diff --git a/fairmq/FairMQSocket.h b/fairmq/FairMQSocket.h index e198a8fa..f2b2f33e 100644 --- a/fairmq/FairMQSocket.h +++ b/fairmq/FairMQSocket.h @@ -63,8 +63,8 @@ namespace fair namespace mq { -using SocketPtr = std::unique_ptr; - +using Socket = FairMQSocket; +using SocketPtr = FairMQSocketPtr; struct SocketError : std::runtime_error { using std::runtime_error::runtime_error; }; } /* namespace mq */ diff --git a/fairmq/ofi/Message.cxx b/fairmq/ofi/Message.cxx index 167b3745..dbc93b6a 100644 --- a/fairmq/ofi/Message.cxx +++ b/fairmq/ofi/Message.cxx @@ -23,9 +23,6 @@ using namespace std; Message::Message() { - // if (zmq_msg_init(&fMessage) != 0) { - // throw MessageError{tools::ToString("Failed initializing meta message, reason: ", zmq_strerror(errno))}; - // } } Message::Message(const size_t size) @@ -92,9 +89,6 @@ auto Message::Copy(const fair::mq::MessagePtr& msg) -> void Message::~Message() noexcept(false) { - // if (zmq_msg_close(&fMessage) != 0) { - // throw MessageError{tools::ToString("Failed closing meta message, reason: ", zmq_strerror(errno))}; - // } } } /* namespace ofi */ diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 1145441f..1d503cf3 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -7,6 +7,7 @@ ********************************************************************************/ #include +#include #include #include @@ -21,7 +22,7 @@ namespace ofi using namespace std; -Socket::Socket(const string& type, const string& name, const string& id /*= ""*/, void* zmqContext) +Socket::Socket(const TransportFactory& factory, const string& type, const string& name, const string& id /*= ""*/) : fId{id + "." + name + "." + type} , fBytesTx{0} , fBytesRx{0} @@ -30,12 +31,10 @@ Socket::Socket(const string& type, const string& name, const string& id /*= ""*/ , fSndTimeout{100} , fRcvTimeout{100} { - assert(zmqContext); - if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } else { - fMetaSocket = zmq_socket(zmqContext, GetConstant(type)); + fMetaSocket = zmq_socket(factory.fZmqContext, GetConstant(type)); } if (fMetaSocket == nullptr) { @@ -95,122 +94,28 @@ auto Socket::TryReceive(std::vector>& msgVec) -> auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int { - throw SocketError{"Not yet implemented."}; - // int nbytes = -1; - // int elapsed = 0; - // - // while (true && !fInterrupted) - // { - // nbytes = zmq_msg_send(static_cast(msg.get())->GetMessage(), fSocket, flags); - // if (nbytes == 0) - // { - // return nbytes; - // } - // else if (nbytes > 0) - // { - // static_cast(msg.get())->fQueued = true; - // - // size_t size = msg->GetSize(); - // fBytesTx += size; - // ++fMessagesTx; - // - // return size; - // } - // else if (zmq_errno() == EAGAIN) - // { - // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - // { - // if (timeout) - // { - // elapsed += fSndTimeout; - // if (elapsed >= timeout) - // { - // return -2; - // } - // } - // continue; - // } - // else - // { - // return -2; - // } - // } - // else if (zmq_errno() == ETERM) - // { - // LOG(info) << "terminating socket " << fId; - // return -1; - // } - // else - // { - // LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - // return nbytes; - // } - // } - // - // return -1; + auto ret = zmq_send(fMetaSocket, nullptr, 0, flags); + if (ret == EAGAIN) { + return -2; + } else if (ret < 0) { + LOG(error) << "Failed sending meta message on socket " << fId << ", reason: " << zmq_strerror(errno); + return -1; + } else { + return ret; + } } auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int { - throw SocketError{"Not yet implemented."}; - // int nbytes = -1; - // int elapsed = 0; - // - // zmq_msg_t* msgPtr = static_cast(msg.get())->GetMessage(); - // while (true) - // { - // nbytes = zmq_msg_recv(msgPtr, fSocket, flags); - // if (nbytes == 0) - // { - // ++fMessagesRx; - // - // return nbytes; - // } - // else if (nbytes > 0) - // { - // MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); - // size_t size = 0; - // static_cast(msg.get())->fHandle = hdr->fHandle; - // static_cast(msg.get())->fSize = hdr->fSize; - // static_cast(msg.get())->fRegionId = hdr->fRegionId; - // static_cast(msg.get())->fHint = hdr->fHint; - // size = msg->GetSize(); - // - // fBytesRx += size; - // ++fMessagesRx; - // - // return size; - // } - // else if (zmq_errno() == EAGAIN) - // { - // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - // { - // if (timeout) - // { - // elapsed += fSndTimeout; - // if (elapsed >= timeout) - // { - // return -2; - // } - // } - // continue; - // } - // else - // { - // return -2; - // } - // } - // else if (zmq_errno() == ETERM) - // { - // LOG(info) << "terminating socket " << fId; - // return -1; - // } - // else - // { - // LOG(error) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); - // return nbytes; - // } - // } + auto ret = zmq_recv(fMetaSocket, nullptr, 0, flags); + if (ret == EAGAIN) { + return -2; + } else if (ret < 0) { + LOG(error) << "Failed receiving meta message on socket " << fId << ", reason: " << zmq_strerror(errno); + return -1; + } else { + return ret; + } } auto Socket::SendImpl(vector& msgVec, const int flags, const int timeout) -> int64_t diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 6e34fc1c..3cd2637f 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -21,16 +21,18 @@ namespace mq namespace ofi { +class TransportFactory; + /** * @class Socket Socket.h * @brief * * @todo TODO insert long description */ -class Socket : public FairMQSocket +class Socket : public fair::mq::Socket { public: - Socket(const std::string& type, const std::string& name, const std::string& id = "", void* zmqContext = nullptr); + Socket(const TransportFactory& factory, const std::string& type, const std::string& name, const std::string& id = ""); Socket(const Socket&) = delete; Socket operator=(const Socket&) = delete; @@ -39,15 +41,15 @@ class Socket : public FairMQSocket auto Bind(const std::string& address) -> bool override; auto Connect(const std::string& address) -> void override; - auto Send(FairMQMessagePtr& msg, int timeout = 0) -> int override; - auto Receive(FairMQMessagePtr& msg, int timeout = 0) -> int override; - auto Send(std::vector>& msgVec, int timeout = 0) -> int64_t override; - auto Receive(std::vector>& msgVec, int timeout = 0) -> int64_t override; + auto Send(MessagePtr& msg, int timeout = 0) -> int override; + auto Receive(MessagePtr& msg, int timeout = 0) -> int override; + auto Send(std::vector& msgVec, int timeout = 0) -> int64_t override; + auto Receive(std::vector& msgVec, int timeout = 0) -> int64_t override; - auto TrySend(FairMQMessagePtr& msg) -> int override; - auto TryReceive(FairMQMessagePtr& msg) -> int override; - auto TrySend(std::vector>& msgVec) -> int64_t override; - auto TryReceive(std::vector>& msgVec) -> int64_t override; + auto TrySend(MessagePtr& msg) -> int override; + auto TryReceive(MessagePtr& msg) -> int override; + auto TrySend(std::vector& msgVec) -> int64_t override; + auto TryReceive(std::vector& msgVec) -> int64_t override; auto GetSocket() const -> void* override { return fMetaSocket; } auto GetSocket(int nothing) const -> int override { return -1; } @@ -82,10 +84,10 @@ class Socket : public FairMQSocket int fSndTimeout; int fRcvTimeout; - auto SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int; - auto ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int; - auto SendImpl(std::vector>& msgVec, const int flags, const int timeout) -> int64_t; - auto ReceiveImpl(std::vector>& msgVec, const int flags, const int timeout) -> int64_t; + auto SendImpl(MessagePtr& msg, const int flags, const int timeout) -> int; + auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int; + auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; + auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; }; /* class Socket */ } /* namespace ofi */ diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 2ab89071..9138a95a 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -86,7 +86,7 @@ auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, con auto TransportFactory::CreateSocket(const string& type, const string& name) const -> SocketPtr { assert(fZmqContext); - return SocketPtr{new Socket(type, name, GetId(), fZmqContext)}; + return SocketPtr{new Socket(*this, type, name, GetId())}; } auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr diff --git a/fairmq/ofi/TransportFactory.h b/fairmq/ofi/TransportFactory.h index b058e2b5..7bc26e52 100644 --- a/fairmq/ofi/TransportFactory.h +++ b/fairmq/ofi/TransportFactory.h @@ -19,6 +19,8 @@ namespace mq namespace ofi { +class Socket; + /** * @class TransportFactory TransportFactory.h * @brief FairMQ transport factory for the ofi transport (implemented with ZeroMQ + libfabric) @@ -27,6 +29,8 @@ namespace ofi */ class TransportFactory : public FairMQTransportFactory { + friend Socket; + public: TransportFactory(const std::string& id = "", const FairMQProgOptions* config = nullptr); TransportFactory(const TransportFactory&) = delete; diff --git a/fairmq/test/helper/devices/TestPairLeft.cxx b/fairmq/test/helper/devices/TestPairLeft.cxx index d82b85e2..ee38aabc 100644 --- a/fairmq/test/helper/devices/TestPairLeft.cxx +++ b/fairmq/test/helper/devices/TestPairLeft.cxx @@ -30,7 +30,7 @@ class PairLeft : public FairMQDevice auto Run() -> void override { - auto msg = FairMQMessagePtr{NewMessage()}; + auto msg{NewMessageFor("data", 0)}; Send(msg, "data"); }; }; diff --git a/fairmq/test/helper/devices/TestPairRight.cxx b/fairmq/test/helper/devices/TestPairRight.cxx index d8328723..5dfec498 100644 --- a/fairmq/test/helper/devices/TestPairRight.cxx +++ b/fairmq/test/helper/devices/TestPairRight.cxx @@ -30,10 +30,9 @@ class PairRight : public FairMQDevice auto Run() -> void override { - auto msg = FairMQMessagePtr{NewMessage()}; + MessagePtr msg{NewMessageFor("data", 0)}; - if (Receive(msg, "data") >= 0) - { + if (Receive(msg, "data") >= 0) { LOG(info) << "PAIR test successfull"; } };