#include #include #include #include #include #include "cudss.h" #define __CSCUDSS_EXPORT extern "C" __declspec(dllexport) #define checkCudaErrors(call) _cudaErrorsThrow(__FILE__, __LINE__, __FUNCTION__, (call)) class flExcept : public std::exception { std::string msg; public: flExcept(const char *file, int line, const char *message) : msg(std::string(message)) {} flExcept(const char *file, int line, std::string &message) : msg(message) { } const char* what() const noexcept override { return msg.c_str(); } flExcept& append(const char *file, int line, std::string appendmsg) { msg += " "+appendmsg; return *this; } }; void _cudaErrorsThrow(const char *file, int line, const char *funcname, cudaError_t err) { if(err != cudaSuccess) { int intError = err; std::string cudaError = std::string(funcname) + " cuda error: " + std::to_string(intError); throw flExcept(file, line, cudaError); } } // TODO remove this when we implemented all functions void throwNotImplemented(const char *file, int line, const char *funcname) { throw flExcept(file, line, "Not yet implemented function:") .append(file, line, funcname); } __CSCUDSS_EXPORT int cudssCommRank(void *comm, int *rank) { int result = MPI_Comm_rank((*(MPI_Comm*)comm), rank); return result; } __CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) { int result = MPI_Comm_size((*(MPI_Comm*)comm), size); return result; } __CSCUDSS_EXPORT int cudssSend(const void *buffer, int count, cudaDataType_t datatype, int dest, int tag, void *comm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssRecv(void *buffer, int count, cudaDataType_t datatype, int root, int tag, void *comm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssBcast(void *buffer, int count, cudaDataType_t datatype, int root, void *comm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssReduce(const void *sendbuf, void *recvbuf, int count, cudaDataType_t datatype, cudssOpType_t op, int root, void *comm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssAllreduce(const void *sendbuf, void *recvbuf, int count, cudaDataType_t datatype, cudssOpType_t op, void *commm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssScatterv(const void *sendbuf, const int *sendcounts, const int *displs, cudaDataType_t sendtype, void *recvbuf, int recvcount, cudaDataType_t recvtype, int root, void *comm, cudaStream_t stream) { checkCudaErrors(cudaStreamSynchronize(stream)); throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); return MPI_ERR_OTHER; } __CSCUDSS_EXPORT int cudssCommSplit(const void *comm, int color, int key, void *newcomm) { int result = MPI_Comm_split((*(MPI_Comm*)comm), color, key, ((MPI_Comm*)newcomm)); return result; } __CSCUDSS_EXPORT int cudssCommFree(void *comm) { int result = MPI_Comm_free((MPI_Comm*)comm); return result; } __CSCUDSS_EXPORT int libraryMain(void) { int rank; int size; int Comm = MPI_COMM_WORLD; cudssCommRank(&Comm, &rank); cudssCommSize(&Comm, &size); std::cout << "Entered library for Rank: " << rank << ", size: " << size << std::endl; return 0; } /* * Distributed communication service API wrapper binding table (imported by * cuDSS). The exposed C symbol must be named as "cudssDistributedInterface". * Even though the cudssDistributedInterface_t struct is anonymous, one can use * this type initializer list in C++20, which formats nicely. */ __CSCUDSS_EXPORT cudssDistributedInterface_t cudssDistributedInterface = { cudssCommRank, cudssCommSize, cudssSend, cudssRecv, cudssBcast, cudssReduce, cudssAllreduce, cudssScatterv, cudssCommSplit, cudssCommFree };