full commlayer interface, not everything implemented yet

This commit is contained in:
antonl 2026-02-05 11:26:10 +01:00
parent 812d360e24
commit 7ed9dab8f9

View File

@ -1,12 +1,47 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
#include <string>
#include <exception>
#include <mpi.h> #include <mpi.h>
#include "cudss.h" #include "cudss.h"
#define __CSCUDSS_EXPORT extern "C" __declspec(dllexport) #define __CSCUDSS_EXPORT extern "C" __declspec(dllexport)
#define checkCudaErrors(call) _cudaErrorsThrow(__FILE__, __LINE__, __FUNCTION__, (call))
class flExcept : public std::exception {
std::string msg;
public:
flExcept(const char *file, int line, const char *message) : msg(std::string(message)) {}
flExcept(const char *file, int line, std::string &message) : msg(message) {
}
const char* what() const noexcept override {
return msg.c_str();
}
flExcept& append(const char *file, int line, std::string appendmsg) {
msg += " "+appendmsg;
return *this;
}
};
void _cudaErrorsThrow(const char *file, int line, const char *funcname, cudaError_t err) {
if(err != cudaSuccess) {
int intError = err;
std::string cudaError = std::string(funcname) + " cuda error: " + std::to_string(intError);
throw flExcept(file, line, cudaError);
}
}
// TODO remove this when we implemented all functions
void throwNotImplemented(const char *file, int line, const char *funcname) {
throw flExcept(file, line, "Not yet implemented function:")
.append(file, line, funcname);
}
__CSCUDSS_EXPORT int cudssCommRank(void *comm, int *rank) { __CSCUDSS_EXPORT int cudssCommRank(void *comm, int *rank) {
int result = MPI_Comm_rank((*(MPI_Comm*)comm), rank); int result = MPI_Comm_rank((*(MPI_Comm*)comm), rank);
@ -19,6 +54,69 @@ __CSCUDSS_EXPORT int cudssCommSize(void *comm, int *size) {
} }
__CSCUDSS_EXPORT int cudssSend(const void *buffer, int count,
cudaDataType_t datatype, int dest, int tag,
void *comm, cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssRecv(void *buffer, int count, cudaDataType_t datatype,
int root, int tag, void *comm,
cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssBcast(void *buffer, int count,
cudaDataType_t datatype, int root, void *comm,
cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssReduce(const void *sendbuf, void *recvbuf, int count,
cudaDataType_t datatype, cudssOpType_t op,
int root, void *comm, cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssAllreduce(const void *sendbuf, void *recvbuf,
int count, cudaDataType_t datatype,
cudssOpType_t op, void *commm,
cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssScatterv(const void *sendbuf, const int *sendcounts,
const int *displs, cudaDataType_t sendtype,
void *recvbuf, int recvcount,
cudaDataType_t recvtype, int root,
void *comm, cudaStream_t stream) {
checkCudaErrors(cudaStreamSynchronize(stream));
throwNotImplemented(__FILE__, __LINE__, __FUNCTION__);
return MPI_ERR_OTHER;
}
__CSCUDSS_EXPORT int cudssCommSplit(const void *comm, int color, int key,
void *newcomm) {
int result = MPI_Comm_split((*(MPI_Comm*)comm), color, key, ((MPI_Comm*)newcomm));
return result;
}
__CSCUDSS_EXPORT int cudssCommFree(void *comm) {
int result = MPI_Comm_free((MPI_Comm*)comm);
return result;
}
__CSCUDSS_EXPORT int libraryMain(void) { __CSCUDSS_EXPORT int libraryMain(void) {
@ -33,3 +131,24 @@ __CSCUDSS_EXPORT int libraryMain(void) {
return 0; return 0;
} }
/*
* Distributed communication service API wrapper binding table (imported by
* cuDSS). The exposed C symbol must be named as "cudssDistributedInterface".
* Even though the cudssDistributedInterface_t struct is anonymous, one can use
* this type initializer list in C++20, which formats nicely.
*/
__CSCUDSS_EXPORT
cudssDistributedInterface_t cudssDistributedInterface = {
cudssCommRank,
cudssCommSize,
cudssSend,
cudssRecv,
cudssBcast,
cudssReduce,
cudssAllreduce,
cudssScatterv,
cudssCommSplit,
cudssCommFree
};