diff --git a/fairmq/shmem/FairMQSocketSHM.cxx b/fairmq/shmem/FairMQSocketSHM.cxx index f7ca5c65..c4f0d2e3 100644 --- a/fairmq/shmem/FairMQSocketSHM.cxx +++ b/fairmq/shmem/FairMQSocketSHM.cxx @@ -1,8 +1,8 @@ /******************************************************************************** * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * + * This software is distributed under the terms of the * + * GNU Lesser General Public Licence (LGPL) version 3, * * copied verbatim in the file "LICENSE" * ********************************************************************************/ #include @@ -158,6 +158,7 @@ int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags) int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) { int nbytes = -1; + zmq_msg_t* msgPtr = static_cast(msg.get())->GetMessage(); while (true) { @@ -170,6 +171,15 @@ int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) } else if (nbytes > 0) { + // check for number of receiving messages. must be 1 + const auto numMsgs = nbytes / sizeof(MetaHeader); + if (numMsgs > 1) + { + LOG(ERROR) << "Receiving SHM multipart with a single message receive call"; + } + + assert (numMsgs == 1); + MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); size_t size = 0; static_cast(msg.get())->fHandle = hdr->fHandle; @@ -210,151 +220,150 @@ int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) int64_t FairMQSocketSHM::Send(vector& msgVec, const int flags) { const unsigned int vecSize = msgVec.size(); + int64_t totalSize = 0; - // Sending vector typicaly handles more then one part - if (vecSize > 1) + if (vecSize == 1) { + return Send(msgVec.back(), flags); + } + + // put it into zmq message + zmq_msg_t lZmqMsg; + zmq_msg_init_size(&lZmqMsg, vecSize * sizeof(MetaHeader)); + + // prepare the message with shm metas + MetaHeader *lMetas = static_cast(zmq_msg_data(&lZmqMsg)); + + for (auto &lMsg : msgVec) + { + zmq_msg_t *lMetaMsg = static_cast(lMsg.get())->GetMessage(); + memcpy(lMetas++, zmq_msg_data(lMetaMsg), sizeof(MetaHeader)); + } + + while (!fInterrupted) { - int64_t totalSize = 0; int nbytes = -1; - bool repeat = false; + nbytes = zmq_msg_send(&lZmqMsg, fSocket, flags); - while (true && !fInterrupted) + if (nbytes == 0) { - for (unsigned int i = 0; i < vecSize; ++i) - { - nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), - fSocket, - (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); - if (nbytes >= 0) - { - static_cast(msgVec[i].get())->fQueued = true; - size_t size = msgVec[i]->GetSize(); + zmq_msg_close (&lZmqMsg); + return nbytes; + } + else if (nbytes > 0) + { + assert(nbytes == (vecSize * sizeof(MetaHeader))); // all or nothing - totalSize += size; - } - else - { - // according to ZMQ docs, this can only occur for the first part - if (zmq_errno() == EAGAIN) - { - if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - { - repeat = true; - break; - } - else - { - return -2; - } - } - if (zmq_errno() == ETERM) - { - LOG(info) << "terminating socket " << fId; - return -1; - } - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; - } + for (auto &lMsg : msgVec) + { + FairMQMessageSHM *lShmMsg = static_cast(lMsg.get()); + lShmMsg->fQueued = true; + totalSize += lShmMsg->fSize; } - if (repeat) + // store statistics on how many messages have been sent + fMessagesTx++; + fBytesTx += totalSize; + + zmq_msg_close (&lZmqMsg); + return totalSize; + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) { continue; } - - // store statistics on how many messages have been sent (handle all parts as a single message) - ++fMessagesTx; - fBytesTx += totalSize; - return totalSize; + else + { + zmq_msg_close (&lZmqMsg); + return -2; + } + } + else if (zmq_errno() == ETERM) + { + zmq_msg_close (&lZmqMsg); + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + zmq_msg_close (&lZmqMsg); + LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); + return nbytes; } - - return -1; - } // If there's only one part, send it as a regular message - else if (vecSize == 1) - { - return Send(msgVec.back(), flags); - } - else // if the vector is empty, something might be wrong - { - LOG(warn) << "Will not send empty vector"; - return -1; } } + int64_t FairMQSocketSHM::Receive(vector& msgVec, const int flags) { int64_t totalSize = 0; - int64_t more = 0; - bool repeat = false; - while (true) + while (!fInterrupted) { - // Warn if the vector is filled before Receive() and empty it. - // if (msgVec.size() > 0) - // { - // LOG(warn) << "Message vector contains elements before Receive(), they will be deleted!"; - // msgVec.clear(); - // } - - totalSize = 0; - more = 0; - repeat = false; - - do + zmq_msg_t lRcvMsg; + zmq_msg_init(&lRcvMsg); + int nbytes = zmq_msg_recv(&lRcvMsg, fSocket, flags); + if (nbytes == 0) { - FairMQMessagePtr part(new FairMQMessageSHM(fManager)); - zmq_msg_t* msgPtr = static_cast(part.get())->GetMessage(); + zmq_msg_close (&lRcvMsg); + return 0; + } + else if (nbytes > 0) + { + MetaHeader* lHdrVec = static_cast(zmq_msg_data(&lRcvMsg)); + const auto lHdrVecSize = zmq_msg_size(&lRcvMsg); + assert(lHdrVecSize > 0); + assert(lHdrVecSize % sizeof(MetaHeader) == 0); - int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); - if (nbytes == 0) + const auto lNumMessages = lHdrVecSize / sizeof (MetaHeader); + + msgVec.reserve(lNumMessages); + + for (auto m = 0; m < lNumMessages; m++) { - msgVec.push_back(move(part)); + MetaHeader lMetaHeader; + memcpy(&lMetaHeader, &lHdrVec[m], sizeof(MetaHeader)); + + msgVec.emplace_back(fair::mq::tools::make_unique(fManager)); + + FairMQMessageSHM *lMsg = static_cast(msgVec.back().get()); + MetaHeader *lMsgHdr = static_cast(zmq_msg_data(lMsg->GetMessage())); + + memcpy(lMsgHdr, &lMetaHeader, sizeof(MetaHeader)); + + lMsg->fHandle = lMetaHeader.fHandle; + lMsg->fSize = lMetaHeader.fSize; + lMsg->fRegionId = lMetaHeader.fRegionId; + lMsg->fHint = lMetaHeader.fHint; + + totalSize += lMsg->GetSize(); } - else if (nbytes > 0) - { - MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); - size_t size = 0; - static_cast(part.get())->fHandle = hdr->fHandle; - static_cast(part.get())->fSize = hdr->fSize; - static_cast(part.get())->fRegionId = hdr->fRegionId; - static_cast(part.get())->fHint = hdr->fHint; - size = part->GetSize(); - msgVec.push_back(move(part)); + // store statistics on how many messages have been received (handle all parts as a single message) + fMessagesRx++; + fBytesRx += totalSize; - totalSize += size; - } - else if (zmq_errno() == EAGAIN) + zmq_msg_close (&lRcvMsg); + return totalSize; + } + else if (zmq_errno() == EAGAIN) + { + zmq_msg_close(&lRcvMsg); + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) { - if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - { - repeat = true; - break; - } - else - { - return -2; - } + continue; } else { - return nbytes; + return -2; } - - size_t more_size = sizeof(more); - zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); } - while (more); - - if (repeat) + else { - continue; + zmq_msg_close (&lRcvMsg); + return nbytes; } - - // store statistics on how many messages have been received (handle all parts as a single message) - ++fMessagesRx; - fBytesRx += totalSize; - return totalSize; } }