diff --git a/src/commlayer.cu b/src/commlayer.cu index c5211dc..359ae05 100644 --- a/src/commlayer.cu +++ b/src/commlayer.cu @@ -1,12 +1,47 @@ #include #include +#include +#include #include #include "cudss.h" + #define __CSCUDSS_EXPORT extern "C" __declspec(dllexport) +#define checkCudaErrors(call) _cudaErrorsThrow(__FILE__, __LINE__, __FUNCTION__, (call)) + +class flExcept : public std::exception { + std::string msg; + public: + flExcept(const char *file, int line, const char *message) : msg(std::string(message)) {} + flExcept(const char *file, int line, std::string &message) : msg(message) { + } + + const char* what() const noexcept override { + return msg.c_str(); + } + + flExcept& append(const char *file, int line, std::string appendmsg) { + msg += " "+appendmsg; + return *this; + } +}; + +void _cudaErrorsThrow(const char *file, int line, const char *funcname, cudaError_t err) { + if(err != cudaSuccess) { + int intError = err; + std::string cudaError = std::string(funcname) + " cuda error: " + std::to_string(intError); + throw flExcept(file, line, cudaError); + } +} + +// TODO remove this when we implemented all functions +void throwNotImplemented(const char *file, int line, const char *funcname) { + throw flExcept(file, line, "Not yet implemented function:") + .append(file, line, funcname); +} __CSCUDSS_EXPORT int cudssCommRank(void *comm, int *rank) { int result = MPI_Comm_rank((*(MPI_Comm*)comm), rank); @@ -19,6 +54,69 @@ __CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) { } +__CSCUDSS_EXPORT int cudssSend(const void *buffer, int count, + cudaDataType_t datatype, int dest, int tag, + void *comm, cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssRecv(void *buffer, int count, cudaDataType_t datatype, + int root, int tag, void *comm, + cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssBcast(void *buffer, int count, + cudaDataType_t datatype, int root, void *comm, + cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssReduce(const void *sendbuf, void *recvbuf, int count, + cudaDataType_t datatype, cudssOpType_t op, + int root, void *comm, cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssAllreduce(const void *sendbuf, void *recvbuf, + int count, cudaDataType_t datatype, + cudssOpType_t op, void *commm, + cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssScatterv(const void *sendbuf, const int *sendcounts, + const int *displs, cudaDataType_t sendtype, + void *recvbuf, int recvcount, + cudaDataType_t recvtype, int root, + void *comm, cudaStream_t stream) { + checkCudaErrors(cudaStreamSynchronize(stream)); + throwNotImplemented(__FILE__, __LINE__, __FUNCTION__); + return MPI_ERR_OTHER; +} + +__CSCUDSS_EXPORT int cudssCommSplit(const void *comm, int color, int key, + void *newcomm) { + int result = MPI_Comm_split((*(MPI_Comm*)comm), color, key, ((MPI_Comm*)newcomm)); + return result; +} + +__CSCUDSS_EXPORT int cudssCommFree(void *comm) { + int result = MPI_Comm_free((MPI_Comm*)comm); + return result; +} + + __CSCUDSS_EXPORT int libraryMain(void) { @@ -33,3 +131,24 @@ __CSCUDSS_EXPORT int libraryMain(void) { return 0; } + +/* + * Distributed communication service API wrapper binding table (imported by + * cuDSS). The exposed C symbol must be named as "cudssDistributedInterface". + * Even though the cudssDistributedInterface_t struct is anonymous, one can use + * this type initializer list in C++20, which formats nicely. + */ +__CSCUDSS_EXPORT +cudssDistributedInterface_t cudssDistributedInterface = { + cudssCommRank, + cudssCommSize, + cudssSend, + cudssRecv, + cudssBcast, + cudssReduce, + cudssAllreduce, + cudssScatterv, + cudssCommSplit, + cudssCommFree +}; +