working with mpi compile on main, initialising MPI on main and calling mpi functions on library

This commit is contained in:
antonl 2026-02-05 10:10:27 +01:00
parent f391a0873f
commit 812d360e24
3 changed files with 45 additions and 24 deletions

View File

@ -1,4 +1,4 @@
@rem @echo off
@echo off
setlocal
@rem I rely on environment variables here.
@ -22,7 +22,7 @@ set CudaLibPaths=-L"%CUDA_PATH%\lib\x64" -L"%CUDSS_PATH%\lib\12"
set LibraryLinks=-lcudss -lcudart
nvcc -v -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^
nvcc -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^
-arch=%nvccArch% ^
%CudaIncludes% %OtherIncludes% %MPIIncludes% %CudaLibPaths% %MPILibPath% ^
%LibraryLinks% ^
@ -48,8 +48,8 @@ if %ERRORLEVEL% NEQ 0 (
nvcc -o build/%mainSource% src/%mainSource%.cu ^
-arch=%nvccArch% ^
%CudaIncludes% %OtherIncludes% %CudaLibPaths% ^
-lcudss -lcudart ^
%CudaIncludes% %OtherIncludes% %CudaLibPaths% %MPIIncludes% %MPILibPath% ^
-lcudss -lcudart -lmsmpi ^
%IgnoreWarnings% ^
-Xcompiler %ExeXcompilerFlags%

View File

@ -3,6 +3,8 @@
#include <mpi.h>
#include "cudss.h"
#define __CSCUDSS_EXPORT extern "C" __declspec(dllexport)
@ -18,20 +20,16 @@ __CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) {
__CSCUDSS_EXPORT int libraryMain() {
__CSCUDSS_EXPORT int libraryMain(void) {
MPI_Init(0,0);
std::cout << "HELLO LIBRARY" << std::endl;
int rank;
int size;
int Comm = MPI_COMM_WORLD;
cudssCommRank(&Comm, &rank);
cudssCommSize(&Comm, &size);
std::cout << "Rank: " << rank << ", size: " << size << std::endl;
std::cout << "Entered library for Rank: " << rank << ", size: " << size << std::endl;
MPI_Finalize();
return 0;
}

View File

@ -6,13 +6,17 @@
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include <mpi.h>
#include "util.h"
/* ============================================================================
* Version Information
* ============================================================================ */
void printCudaVersion(void) {
void printCudaVersion(int rank) {
if(rank !=0) return;
int runtimeVersion = 0;
int driverVersion = 0;
@ -29,7 +33,9 @@ void printCudaVersion(void) {
LOG("CUDA Driver Version: %d.%d", driverMajor, driverMinor);
}
void printCudssVersion(void) {
void printCudssVersion(int rank) {
if(rank !=0) return;
int major = 0;
int minor = 0;
int patch = 0;
@ -41,7 +47,9 @@ void printCudssVersion(void) {
LOG("cuDSS Version: %d.%d.%d", major, minor, patch);
}
void printDeviceInfo(void) {
void printDeviceInfo(int rank) {
if(rank !=0) return;
int deviceCount = 0;
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
@ -67,13 +75,22 @@ void printDeviceInfo(void) {
using libraryMain = int(*)();
int main(int argc, char **argv) {
LOG("cuDSS Test Program");
LOG("==================");
MPI_Init(0,0);
int rank;
int size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
if(rank == 0) {
LOG("cuDSS Test Program");
LOG("==================");
}
/* Print version information */
printCudaVersion();
printCudssVersion();
printDeviceInfo();
printCudaVersion(rank);
printCudssVersion(rank);
printDeviceInfo(rank);
HMODULE hLib = LoadLibraryA("commlayer.dll");
if (!hLib) {
@ -92,12 +109,18 @@ int main(int argc, char **argv) {
/* Initialize cuDSS */
cudssHandle_t handle = NULL;
CUDSS_CHECK(cudssCreate(&handle));
LOG("cuDSS handle created successfully");
if(rank == 0) {
LOG("cuDSS handle created successfully");
}
/* Cleanup */
CUDSS_CHECK(cudssDestroy(handle));
LOG("cuDSS handle destroyed");
LOG("Done.");
if(rank == 0) {
LOG("cuDSS handle destroyed");
}
MPI_Finalize();
if(rank == 0) {
LOG("Done.");
}
return 0;
}