working mpi link in the commlibrary

This commit is contained in:
antonl 2026-02-05 09:17:53 +01:00
parent af5ba1e4a1
commit f391a0873f
3 changed files with 33 additions and 6 deletions

View File

@ -1,4 +1,4 @@
@echo off
@rem @echo off
setlocal
@rem I rely on environment variables here.
@ -8,19 +8,24 @@ setlocal
set nvccArch=sm_120
set commLayerSource=commlayer
set mainSource=main
set CommonXCompFlags=/W3 /O2 /EHsc
set DllXcompilerFlags="%CommonXCompFlags% /MD"
set ExeXcompilerFlags="%CommonXCompFlags% /MT"
set IgnoreWarnings=-Wno-deprecated-gpu-targets
set MPIIncludes=-I"%MSMPI_INC%\"
set MPILibPath=-L"%MSMPI_LIB64%\"
set CudaIncludes=-I"%CUDA_PATH%\include" -I"%CUDSS_PATH%\include"
set OtherIncludes=-Iinclude
set CudaLibPaths=-L"%CUDA_PATH%\lib\x64" -L"%CUDSS_PATH%\lib\12"
nvcc -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^
set LibraryLinks=-lcudss -lcudart
nvcc -v -c -o build/%commLayerSource%.obj src/%commLayerSource%.cu ^
-arch=%nvccArch% ^
%CudaIncludes% %OtherIncludes% %CudaLibPaths% ^
-lcudss -lcudart ^
%CudaIncludes% %OtherIncludes% %MPIIncludes% %CudaLibPaths% %MPILibPath% ^
%LibraryLinks% ^
%IgnoreWarnings% ^
-Xcompiler %DllXcompilerFlags%
@ -32,6 +37,7 @@ if %ERRORLEVEL% NEQ 0 (
nvcc -shared -o build/%commLayerSource%.dll build/%commLayerSource%.obj ^
-arch=%nvccArch% ^
%IgnoreWarnings% ^
%MPILibPath% -lmsmpi ^
-Xcompiler %DllXcompilerFlags% ^
-Xlinker "/NODEFAULTLIB:LIBCMT"

View File

@ -1,14 +1,37 @@
#include <cuda_runtime.h>
#include <iostream>
#include <mpi.h>
#define __CSCUDSS_EXPORT extern "C" __declspec(dllexport)
__CSCUDSS_EXPORT int cudssCommRank(void *comm, int *rank) {
int result = MPI_Comm_rank((*(MPI_Comm*)comm), rank);
return result;
}
__CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) {
int result = MPI_Comm_size((*(MPI_Comm*)comm), size);
return result;
}
__CSCUDSS_EXPORT int libraryMain() {
MPI_Init(0,0);
std::cout << "HELLO LIBRARY" << std::endl;
int rank;
int size;
int Comm = MPI_COMM_WORLD;
cudssCommRank(&Comm, &rank);
cudssCommSize(&Comm, &size);
std::cout << "Rank: " << rank << ", size: " << size << std::endl;
MPI_Finalize();
return 0;
}

View File

@ -8,8 +8,6 @@
#include "util.h"
/* ============================================================================
* Version Information
* ============================================================================ */