remove Get/SetMessage from interface (internal transport detail)

This commit is contained in:
Alexey Rybalchenko 2017-12-07 13:42:38 +01:00 committed by Mohammad Al-Turany
parent e5aa85b61d
commit ea7ae3ded9
10 changed files with 63 additions and 77 deletions

View File

@ -29,14 +29,11 @@ class FairMQMessage
virtual void Rebuild(const size_t size) = 0; virtual void Rebuild(const size_t size) = 0;
virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0; virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0;
virtual void* GetMessage() = 0;
virtual void* GetData() = 0; virtual void* GetData() = 0;
virtual size_t GetSize() const = 0; virtual size_t GetSize() const = 0;
virtual bool SetUsedSize(const size_t size) = 0; virtual bool SetUsedSize(const size_t size) = 0;
virtual void SetMessage(void* data, size_t size) = 0;
virtual FairMQ::Transport GetType() const = 0; virtual FairMQ::Transport GetType() const = 0;
virtual void Copy(const std::unique_ptr<FairMQMessage>& msg) = 0; virtual void Copy(const std::unique_ptr<FairMQMessage>& msg) = 0;

View File

@ -94,13 +94,13 @@ FairMQMessageNN::FairMQMessageNN(FairMQUnmanagedRegionPtr& region, void* data, c
void FairMQMessageNN::Rebuild() void FairMQMessageNN::Rebuild()
{ {
Clear(); CloseMessage();
fReceiving = false; fReceiving = false;
} }
void FairMQMessageNN::Rebuild(const size_t size) void FairMQMessageNN::Rebuild(const size_t size)
{ {
Clear(); CloseMessage();
fMessage = nn_allocmsg(size, 0); fMessage = nn_allocmsg(size, 0);
if (!fMessage) if (!fMessage)
{ {
@ -112,7 +112,7 @@ void FairMQMessageNN::Rebuild(const size_t size)
void FairMQMessageNN::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) void FairMQMessageNN::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint)
{ {
Clear(); CloseMessage();
fMessage = nn_allocmsg(size, 0); fMessage = nn_allocmsg(size, 0);
if (!fMessage) if (!fMessage)
{ {
@ -173,7 +173,7 @@ FairMQ::Transport FairMQMessageNN::GetType() const
return fTransportType; return fTransportType;
} }
void FairMQMessageNN::Copy(const unique_ptr<FairMQMessage>& msg) void FairMQMessageNN::Copy(const FairMQMessagePtr& msg)
{ {
if (fMessage) if (fMessage)
{ {
@ -192,12 +192,12 @@ void FairMQMessageNN::Copy(const unique_ptr<FairMQMessage>& msg)
} }
else else
{ {
memcpy(fMessage, msg->GetMessage(), size); memcpy(fMessage, static_cast<FairMQMessageNN*>(msg.get())->GetMessage(), size);
fSize = size; fSize = size;
} }
} }
void FairMQMessageNN::Clear() void FairMQMessageNN::CloseMessage()
{ {
if (nn_freemsg(fMessage) < 0) if (nn_freemsg(fMessage) < 0)
{ {
@ -214,15 +214,6 @@ FairMQMessageNN::~FairMQMessageNN()
{ {
if (fReceiving) if (fReceiving)
{ {
int rc = nn_freemsg(fMessage); CloseMessage();
if (rc < 0)
{
LOG(ERROR) << "failed freeing message, reason: " << nn_strerror(errno);
}
else
{
fMessage = nullptr;
fSize = 0;
}
} }
} }

View File

@ -22,8 +22,12 @@
#include "FairMQMessage.h" #include "FairMQMessage.h"
#include "FairMQUnmanagedRegion.h" #include "FairMQUnmanagedRegion.h"
class FairMQSocketNN;
class FairMQMessageNN : public FairMQMessage class FairMQMessageNN : public FairMQMessage
{ {
friend class FairMQSocketNN;
public: public:
FairMQMessageNN(); FairMQMessageNN();
FairMQMessageNN(const size_t size); FairMQMessageNN(const size_t size);
@ -37,22 +41,17 @@ class FairMQMessageNN : public FairMQMessage
void Rebuild(const size_t size) override; void Rebuild(const size_t size) override;
void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
void* GetMessage() override;
void* GetData() override; void* GetData() override;
size_t GetSize() const override; size_t GetSize() const override;
bool SetUsedSize(const size_t size) override; bool SetUsedSize(const size_t size) override;
void SetMessage(void* data, const size_t size) override;
FairMQ::Transport GetType() const override; FairMQ::Transport GetType() const override;
void Copy(const std::unique_ptr<FairMQMessage>& msg) override; void Copy(const FairMQMessagePtr& msg) override;
~FairMQMessageNN() override; ~FairMQMessageNN() override;
friend class FairMQSocketNN;
private: private:
void* fMessage; void* fMessage;
size_t fSize; size_t fSize;
@ -60,7 +59,9 @@ class FairMQMessageNN : public FairMQMessage
FairMQUnmanagedRegion* fRegionPtr; FairMQUnmanagedRegion* fRegionPtr;
static FairMQ::Transport fTransportType; static FairMQ::Transport fTransportType;
void Clear(); void* GetMessage();
void CloseMessage();
void SetMessage(void* data, const size_t size);
}; };
#endif /* FAIRMQMESSAGENN_H_ */ #endif /* FAIRMQMESSAGENN_H_ */

View File

@ -125,18 +125,20 @@ int FairMQSocketNN::Send(FairMQMessagePtr& msg, const int flags)
{ {
int nbytes = -1; int nbytes = -1;
FairMQMessageNN* msgPtr = static_cast<FairMQMessageNN*>(msg.get());
void* bufPtr = msgPtr->GetMessage();
while (true) while (true)
{ {
void* ptr = msg->GetMessage(); if (msgPtr->fRegionPtr == nullptr)
if (static_cast<FairMQMessageNN*>(msg.get())->fRegionPtr == nullptr)
{ {
nbytes = nn_send(fSocket, &ptr, NN_MSG, flags); nbytes = nn_send(fSocket, &bufPtr, NN_MSG, flags);
} }
else else
{ {
nbytes = nn_send(fSocket, ptr, msg->GetSize(), flags); nbytes = nn_send(fSocket, bufPtr, msg->GetSize(), flags);
// nn_send copies the data, safe to call region callback here // nn_send copies the data, safe to call region callback here
static_cast<FairMQUnmanagedRegionNN*>(static_cast<FairMQMessageNN*>(msg.get())->fRegionPtr)->fCallback(msg->GetMessage(), msg->GetSize()); static_cast<FairMQUnmanagedRegionNN*>(msgPtr->fRegionPtr)->fCallback(bufPtr, msg->GetSize());
} }
if (nbytes >= 0) if (nbytes >= 0)
@ -183,6 +185,8 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags)
{ {
int nbytes = -1; int nbytes = -1;
FairMQMessageNN* msgPtr = static_cast<FairMQMessageNN*>(msg.get());
while (true) while (true)
{ {
void* ptr = nullptr; void* ptr = nullptr;
@ -191,8 +195,8 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags)
{ {
fBytesRx += nbytes; fBytesRx += nbytes;
++fMessagesRx; ++fMessagesRx;
msg->SetMessage(ptr, nbytes); msgPtr->SetMessage(ptr, nbytes);
static_cast<FairMQMessageNN*>(msg.get())->fReceiving = true; msgPtr->fReceiving = true;
return nbytes; return nbytes;
} }
#if NN_VERSION_CURRENT>2 // backwards-compatibility with nanomsg version<=0.6 #if NN_VERSION_CURRENT>2 // backwards-compatibility with nanomsg version<=0.6
@ -227,7 +231,7 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags)
} }
} }
int64_t FairMQSocketNN::Send(vector<unique_ptr<FairMQMessage>>& msgVec, const int flags) int64_t FairMQSocketNN::Send(vector<FairMQMessagePtr>& msgVec, const int flags)
{ {
const unsigned int vecSize = msgVec.size(); const unsigned int vecSize = msgVec.size();
#ifdef MSGPACK_FOUND #ifdef MSGPACK_FOUND
@ -240,13 +244,15 @@ int64_t FairMQSocketNN::Send(vector<unique_ptr<FairMQMessage>>& msgVec, const in
// pack all parts into a single msgpack simple buffer // pack all parts into a single msgpack simple buffer
for (unsigned int i = 0; i < vecSize; ++i) for (unsigned int i = 0; i < vecSize; ++i)
{ {
static_cast<FairMQMessageNN*>(msgVec[i].get())->fReceiving = false; FairMQMessageNN* partPtr = static_cast<FairMQMessageNN*>(msgVec[i].get());
partPtr->fReceiving = false;
packer.pack_bin(msgVec[i]->GetSize()); packer.pack_bin(msgVec[i]->GetSize());
packer.pack_bin_body(static_cast<char*>(msgVec[i]->GetData()), msgVec[i]->GetSize()); packer.pack_bin_body(static_cast<char*>(msgVec[i]->GetData()), msgVec[i]->GetSize());
// call region callback // call region callback
if (static_cast<FairMQMessageNN*>(msgVec[i].get())->fRegionPtr) if (partPtr->fRegionPtr)
{ {
static_cast<FairMQUnmanagedRegionNN*>(static_cast<FairMQMessageNN*>(msgVec[i].get())->fRegionPtr)->fCallback(msgVec[i]->GetMessage(), msgVec[i]->GetSize()); static_cast<FairMQUnmanagedRegionNN*>(partPtr->fRegionPtr)->fCallback(partPtr->GetMessage(), msgVec[i]->GetSize());
} }
} }
@ -297,7 +303,7 @@ int64_t FairMQSocketNN::Send(vector<unique_ptr<FairMQMessage>>& msgVec, const in
#endif /*MSGPACK_FOUND*/ #endif /*MSGPACK_FOUND*/
} }
int64_t FairMQSocketNN::Receive(vector<unique_ptr<FairMQMessage>>& msgVec, const int flags) int64_t FairMQSocketNN::Receive(vector<FairMQMessagePtr>& msgVec, const int flags)
{ {
#ifdef MSGPACK_FOUND #ifdef MSGPACK_FOUND
// Warn if the vector is filled before Receive() and empty it. // Warn if the vector is filled before Receive() and empty it.
@ -334,7 +340,7 @@ int64_t FairMQSocketNN::Receive(vector<unique_ptr<FairMQMessage>>& msgVec, const
object.convert(buf); object.convert(buf);
// get the single message size // get the single message size
size_t size = buf.size() * sizeof(char); size_t size = buf.size() * sizeof(char);
unique_ptr<FairMQMessage> part(new FairMQMessageNN(size)); FairMQMessagePtr part(new FairMQMessageNN(size));
static_cast<FairMQMessageNN*>(part.get())->fReceiving = true; static_cast<FairMQMessageNN*>(part.get())->fReceiving = true;
memcpy(part->GetData(), buf.data(), size); memcpy(part->GetData(), buf.data(), size);
msgVec.push_back(move(part)); msgVec.push_back(move(part));

View File

@ -197,7 +197,7 @@ void FairMQMessageSHM::Rebuild(void* data, const size_t size, fairmq_free_fn* ff
} }
} }
void* FairMQMessageSHM::GetMessage() zmq_msg_t* FairMQMessageSHM::GetMessage()
{ {
return &fMessage; return &fMessage;
} }
@ -269,11 +269,6 @@ bool FairMQMessageSHM::SetUsedSize(const size_t size)
} }
} }
void FairMQMessageSHM::SetMessage(void*, const size_t)
{
// dummy method to comply with the interface. functionality not allowed in zeromq.
}
FairMQ::Transport FairMQMessageSHM::GetType() const FairMQ::Transport FairMQMessageSHM::GetType() const
{ {
return fTransportType; return fTransportType;

View File

@ -20,6 +20,8 @@
#include <cstddef> // size_t #include <cstddef> // size_t
#include <atomic> #include <atomic>
class FairMQSocketSHM;
class FairMQMessageSHM : public FairMQMessage class FairMQMessageSHM : public FairMQMessage
{ {
friend class FairMQSocketSHM; friend class FairMQSocketSHM;
@ -33,25 +35,18 @@ class FairMQMessageSHM : public FairMQMessage
FairMQMessageSHM(const FairMQMessageSHM&) = delete; FairMQMessageSHM(const FairMQMessageSHM&) = delete;
FairMQMessageSHM operator=(const FairMQMessageSHM&) = delete; FairMQMessageSHM operator=(const FairMQMessageSHM&) = delete;
bool InitializeChunk(const size_t size);
void Rebuild() override; void Rebuild() override;
void Rebuild(const size_t size) override; void Rebuild(const size_t size) override;
void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
void* GetMessage() override;
void* GetData() override; void* GetData() override;
size_t GetSize() const override; size_t GetSize() const override;
bool SetUsedSize(const size_t size) override; bool SetUsedSize(const size_t size) override;
void SetMessage(void* data, const size_t size) override;
FairMQ::Transport GetType() const override; FairMQ::Transport GetType() const override;
void Copy(const std::unique_ptr<FairMQMessage>& msg) override; void Copy(const FairMQMessagePtr& msg) override;
void CloseMessage();
~FairMQMessageSHM() override; ~FairMQMessageSHM() override;
@ -67,6 +62,10 @@ class FairMQMessageSHM : public FairMQMessage
boost::interprocess::managed_shared_memory::handle_t fHandle; boost::interprocess::managed_shared_memory::handle_t fHandle;
size_t fSize; size_t fSize;
char* fLocalPtr; char* fLocalPtr;
bool InitializeChunk(const size_t size);
zmq_msg_t* GetMessage();
void CloseMessage();
}; };
#endif /* FAIRMQMESSAGESHM_H_ */ #endif /* FAIRMQMESSAGESHM_H_ */

View File

@ -114,7 +114,7 @@ int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags)
int nbytes = -1; int nbytes = -1;
while (true && !fInterrupted) while (true && !fInterrupted)
{ {
nbytes = zmq_msg_send(static_cast<zmq_msg_t*>(msg->GetMessage()), fSocket, flags); nbytes = zmq_msg_send(static_cast<FairMQMessageSHM*>(msg.get())->GetMessage(), fSocket, flags);
if (nbytes == 0) if (nbytes == 0)
{ {
return nbytes; return nbytes;
@ -158,7 +158,7 @@ int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags)
int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags)
{ {
int nbytes = -1; int nbytes = -1;
zmq_msg_t* msgPtr = static_cast<zmq_msg_t*>(msg->GetMessage()); zmq_msg_t* msgPtr = static_cast<FairMQMessageSHM*>(msg.get())->GetMessage();
while (true) while (true)
{ {
nbytes = zmq_msg_recv(msgPtr, fSocket, flags); nbytes = zmq_msg_recv(msgPtr, fSocket, flags);
@ -221,7 +221,7 @@ int64_t FairMQSocketSHM::Send(vector<FairMQMessagePtr>& msgVec, const int flags)
{ {
for (unsigned int i = 0; i < vecSize; ++i) for (unsigned int i = 0; i < vecSize; ++i)
{ {
nbytes = zmq_msg_send(static_cast<zmq_msg_t*>(msgVec[i]->GetMessage()), nbytes = zmq_msg_send(static_cast<FairMQMessageSHM*>(msgVec[i].get())->GetMessage(),
fSocket, fSocket,
(i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags);
if (nbytes >= 0) if (nbytes >= 0)
@ -302,7 +302,7 @@ int64_t FairMQSocketSHM::Receive(vector<FairMQMessagePtr>& msgVec, const int fla
do do
{ {
FairMQMessagePtr part(new FairMQMessageSHM(fManager)); FairMQMessagePtr part(new FairMQMessageSHM(fManager));
zmq_msg_t* msgPtr = static_cast<zmq_msg_t*>(part->GetMessage()); zmq_msg_t* msgPtr = static_cast<FairMQMessageSHM*>(part.get())->GetMessage();
int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); int nbytes = zmq_msg_recv(msgPtr, fSocket, flags);
if (nbytes == 0) if (nbytes == 0)

View File

@ -112,7 +112,7 @@ void FairMQMessageZMQ::Rebuild(void* data, const size_t size, fairmq_free_fn* ff
} }
} }
void* FairMQMessageZMQ::GetMessage() zmq_msg_t* FairMQMessageZMQ::GetMessage()
{ {
if (!fViewMsg) if (!fViewMsg)
{ {
@ -190,11 +190,6 @@ void FairMQMessageZMQ::ApplyUsedSize()
} }
} }
void FairMQMessageZMQ::SetMessage(void*, const size_t)
{
// dummy method to comply with the interface. functionality not allowed in zeromq.
}
FairMQ::Transport FairMQMessageZMQ::GetType() const FairMQ::Transport FairMQMessageZMQ::GetType() const
{ {
return fTransportType; return fTransportType;
@ -202,18 +197,19 @@ FairMQ::Transport FairMQMessageZMQ::GetType() const
void FairMQMessageZMQ::Copy(const FairMQMessagePtr& msg) void FairMQMessageZMQ::Copy(const FairMQMessagePtr& msg)
{ {
FairMQMessageZMQ* msgPtr = static_cast<FairMQMessageZMQ*>(msg.get());
// Shares the message buffer between msg and this fMsg. // Shares the message buffer between msg and this fMsg.
if (zmq_msg_copy(fMsg.get(), static_cast<zmq_msg_t*>(msg->GetMessage())) != 0) if (zmq_msg_copy(fMsg.get(), msgPtr->GetMessage()) != 0)
{ {
LOG(ERROR) << "failed copying message, reason: " << zmq_strerror(errno); LOG(ERROR) << "failed copying message, reason: " << zmq_strerror(errno);
return; return;
} }
// if the target message has been resized, apply same to this message also // if the target message has been resized, apply same to this message also
if (static_cast<FairMQMessageZMQ*>(msg.get())->fUsedSizeModified) if (msgPtr->fUsedSizeModified)
{ {
fUsedSizeModified = true; fUsedSizeModified = true;
fUsedSize = static_cast<FairMQMessageZMQ*>(msg.get())->fUsedSize; fUsedSize = msgPtr->fUsedSize;
} }
} }

View File

@ -24,8 +24,12 @@
#include "FairMQMessage.h" #include "FairMQMessage.h"
#include "FairMQUnmanagedRegion.h" #include "FairMQUnmanagedRegion.h"
class FairMQSocketZMQ;
class FairMQMessageZMQ : public FairMQMessage class FairMQMessageZMQ : public FairMQMessage
{ {
friend class FairMQSocketZMQ;
public: public:
FairMQMessageZMQ(); FairMQMessageZMQ();
FairMQMessageZMQ(const size_t size); FairMQMessageZMQ(const size_t size);
@ -36,21 +40,15 @@ class FairMQMessageZMQ : public FairMQMessage
void Rebuild(const size_t size) override; void Rebuild(const size_t size) override;
void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
void* GetMessage() override;
void* GetData() override; void* GetData() override;
size_t GetSize() const override; size_t GetSize() const override;
bool SetUsedSize(const size_t size) override; bool SetUsedSize(const size_t size) override;
void ApplyUsedSize(); void ApplyUsedSize();
void SetMessage(void* data, const size_t size) override;
FairMQ::Transport GetType() const override; FairMQ::Transport GetType() const override;
void Copy(const std::unique_ptr<FairMQMessage>& msg) override; void Copy(const FairMQMessagePtr& msg) override;
void CloseMessage();
~FairMQMessageZMQ() override; ~FairMQMessageZMQ() override;
@ -60,6 +58,9 @@ class FairMQMessageZMQ : public FairMQMessage
std::unique_ptr<zmq_msg_t> fMsg; std::unique_ptr<zmq_msg_t> fMsg;
std::unique_ptr<zmq_msg_t> fViewMsg; // view on a subset of fMsg (treating it as user buffer) std::unique_ptr<zmq_msg_t> fViewMsg; // view on a subset of fMsg (treating it as user buffer)
static FairMQ::Transport fTransportType; static FairMQ::Transport fTransportType;
zmq_msg_t* GetMessage();
void CloseMessage();
}; };
#endif /* FAIRMQMESSAGEZMQ_H_ */ #endif /* FAIRMQMESSAGEZMQ_H_ */

View File

@ -119,7 +119,7 @@ int FairMQSocketZMQ::Send(FairMQMessagePtr& msg, const int flags)
while (true) while (true)
{ {
nbytes = zmq_msg_send(static_cast<zmq_msg_t*>(msg->GetMessage()), fSocket, flags); nbytes = zmq_msg_send(static_cast<FairMQMessageZMQ*>(msg.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) if (nbytes >= 0)
{ {
fBytesTx += nbytes; fBytesTx += nbytes;
@ -157,7 +157,7 @@ int FairMQSocketZMQ::Receive(FairMQMessagePtr& msg, const int flags)
while (true) while (true)
{ {
nbytes = zmq_msg_recv(static_cast<zmq_msg_t*>(msg->GetMessage()), fSocket, flags); nbytes = zmq_msg_recv(static_cast<FairMQMessageZMQ*>(msg.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) if (nbytes >= 0)
{ {
fBytesRx += nbytes; fBytesRx += nbytes;
@ -209,7 +209,7 @@ int64_t FairMQSocketZMQ::Send(vector<FairMQMessagePtr>& msgVec, const int flags)
{ {
static_cast<FairMQMessageZMQ*>(msgVec[i].get())->ApplyUsedSize(); static_cast<FairMQMessageZMQ*>(msgVec[i].get())->ApplyUsedSize();
nbytes = zmq_msg_send(static_cast<zmq_msg_t*>(msgVec[i]->GetMessage()), nbytes = zmq_msg_send(static_cast<FairMQMessageZMQ*>(msgVec[i].get())->GetMessage(),
fSocket, fSocket,
(i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags);
if (nbytes >= 0) if (nbytes >= 0)
@ -279,7 +279,7 @@ int64_t FairMQSocketZMQ::Receive(vector<FairMQMessagePtr>& msgVec, const int fla
{ {
unique_ptr<FairMQMessage> part(new FairMQMessageZMQ()); unique_ptr<FairMQMessage> part(new FairMQMessageZMQ());
int nbytes = zmq_msg_recv(static_cast<zmq_msg_t*>(part->GetMessage()), fSocket, flags); int nbytes = zmq_msg_recv(static_cast<FairMQMessageZMQ*>(part.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) if (nbytes >= 0)
{ {
msgVec.push_back(move(part)); msgVec.push_back(move(part));