From cc0c525e0ddb133f94d780b56b814fe05e41adb5 Mon Sep 17 00:00:00 2001 From: mkrzewic Date: Tue, 27 Nov 2018 13:23:42 +0100 Subject: [PATCH] Set pointer to factory also when receiving multi-part --- fairmq/FairMQSocket.h | 8 ++++++++ fairmq/FairMQTransportFactory.h | 2 +- fairmq/nanomsg/FairMQSocketNN.cxx | 7 ++++--- fairmq/nanomsg/FairMQSocketNN.h | 3 ++- fairmq/nanomsg/FairMQTransportFactoryNN.cxx | 4 ++-- fairmq/nanomsg/FairMQTransportFactoryNN.h | 2 +- fairmq/ofi/Socket.cxx | 7 ++++--- fairmq/ofi/Socket.h | 3 ++- fairmq/ofi/TransportFactory.cxx | 4 ++-- fairmq/ofi/TransportFactory.h | 2 +- fairmq/shmem/FairMQSocketSHM.cxx | 7 ++++--- fairmq/shmem/FairMQSocketSHM.h | 3 ++- fairmq/shmem/FairMQTransportFactorySHM.cxx | 4 ++-- fairmq/shmem/FairMQTransportFactorySHM.h | 2 +- fairmq/zeromq/FairMQSocketZMQ.cxx | 7 ++++--- fairmq/zeromq/FairMQSocketZMQ.h | 3 ++- fairmq/zeromq/FairMQTransportFactoryZMQ.cxx | 4 ++-- fairmq/zeromq/FairMQTransportFactoryZMQ.h | 2 +- 18 files changed, 45 insertions(+), 29 deletions(-) diff --git a/fairmq/FairMQSocket.h b/fairmq/FairMQSocket.h index 2174c97f..74528467 100644 --- a/fairmq/FairMQSocket.h +++ b/fairmq/FairMQSocket.h @@ -14,11 +14,13 @@ #include #include "FairMQMessage.h" +class FairMQTransportFactory; class FairMQSocket { public: FairMQSocket() {} + FairMQSocket(FairMQTransportFactory* fac): fTransport(fac) {} virtual std::string GetId() = 0; @@ -51,7 +53,13 @@ class FairMQSocket virtual unsigned long GetMessagesTx() const = 0; virtual unsigned long GetMessagesRx() const = 0; + FairMQTransportFactory* GetTransport() { return fTransport; } + void SetTransport(FairMQTransportFactory* transport) { fTransport=transport; } + virtual ~FairMQSocket() {}; + + private: + FairMQTransportFactory* fTransport{nullptr}; }; using FairMQSocketPtr = std::unique_ptr; diff --git a/fairmq/FairMQTransportFactory.h b/fairmq/FairMQTransportFactory.h index c8770804..c4a74fd3 100644 --- a/fairmq/FairMQTransportFactory.h +++ b/fairmq/FairMQTransportFactory.h @@ -62,7 +62,7 @@ class FairMQTransportFactory virtual FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& unmanagedRegion, void* data, const size_t size, void* hint = 0) = 0; /// Create a socket - virtual FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const = 0; + virtual FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) = 0; /// Create a poller for a single channel (all subchannels) virtual FairMQPollerPtr CreatePoller(const std::vector& channels) const = 0; diff --git a/fairmq/nanomsg/FairMQSocketNN.cxx b/fairmq/nanomsg/FairMQSocketNN.cxx index d759cb0b..3e4bb162 100644 --- a/fairmq/nanomsg/FairMQSocketNN.cxx +++ b/fairmq/nanomsg/FairMQSocketNN.cxx @@ -32,8 +32,9 @@ using namespace fair::mq; atomic FairMQSocketNN::fInterrupted(false); -FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const string& id /*= ""*/) - : fSocket(-1) +FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const string& id /*= ""*/, FairMQTransportFactory* fac /*=nullptr*/) + : FairMQSocket{fac} + , fSocket(-1) , fId(id + "." + name + "." + type) , fBytesTx(0) , fBytesRx(0) @@ -368,7 +369,7 @@ int64_t FairMQSocketNN::Receive(vector& msgVec, const int time object.convert(buf); // get the single message size size_t size = buf.size() * sizeof(char); - FairMQMessagePtr part(new FairMQMessageNN(size)); + FairMQMessagePtr part(new FairMQMessageNN(size, GetTransport())); static_cast(part.get())->fReceiving = true; memcpy(part->GetData(), buf.data(), size); msgVec.push_back(move(part)); diff --git a/fairmq/nanomsg/FairMQSocketNN.h b/fairmq/nanomsg/FairMQSocketNN.h index 525266b0..8007c81b 100644 --- a/fairmq/nanomsg/FairMQSocketNN.h +++ b/fairmq/nanomsg/FairMQSocketNN.h @@ -14,11 +14,12 @@ #include "FairMQSocket.h" #include "FairMQMessage.h" +class FairMQTransportFactory; class FairMQSocketNN final : public FairMQSocket { public: - FairMQSocketNN(const std::string& type, const std::string& name, const std::string& id = ""); + FairMQSocketNN(const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac = nullptr); FairMQSocketNN(const FairMQSocketNN&) = delete; FairMQSocketNN operator=(const FairMQSocketNN&) = delete; diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx index ef37dd6c..f8ef6504 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx @@ -43,9 +43,9 @@ FairMQMessagePtr FairMQTransportFactoryNN::CreateMessage(FairMQUnmanagedRegionPt return unique_ptr(new FairMQMessageNN(region, data, size, hint, this)); } -FairMQSocketPtr FairMQTransportFactoryNN::CreateSocket(const string& type, const string& name) const +FairMQSocketPtr FairMQTransportFactoryNN::CreateSocket(const string& type, const string& name) { - unique_ptr socket(new FairMQSocketNN(type, name, GetId())); + unique_ptr socket(new FairMQSocketNN(type, name, GetId(), this)); fSockets.push_back(socket.get()); return socket; } diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.h b/fairmq/nanomsg/FairMQTransportFactoryNN.h index ece07de4..e04b9eb0 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.h +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.h @@ -30,7 +30,7 @@ class FairMQTransportFactoryNN final : public FairMQTransportFactory FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; - FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; + FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override; diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 61a6537f..b00539a5 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -31,8 +31,9 @@ namespace ofi using namespace std; -Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) - : fDataEndpoint(nullptr) +Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/, FairMQTransportFactory* fac) + : FairMQSocket{fac} + , fDataEndpoint(nullptr) , fDataCompletionQueueTx(nullptr) , fDataCompletionQueueRx(nullptr) , fId(id + "." + name + "." + type) @@ -515,7 +516,7 @@ auto Socket::ReceiveImpl(vector& msgVec, const int flags, cons // // do // { - // FairMQMessagePtr part(new FairMQMessageSHM(fManager)); + // FairMQMessagePtr part(new FairMQMessageSHM(fManager, GetTransport())); // zmq_msg_t* msgPtr = static_cast(part.get())->GetMessage(); // // int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 9cc93adf..0ab202d2 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -18,6 +18,7 @@ #include // unique_ptr #include #include +class FairMQTransportFactory; namespace fair { @@ -35,7 +36,7 @@ namespace ofi class Socket final : public fair::mq::Socket { public: - Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = ""); + Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac); Socket(const Socket&) = delete; Socket operator=(const Socket&) = delete; diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 0bfd2fd6..d71b3502 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -56,9 +56,9 @@ auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, con return MessagePtr{new Message(region, data, size, hint)}; } -auto TransportFactory::CreateSocket(const string& type, const string& name) const -> SocketPtr +auto TransportFactory::CreateSocket(const string& type, const string& name) -> SocketPtr { - return SocketPtr{new Socket(fContext, type, name, GetId())}; + return SocketPtr{new Socket(fContext, type, name, GetId(), this)}; } auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr diff --git a/fairmq/ofi/TransportFactory.h b/fairmq/ofi/TransportFactory.h index 73475d39..f618a07f 100644 --- a/fairmq/ofi/TransportFactory.h +++ b/fairmq/ofi/TransportFactory.h @@ -38,7 +38,7 @@ class TransportFactory final : public FairMQTransportFactory auto CreateMessage(void* data, const std::size_t size, fairmq_free_fn* ffn, void* hint = nullptr) const -> MessagePtr override; auto CreateMessage(UnmanagedRegionPtr& region, void* data, const std::size_t size, void* hint = nullptr) const -> MessagePtr override; - auto CreateSocket(const std::string& type, const std::string& name) const -> SocketPtr override; + auto CreateSocket(const std::string& type, const std::string& name) -> SocketPtr override; auto CreatePoller(const std::vector& channels) const -> PollerPtr override; auto CreatePoller(const std::vector& channels) const -> PollerPtr override; diff --git a/fairmq/shmem/FairMQSocketSHM.cxx b/fairmq/shmem/FairMQSocketSHM.cxx index 46b0d2d6..b5882097 100644 --- a/fairmq/shmem/FairMQSocketSHM.cxx +++ b/fairmq/shmem/FairMQSocketSHM.cxx @@ -23,8 +23,9 @@ using namespace fair::mq; atomic FairMQSocketSHM::fInterrupted(false); -FairMQSocketSHM::FairMQSocketSHM(Manager& manager, const string& type, const string& name, const string& id /*= ""*/, void* context) - : fSocket(nullptr) +FairMQSocketSHM::FairMQSocketSHM(Manager& manager, const string& type, const string& name, const string& id /*= ""*/, void* context, FairMQTransportFactory* fac /*=nullptr*/) + : FairMQSocket{fac} + , fSocket(nullptr) , fManager(manager) , fId(id + "." + name + "." + type) , fBytesTx(0) @@ -377,7 +378,7 @@ int64_t FairMQSocketSHM::Receive(vector& msgVec, const int tim MetaHeader metaHeader; memcpy(&metaHeader, &hdrVec[m], sizeof(MetaHeader)); - msgVec.emplace_back(fair::mq::tools::make_unique(fManager)); + msgVec.emplace_back(fair::mq::tools::make_unique(fManager, GetTransport())); FairMQMessageSHM* msg = static_cast(msgVec.back().get()); MetaHeader* msgHdr = static_cast(zmq_msg_data(msg->GetMessage())); diff --git a/fairmq/shmem/FairMQSocketSHM.h b/fairmq/shmem/FairMQSocketSHM.h index 39d8064a..4bb0e4b6 100644 --- a/fairmq/shmem/FairMQSocketSHM.h +++ b/fairmq/shmem/FairMQSocketSHM.h @@ -15,11 +15,12 @@ #include #include // unique_ptr +class FairMQTransportFactory; class FairMQSocketSHM final : public FairMQSocket { public: - FairMQSocketSHM(fair::mq::shmem::Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr); + FairMQSocketSHM(fair::mq::shmem::Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* fac = nullptr); FairMQSocketSHM(const FairMQSocketSHM&) = delete; FairMQSocketSHM operator=(const FairMQSocketSHM&) = delete; diff --git a/fairmq/shmem/FairMQTransportFactorySHM.cxx b/fairmq/shmem/FairMQTransportFactorySHM.cxx index 5afe25fa..a17e5e16 100644 --- a/fairmq/shmem/FairMQTransportFactorySHM.cxx +++ b/fairmq/shmem/FairMQTransportFactorySHM.cxx @@ -233,10 +233,10 @@ FairMQMessagePtr FairMQTransportFactorySHM::CreateMessage(FairMQUnmanagedRegionP return unique_ptr(new FairMQMessageSHM(*fManager, region, data, size, hint, this)); } -FairMQSocketPtr FairMQTransportFactorySHM::CreateSocket(const string& type, const string& name) const +FairMQSocketPtr FairMQTransportFactorySHM::CreateSocket(const string& type, const string& name) { assert(fContext); - return unique_ptr(new FairMQSocketSHM(*fManager, type, name, GetId(), fContext)); + return unique_ptr(new FairMQSocketSHM(*fManager, type, name, GetId(), fContext, this)); } FairMQPollerPtr FairMQTransportFactorySHM::CreatePoller(const vector& channels) const diff --git a/fairmq/shmem/FairMQTransportFactorySHM.h b/fairmq/shmem/FairMQTransportFactorySHM.h index 30fc502b..93edb39c 100644 --- a/fairmq/shmem/FairMQTransportFactorySHM.h +++ b/fairmq/shmem/FairMQTransportFactorySHM.h @@ -38,7 +38,7 @@ class FairMQTransportFactorySHM final : public FairMQTransportFactory FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; - FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; + FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override; diff --git a/fairmq/zeromq/FairMQSocketZMQ.cxx b/fairmq/zeromq/FairMQSocketZMQ.cxx index 59cdfc5f..8d3eea0c 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.cxx +++ b/fairmq/zeromq/FairMQSocketZMQ.cxx @@ -20,8 +20,9 @@ using namespace fair::mq; atomic FairMQSocketZMQ::fInterrupted(false); -FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const string& id /*= ""*/, void* context) - : fSocket(nullptr) +FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const string& id /*= ""*/, void* context, FairMQTransportFactory* fac) + : FairMQSocket{fac} + , fSocket(nullptr) , fId(id + "." + name + "." + type) , fBytesTx(0) , fBytesRx(0) @@ -314,7 +315,7 @@ int64_t FairMQSocketZMQ::Receive(vector& msgVec, const int tim do { - unique_ptr part(new FairMQMessageZMQ()); + unique_ptr part(new FairMQMessageZMQ(GetTransport())); int nbytes = zmq_msg_recv(static_cast(part.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) diff --git a/fairmq/zeromq/FairMQSocketZMQ.h b/fairmq/zeromq/FairMQSocketZMQ.h index a5d7d21b..3337b77e 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.h +++ b/fairmq/zeromq/FairMQSocketZMQ.h @@ -15,11 +15,12 @@ #include "FairMQSocket.h" #include "FairMQMessage.h" +class FairMQTransportFactory; class FairMQSocketZMQ final : public FairMQSocket { public: - FairMQSocketZMQ(const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr); + FairMQSocketZMQ(const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* factory = nullptr); FairMQSocketZMQ(const FairMQSocketZMQ&) = delete; FairMQSocketZMQ operator=(const FairMQSocketZMQ&) = delete; diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx index 7027df0b..6d096511 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx @@ -69,10 +69,10 @@ FairMQMessagePtr FairMQTransportFactoryZMQ::CreateMessage(FairMQUnmanagedRegionP return unique_ptr(new FairMQMessageZMQ(region, data, size, hint, this)); } -FairMQSocketPtr FairMQTransportFactoryZMQ::CreateSocket(const string& type, const string& name) const +FairMQSocketPtr FairMQTransportFactoryZMQ::CreateSocket(const string& type, const string& name) { assert(fContext); - return unique_ptr(new FairMQSocketZMQ(type, name, GetId(), fContext)); + return unique_ptr(new FairMQSocketZMQ(type, name, GetId(), fContext, this)); } FairMQPollerPtr FairMQTransportFactoryZMQ::CreatePoller(const vector& channels) const diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.h b/fairmq/zeromq/FairMQTransportFactoryZMQ.h index 11abbfce..64d44395 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.h +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.h @@ -39,7 +39,7 @@ class FairMQTransportFactoryZMQ final : public FairMQTransportFactory FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; - FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; + FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override; FairMQPollerPtr CreatePoller(const std::vector& channels) const override;