From 812d360e2470c896c948bd3f5ff89acc0e054f8f Mon Sep 17 00:00:00 2001 From: antonl Date: Thu, 5 Feb 2026 10:10:27 +0100 Subject: [PATCH] working with mpi compile on main, initialising MPI on main and calling mpi functions on library --- build.bat | 8 ++++---- src/commlayer.cu | 10 ++++------ src/main.cu | 51 +++++++++++++++++++++++++++++++++++------------- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/build.bat b/build.bat index c39e0c0..e930a94 100644 --- a/build.bat +++ b/build.bat @@ -1,4 +1,4 @@ -@rem @echo off +@echo off setlocal @rem I rely on environment variables here. @@ -22,7 +22,7 @@ set CudaLibPaths=-L"%CUDA_PATH%\lib\x64" -L"%CUDSS_PATH%\lib\12" set LibraryLinks=-lcudss -lcudart -nvcc -v -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^ +nvcc -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^ -arch=%nvccArch% ^ %CudaIncludes% %OtherIncludes% %MPIIncludes% %CudaLibPaths% %MPILibPath% ^ %LibraryLinks% ^ @@ -48,8 +48,8 @@ if %ERRORLEVEL% NEQ 0 ( nvcc -o build/%mainSource% src/%mainSource%.cu ^ -arch=%nvccArch% ^ - %CudaIncludes% %OtherIncludes% %CudaLibPaths% ^ - -lcudss -lcudart ^ + %CudaIncludes% %OtherIncludes% %CudaLibPaths% %MPIIncludes% %MPILibPath% ^ + -lcudss -lcudart -lmsmpi ^ %IgnoreWarnings% ^ -Xcompiler %ExeXcompilerFlags% diff --git a/src/commlayer.cu b/src/commlayer.cu index 882b74c..c5211dc 100644 --- a/src/commlayer.cu +++ b/src/commlayer.cu @@ -3,6 +3,8 @@ #include +#include "cudss.h" + #define __CSCUDSS_EXPORT extern "C" __declspec(dllexport) @@ -18,20 +20,16 @@ __CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) { -__CSCUDSS_EXPORT int libraryMain() { +__CSCUDSS_EXPORT int libraryMain(void) { - MPI_Init(0,0); - std::cout << "HELLO LIBRARY" << std::endl; - int rank; int size; int Comm = MPI_COMM_WORLD; cudssCommRank(&Comm, &rank); cudssCommSize(&Comm, &size); - std::cout << "Rank: " << rank << ", size: " << size << std::endl; + std::cout << "Entered library for Rank: " << rank << ", size: " << size << std::endl; - MPI_Finalize(); return 0; } diff --git a/src/main.cu b/src/main.cu index b5d461c..6492d0e 100644 --- a/src/main.cu +++ b/src/main.cu @@ -6,13 +6,17 @@ #define WIN32_LEAN_AND_MEAN #include +#include + #include "util.h" /* ============================================================================ * Version Information * ============================================================================ */ -void printCudaVersion(void) { +void printCudaVersion(int rank) { + if(rank !=0) return; + int runtimeVersion = 0; int driverVersion = 0; @@ -29,7 +33,9 @@ void printCudaVersion(void) { LOG("CUDA Driver Version: %d.%d", driverMajor, driverMinor); } -void printCudssVersion(void) { +void printCudssVersion(int rank) { + if(rank !=0) return; + int major = 0; int minor = 0; int patch = 0; @@ -41,7 +47,9 @@ void printCudssVersion(void) { LOG("cuDSS Version: %d.%d.%d", major, minor, patch); } -void printDeviceInfo(void) { +void printDeviceInfo(int rank) { + if(rank !=0) return; + int deviceCount = 0; CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); @@ -67,13 +75,22 @@ void printDeviceInfo(void) { using libraryMain = int(*)(); int main(int argc, char **argv) { - LOG("cuDSS Test Program"); - LOG("=================="); - + + MPI_Init(0,0); + int rank; + int size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + if(rank == 0) { + LOG("cuDSS Test Program"); + LOG("=================="); + } + /* Print version information */ - printCudaVersion(); - printCudssVersion(); - printDeviceInfo(); + printCudaVersion(rank); + printCudssVersion(rank); + printDeviceInfo(rank); HMODULE hLib = LoadLibraryA("commlayer.dll"); if (!hLib) { @@ -92,12 +109,18 @@ int main(int argc, char **argv) { /* Initialize cuDSS */ cudssHandle_t handle = NULL; CUDSS_CHECK(cudssCreate(&handle)); - LOG("cuDSS handle created successfully"); - + if(rank == 0) { + LOG("cuDSS handle created successfully"); + } /* Cleanup */ CUDSS_CHECK(cudssDestroy(handle)); - LOG("cuDSS handle destroyed"); - - LOG("Done."); + if(rank == 0) { + LOG("cuDSS handle destroyed"); + } + + MPI_Finalize(); + if(rank == 0) { + LOG("Done."); + } return 0; }