From a1b7efa2f483ec2d3da362a0edaec57e068ff32d Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Tue, 19 Mar 2019 18:29:12 +0100 Subject: [PATCH] Unify implementation of multi part and single part message interfaces --- fairmq/ofi/ControlMessages.h | 73 ++-- fairmq/ofi/Socket.cxx | 703 ++++++++++++----------------------- fairmq/ofi/Socket.h | 28 +- 3 files changed, 295 insertions(+), 509 deletions(-) diff --git a/fairmq/ofi/ControlMessages.h b/fairmq/ofi/ControlMessages.h index 86209a1d..2f78798a 100644 --- a/fairmq/ofi/ControlMessages.h +++ b/fairmq/ofi/ControlMessages.h @@ -35,59 +35,76 @@ namespace ofi { enum class ControlMessageType { - DataAddressAnnouncement = 1, + Empty = 1, PostBuffer, - PostBufferAcknowledgement + PostMultiPartStartBuffer +}; + +struct Empty +{}; + +struct PostBuffer +{ + uint64_t size; // buffer size (size_t) +}; + +struct PostMultiPartStartBuffer +{ + uint32_t numParts; // buffer size (size_t) + uint64_t size; // buffer size (size_t) +}; + +union ControlMessageContent +{ + PostBuffer postBuffer; + PostMultiPartStartBuffer postMultiPartStartBuffer; }; struct ControlMessage { ControlMessageType type; -}; - -struct DataAddressAnnouncement : ControlMessage -{ - uint32_t ipv4; // in_addr_t from - uint32_t port; // in_port_t from -}; - -struct PostBuffer : ControlMessage -{ - uint64_t size; // buffer size (size_t) + ControlMessageContent msg; }; template using unique_ptr = std::unique_ptr>; template -auto MakeControlMessageWithPmr(boost::container::pmr::memory_resource* pmr, Args&&... args) - -> ofi::unique_ptr +auto MakeControlMessageWithPmr(boost::container::pmr::memory_resource& pmr, Args&&... args) + -> ofi::unique_ptr { - void* mem = pmr->allocate(sizeof(T)); - T* ctrl = new (mem) T(std::forward(args)...); + void* mem = pmr.allocate(sizeof(ControlMessage)); + ControlMessage* ctrl = new (mem) ControlMessage(); - if (std::is_same::value) { - ctrl->type = ControlMessageType::DataAddressAnnouncement; - } else if (std::is_same::value) { + if (std::is_same::value) { ctrl->type = ControlMessageType::PostBuffer; + ctrl->msg.postBuffer = PostBuffer(std::forward(args)...); + } else if (std::is_same::value) { + ctrl->type = ControlMessageType::PostMultiPartStartBuffer; + ctrl->msg.postMultiPartStartBuffer = PostMultiPartStartBuffer(std::forward(args)...); + } else if (std::is_same::value) { + ctrl->type = ControlMessageType::Empty; } - return ofi::unique_ptr(ctrl, [=](T* p) { - p->~T(); - pmr->deallocate(p, sizeof(T)); + return ofi::unique_ptr(ctrl, [&pmr](ControlMessage* p) { + p->~ControlMessage(); + pmr.deallocate(p, sizeof(T)); }); } template -auto MakeControlMessage(Args&&... args) -> T +auto MakeControlMessage(Args&&... args) -> ControlMessage { - T ctrl = T(std::forward(args)...); + ControlMessage ctrl; - if (std::is_same::value) { - ctrl.type = ControlMessageType::DataAddressAnnouncement; - } else if (std::is_same::value) { + if (std::is_same::value) { ctrl.type = ControlMessageType::PostBuffer; + } else if (std::is_same::value) { + ctrl.type = ControlMessageType::PostMultiPartStartBuffer; + } else if (std::is_same::value) { + ctrl.type = ControlMessageType::Empty; } + ctrl.msg = T(std::forward(args)...); return ctrl; } diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 882bc48a..b00c58d9 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -13,19 +13,17 @@ #include #include -#include #include +#include #include #include #include #include #include #include -#include -#include #include #include -#include +#include namespace fair { @@ -51,41 +49,15 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fMessagesRx(0) , fSndTimeout(100) , fRcvTimeout(100) - , fSendQueueWrite(fContext.GetIoContext(), ZMQ_PUSH) - , fSendQueueRead(fContext.GetIoContext(), ZMQ_PULL) - , fRecvQueueWrite(fContext.GetIoContext(), ZMQ_PUSH) - , fRecvQueueRead(fContext.GetIoContext(), ZMQ_PULL) - , fSendSem(fContext.GetIoContext(), 300) - , fRecvSem(fContext.GetIoContext(), 300) + , fMultiPartRecvCounter(-1) + , fSendPushSem(fContext.GetIoContext(), 384) + , fSendPopSem(fContext.GetIoContext(), 0) + , fRecvPushSem(fContext.GetIoContext(), 384) + , fRecvPopSem(fContext.GetIoContext(), 0) , fNeedOfiMemoryRegistration(false) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; - } else { - // TODO wire this up with config - azmq::socket::snd_hwm send_max(300); - azmq::socket::rcv_hwm recv_max(300); - fSendQueueRead.set_option(send_max); - fSendQueueRead.set_option(recv_max); - fSendQueueWrite.set_option(send_max); - fSendQueueWrite.set_option(recv_max); - fRecvQueueRead.set_option(send_max); - fRecvQueueRead.set_option(recv_max); - fRecvQueueWrite.set_option(send_max); - fRecvQueueWrite.set_option(recv_max); - - // Setup internal queue - auto hashed_id = std::hash()(fId); - auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id; - fSendQueueRead.bind(queue_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id; - fSendQueueWrite.connect(queue_id); - queue_id = tools::ToString("inproc://RXQUEUE", hashed_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id; - fRecvQueueRead.bind(queue_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id; - fRecvQueueWrite.connect(queue_id); } } @@ -107,7 +79,7 @@ auto Socket::InitOfi(Address addr) -> void fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints); } - LOG(debug) << "OFI transport: " << *fOfiInfo; + LOG(debug) << "OFI transport (" << fId << "): " << *fOfiInfo; fOfiFabric = tools::make_unique(*fOfiInfo); @@ -268,222 +240,213 @@ auto Socket::ConnectEndpoint(std::unique_ptr& endpoi } } -// auto Socket::ReceiveDataAddressAnnouncement() -> void -// { - // azmq::message ctrl; - // auto recv = fControlEndpoint.receive(ctrl); - // assert(recv == sizeof(DataAddressAnnouncement)); (void)recv; - // auto daa(static_cast(ctrl.data())); - // assert(daa->type == ControlMessageType::DataAddressAnnouncement); -// - // sockaddr_in remoteAddr; - // remoteAddr.sin_family = AF_INET; - // remoteAddr.sin_port = daa->port; - // remoteAddr.sin_addr.s_addr = daa->ipv4; -// - // auto addr = Context::ConvertAddress(remoteAddr); - // addr.Protocol = fRemoteDataAddr.Protocol; - // LOG(debug) << "OFI transport (" << fId << "): Data address announcement of remote endpoint received: " << addr; - // fRemoteDataAddr = addr; -// } -// -// auto Socket::AnnounceDataAddress() -> void -// { - // fLocalDataAddr = fDataEndpoint->get_local_address(); - // LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr); -// - // Create new data address announcement message - // auto daa = MakeControlMessage(); - // auto addr = Context::ConvertAddress(fLocalDataAddr); - // daa.ipv4 = addr.sin_addr.s_addr; - // daa.port = addr.sin_port; -// - // auto sent = fControlEndpoint.send(boost::asio::buffer(daa)); - // assert(sent == sizeof(addr)); (void)sent; -// - // LOG(debug) << "OFI transport (" << fId << "): data band address " << fLocalDataAddr << " announced."; -// } - auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int { - // LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize(); + // timeout argument not yet implemented - MessagePtr* msgptr(new std::unique_ptr(std::move(msg))); - try { - auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0); + std::vector msgVec; + msgVec.reserve(1); + msgVec.emplace_back(std::move(msg)); - // LOG(debug) << "OFI transport (" << fId << "): LEAVE Send"; - return res; - } catch (const std::exception& e) { - msg = std::move(*msgptr); - LOG(error) << e.what(); - return -1; - } catch (const boost::system::error_code& e) { - msg = std::move(*msgptr); - LOG(error) << e; - return -1; + return Send(msgVec); +} + +auto Socket::Send(std::vector& msgVec, const int /*timeout*/) -> int64_t +try { + // timeout argument not yet implemented + + int size(0); + for (auto& msg : msgVec) { + size += msg->GetSize(); + } + + fSendPushSem.wait(); + { + std::lock_guard lk(fSendQueueMutex); + fSendQueue.emplace(std::move(msgVec)); } -} + fSendPopSem.signal(); -auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int -{ - // LOG(debug) << "OFI transport (" << fId << "): ENTER Receive"; - - try { - azmq::message zmsg; - auto recv = fRecvQueueRead.receive(zmsg); - - size_t size(0); - if (recv > 0) { - msg = std::move(*(static_cast(zmsg.buffer().data()))); - size = msg->GetSize(); - } - - fBytesRx += size; - fMessagesRx++; - - // LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive"; - return size; - } catch (const std::exception& e) { - LOG(error) << e.what(); - return -1; - } catch (const boost::system::error_code& e) { - LOG(error) << e; - return -1; - } -} - -auto Socket::Send(std::vector& msgVec, const int timeout) -> int64_t -{ - return SendImpl(msgVec, 0, timeout); -} - -auto Socket::Receive(std::vector& msgVec, const int timeout) -> int64_t -{ - return ReceiveImpl(msgVec, 0, timeout); + return size; +} catch (const std::exception& e) { + LOG(error) << e.what(); + return -1; } auto Socket::SendQueueReader() -> void { - fSendSem.async_wait([&] { - // LOG(debug) << "OFI transport (" << fId << "): < Wait fSendSem=" << - // fSendSem.get_value(); - fSendQueueRead.async_receive([&](const boost::system::error_code& ec2, - azmq::message& zmsg, - size_t bytes_transferred) { - if (!ec2) { - OnSend(zmsg, bytes_transferred); + fSendPopSem.async_wait([&] { + // Read msg from send queue + std::unique_lock lk(fSendQueueMutex); + std::vector msgVec(std::move(fSendQueue.front())); + fSendQueue.pop(); + lk.unlock(); + + bool postMultiPartStartBuffer = msgVec.size() > 1; + for (auto& msg : msgVec) { + // Create control message + ofi::unique_ptr ctrl(nullptr); + if (postMultiPartStartBuffer) { + postMultiPartStartBuffer = false; + ctrl = MakeControlMessageWithPmr(fControlMemPool); + ctrl->msg.postMultiPartStartBuffer.numParts = msgVec.size(); + ctrl->msg.postMultiPartStartBuffer.size = msg->GetSize(); + } else { + ctrl = MakeControlMessageWithPmr(fControlMemPool); + ctrl->msg.postBuffer.size = msg->GetSize(); } - }); + + // Send control message + boost::asio::mutable_buffer ctrlMsg(ctrl.get(), sizeof(ControlMessage)); + + if (fNeedOfiMemoryRegistration) { + asiofi::memory_region mr(*fOfiDomain, ctrlMsg, asiofi::mr::access::send); + auto desc = mr.desc(); + fControlEndpoint->send(ctrlMsg, + desc, + [&, ctrl2 = std::move(ctrlMsg), mr2 = std::move(mr)]( + boost::asio::mutable_buffer) mutable {}); + } else { + fControlEndpoint->send( + ctrlMsg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable {}); + } + + // Send data message + const auto size = msg->GetSize(); + + if (size) { + boost::asio::mutable_buffer buffer(msg->GetData(), size); + + if (fNeedOfiMemoryRegistration) { + asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send); + auto desc = mr.desc(); + + fDataEndpoint->send(buffer, + desc, + [&, size, msg2 = std::move(msg), mr2 = std::move(mr)]( + boost::asio::mutable_buffer) mutable { + fBytesTx += size; + fMessagesTx++; + fSendPushSem.signal(); + }); + + } else { + fDataEndpoint->send( + buffer, [&, size, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { + fBytesTx += size; + fMessagesTx++; + fSendPushSem.signal(); + }); + } + } else { + ++fMessagesTx; + fSendPushSem.signal(); + } + } + + boost::asio::dispatch(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); }); } -auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void -{ - // LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred; +auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int +try { + // timeout argument not yet implemented - MessagePtr msg(std::move(*(static_cast(zmsg.buffer().data())))); - auto size = msg->GetSize(); - - // LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); - - // Create and send control message - auto ctrl = MakeControlMessageWithPmr(&fControlMemPool); - ctrl->size = size; - auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); - if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(*fOfiDomain, ctrl_msg, asiofi::mr::access::send); - auto desc = mr.desc(); - fControlEndpoint->send( - ctrl_msg, desc, [&, ctrl2 = std::move(ctrl), mr2 = std::move(mr)](boost::asio::mutable_buffer) mutable { - // LOG(debug) << "OFI transport (" << fId << "): >>>>> Control message sent"; - }); - } else { - fControlEndpoint->send(ctrl_msg, - [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { - // LOG(debug) << "OFI transport (" << fId << "): >>>>> Control - // message sent"; - }); + fRecvPopSem.wait(); + { + std::lock_guard lk(fRecvQueueMutex); + msg = std::move(fRecvQueue.front().front()); + fRecvQueue.pop(); } + fRecvPushSem.signal(); - if (size) { - boost::asio::mutable_buffer buffer(msg->GetData(), size); + int size(msg->GetSize()); + fBytesRx += size; + ++fMessagesRx; - if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send); - auto desc = mr.desc(); + return size; +} catch (const std::exception& e) { + LOG(error) << e.what(); + return -1; +} - fDataEndpoint->send(buffer, - desc, - [&, size, msg2 = std::move(msg), mr2 = std::move(mr)]( - boost::asio::mutable_buffer) mutable { - // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data - // buffer sent"; - fBytesTx += size; - fMessagesTx++; - fSendSem.async_signal([&] { - // LOG(debug) << "OFI transport (" << fId << "): > - // Signal fSendSem=" << fSendSem.get_value(); - }); - }); +auto Socket::Receive(std::vector& msgVec, const int /*timeout*/) -> int64_t +try { + // timeout argument not yet implemented - } else { - fDataEndpoint->send( - buffer, [&, size, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { - // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; - fBytesTx += size; - fMessagesTx++; - fSendSem.async_signal([&] { - // LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" - // << fSendSem.get_value(); - }); - }); - } - } else { - ++fMessagesTx; - fSendSem.async_signal([&] { - // LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); - }); + fRecvPopSem.wait(); + { + std::lock_guard lk(fRecvQueueMutex); + msgVec = std::move(fRecvQueue.front()); + fRecvQueue.pop(); } + fRecvPushSem.signal(); - boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); - // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; + int64_t size(0); + for (auto& msg : msgVec) { + size += msg->GetSize(); + } + fBytesRx += size; + ++fMessagesRx; + + return size; +} catch (const std::exception& e) { + LOG(error) << e.what(); + return -1; } auto Socket::RecvControlQueueReader() -> void { - fRecvSem.async_wait([&] { - auto ctrl = MakeControlMessageWithPmr(&fControlMemPool); - auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); + fRecvPushSem.async_wait([&] { + // Receive control message + ofi::unique_ptr ctrl(MakeControlMessageWithPmr(fControlMemPool)); + boost::asio::mutable_buffer ctrlMsg(ctrl.get(), sizeof(ControlMessage)); if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(*fOfiDomain, ctrl_msg, asiofi::mr::access::recv); + asiofi::memory_region mr(*fOfiDomain, ctrlMsg, asiofi::mr::access::recv); auto desc = mr.desc(); fControlEndpoint->recv( - ctrl_msg, + ctrlMsg, desc, [&, ctrl2 = std::move(ctrl), mr2 = std::move(mr)]( boost::asio::mutable_buffer) mutable { OnRecvControl(std::move(ctrl2)); }); } else { fControlEndpoint->recv( - ctrl_msg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { + ctrlMsg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { OnRecvControl(std::move(ctrl2)); }); } }); } -auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void +auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void { - // LOG(debug) << "OFI transport (" << fId << "): ENTER OnRecvControl"; - - auto size = ctrl->size; - // LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size; + // Check control message type + auto size(0); + if (ctrl->type == ControlMessageType::PostMultiPartStartBuffer) { + size = ctrl->msg.postMultiPartStartBuffer.size; + if (fMultiPartRecvCounter == -1) { + fMultiPartRecvCounter = ctrl->msg.postMultiPartStartBuffer.numParts; + assert(fInflightMultiPartMessage.empty()); + fInflightMultiPartMessage.reserve(ctrl->msg.postMultiPartStartBuffer.numParts); + } else { + throw SocketError{tools::ToString( + "OFI transport: Received control start of new multi part message without completed " + "reception of previous multi part message. Number of parts missing: ", + fMultiPartRecvCounter)}; + } + } else if (ctrl->type == ControlMessageType::PostBuffer) { + size = ctrl->msg.postBuffer.size; + } else { + throw SocketError{tools::ToString("OFI transport: Unknown control message type: '", + static_cast(ctrl->type))}; + } // Receive data + auto msg = fContext.MakeReceiveMessage(size); + if (size) { - auto msg = fContext.MakeReceiveMessage(size); boost::asio::mutable_buffer buffer(msg->GetData(), size); if (fNeedOfiMemoryRegistration) { @@ -494,229 +457,41 @@ auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void buffer, desc, [&, msg2 = std::move(msg), mr2 = std::move(mr)]( - boost::asio::mutable_buffer) mutable { - MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); - fRecvQueueWrite.async_send( - azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), - [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { - if (!ec) { - // LOG(debug) << "OFI transport (" << fId - // << "): <<<<< Data buffer received, bytes_transferred2=" - // << bytes_transferred2; - fRecvSem.async_signal([&] { - //LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; - }); - } - }); - }); + boost::asio::mutable_buffer) mutable { DataMessageReceived(std::move(msg2)); }); } else { - fDataEndpoint->recv( - buffer, [&, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { - MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); - fRecvQueueWrite.async_send( - azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), - [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { - if (!ec) { - // LOG(debug) << "OFI transport (" << fId - // << "): <<<<< Data buffer received, bytes_transferred2=" - // << bytes_transferred2; - fRecvSem.async_signal([&] { - // LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; + fDataEndpoint->recv(buffer, + [&, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { + DataMessageReceived(std::move(msg2)); }); - } - }); - }); } } else { - fRecvQueueWrite.async_send( - azmq::message(boost::asio::const_buffer(nullptr, 0)), - [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { - if (!ec) { - // LOG(debug) << "OFI transport (" << fId - // << "): <<<<< Data buffer received, bytes_transferred2=" - // << bytes_transferred2; - fRecvSem.async_signal([&] { - // LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; - }); - } - }); + DataMessageReceived(std::move(msg)); } - boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); - - // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl"; + boost::asio::dispatch(fContext.GetIoContext(), + std::bind(&Socket::RecvControlQueueReader, this)); } -auto Socket::SendImpl(vector& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t +auto Socket::DataMessageReceived(MessagePtr msg) -> void { - throw SocketError{"Not yet implemented."}; - // const unsigned int vecSize = msgVec.size(); - // int elapsed = 0; - // - // // Sending vector typicaly handles more then one part - // if (vecSize > 1) - // { - // int64_t totalSize = 0; - // int nbytes = -1; - // bool repeat = false; - // - // while (true && !fInterrupted) - // { - // for (unsigned int i = 0; i < vecSize; ++i) - // { - // nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), - // fSocket, - // (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); - // if (nbytes >= 0) - // { - // static_cast(msgVec[i].get())->fQueued = true; - // size_t size = msgVec[i]->GetSize(); - // - // totalSize += size; - // } - // else - // { - // // according to ZMQ docs, this can only occur for the first part - // if (zmq_errno() == EAGAIN) - // { - // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - // { - // if (timeout) - // { - // elapsed += fSndTimeout; - // if (elapsed >= timeout) - // { - // return -2; - // } - // } - // repeat = true; - // break; - // } - // else - // { - // return -2; - // } - // } - // if (zmq_errno() == ETERM) - // { - // LOG(info) << "terminating socket " << fId; - // return -1; - // } - // LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - // return nbytes; - // } - // } - // - // if (repeat) - // { - // continue; - // } - // - // // store statistics on how many messages have been sent (handle all parts as a single message) - // ++fMessagesTx; - // fBytesTx += totalSize; - // return totalSize; - // } - // - // return -1; - // } // If there's only one part, send it as a regular message - // else if (vecSize == 1) - // { - // return Send(msgVec.back(), flags); - // } - // else // if the vector is empty, something might be wrong - // { - // LOG(warn) << "Will not send empty vector"; - // return -1; - // } -} + if (fMultiPartRecvCounter > 0) { + --fMultiPartRecvCounter; + fInflightMultiPartMessage.emplace_back(std::move(msg)); + } -auto Socket::ReceiveImpl(vector& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t -{ - throw SocketError{"Not yet implemented."}; - // int64_t totalSize = 0; - // int64_t more = 0; - // bool repeat = false; - // int elapsed = 0; - // - // while (true) - // { - // // Warn if the vector is filled before Receive() and empty it. - // // if (msgVec.size() > 0) - // // { - // // LOG(warn) << "Message vector contains elements before Receive(), they will be deleted!"; - // // msgVec.clear(); - // // } - // - // totalSize = 0; - // more = 0; - // repeat = false; - // - // do - // { - // FairMQMessagePtr part(new FairMQMessageSHM(fManager, GetTransport())); - // zmq_msg_t* msgPtr = static_cast(part.get())->GetMessage(); - // - // int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); - // if (nbytes == 0) - // { - // msgVec.push_back(move(part)); - // } - // else if (nbytes > 0) - // { - // MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); - // size_t size = 0; - // static_cast(part.get())->fHandle = hdr->fHandle; - // static_cast(part.get())->fSize = hdr->fSize; - // static_cast(part.get())->fRegionId = hdr->fRegionId; - // static_cast(part.get())->fHint = hdr->fHint; - // size = part->GetSize(); - // - // msgVec.push_back(move(part)); - // - // totalSize += size; - // } - // else if (zmq_errno() == EAGAIN) - // { - // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) - // { - // if (timeout) - // { - // elapsed += fSndTimeout; - // if (elapsed >= timeout) - // { - // return -2; - // } - // } - // repeat = true; - // break; - // } - // else - // { - // return -2; - // } - // } - // else - // { - // return nbytes; - // } - // - // size_t more_size = sizeof(more); - // zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); - // } - // while (more); - // - // if (repeat) - // { - // continue; - // } - // - // // store statistics on how many messages have been received (handle all parts as a single message) - // ++fMessagesRx; - // fBytesRx += totalSize; - // return totalSize; - // } + std::unique_lock lk(fRecvQueueMutex); + if (fMultiPartRecvCounter == 0) { + fRecvQueue.push(std::move(fInflightMultiPartMessage)); + fMultiPartRecvCounter = -1; + } else { + std::vector msgVec; + msgVec.push_back(std::move(msg)); + fRecvQueue.push(std::move(msgVec)); + } + lk.unlock(); + + fRecvPopSem.signal(); } auto Socket::Close() -> void {} @@ -805,53 +580,53 @@ int Socket::GetRcvKernelSize() const return 0; } -auto Socket::GetConstant(const string& constant) -> int +auto Socket::GetConstant(const string& /*constant*/) -> int { - if (constant == "") - return 0; - if (constant == "sub") - return ZMQ_SUB; - if (constant == "pub") - return ZMQ_PUB; - if (constant == "xsub") - return ZMQ_XSUB; - if (constant == "xpub") - return ZMQ_XPUB; - if (constant == "push") - return ZMQ_PUSH; - if (constant == "pull") - return ZMQ_PULL; - if (constant == "req") - return ZMQ_REQ; - if (constant == "rep") - return ZMQ_REP; - if (constant == "dealer") - return ZMQ_DEALER; - if (constant == "router") - return ZMQ_ROUTER; - if (constant == "pair") - return ZMQ_PAIR; - - if (constant == "snd-hwm") - return ZMQ_SNDHWM; - if (constant == "rcv-hwm") - return ZMQ_RCVHWM; - if (constant == "snd-size") - return ZMQ_SNDBUF; - if (constant == "rcv-size") - return ZMQ_RCVBUF; - if (constant == "snd-more") - return ZMQ_SNDMORE; - if (constant == "rcv-more") - return ZMQ_RCVMORE; - - if (constant == "linger") - return ZMQ_LINGER; - if (constant == "no-block") - return ZMQ_DONTWAIT; - if (constant == "snd-more no-block") - return ZMQ_DONTWAIT|ZMQ_SNDMORE; - + // if (constant == "") + // return 0; + // if (constant == "sub") + // return ZMQ_SUB; + // if (constant == "pub") + // return ZMQ_PUB; + // if (constant == "xsub") + // return ZMQ_XSUB; + // if (constant == "xpub") + // return ZMQ_XPUB; + // if (constant == "push") + // return ZMQ_PUSH; + // if (constant == "pull") + // return ZMQ_PULL; + // if (constant == "req") + // return ZMQ_REQ; + // if (constant == "rep") + // return ZMQ_REP; + // if (constant == "dealer") + // return ZMQ_DEALER; + // if (constant == "router") + // return ZMQ_ROUTER; + // if (constant == "pair") + // return ZMQ_PAIR; +// + // if (constant == "snd-hwm") + // return ZMQ_SNDHWM; + // if (constant == "rcv-hwm") + // return ZMQ_RCVHWM; + // if (constant == "snd-size") + // return ZMQ_SNDBUF; + // if (constant == "rcv-size") + // return ZMQ_RCVBUF; + // if (constant == "snd-more") + // return ZMQ_SNDMORE; + // if (constant == "rcv-more") + // return ZMQ_RCVMORE; +// + // if (constant == "linger") + // return ZMQ_LINGER; + // if (constant == "no-block") + // return ZMQ_DONTWAIT; + // if (constant == "snd-more no-block") + // return ZMQ_DONTWAIT|ZMQ_SNDMORE; +// return -1; } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 67a742e9..31ecbedf 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -18,10 +18,10 @@ #include #include #include -#include #include #include // unique_ptr -#include +#include + namespace fair { @@ -97,28 +97,22 @@ class Socket final : public fair::mq::Socket Address fLocalAddr; int fSndTimeout; int fRcvTimeout; - azmq::socket fSendQueueWrite, fSendQueueRead; - azmq::socket fRecvQueueWrite, fRecvQueueRead; - asiofi::synchronized_semaphore fSendSem, fRecvSem; + std::mutex fSendQueueMutex, fRecvQueueMutex; + std::queue> fSendQueue, fRecvQueue; + std::vector fInflightMultiPartMessage; + int64_t fMultiPartRecvCounter; + asiofi::synchronized_semaphore fSendPushSem, fSendPopSem, fRecvPushSem, fRecvPopSem; std::atomic fNeedOfiMemoryRegistration; - auto SendQueueReader() -> void; - auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void; - auto RecvControlQueueReader() -> void; - auto OnRecvControl(ofi::unique_ptr ctrl) -> void; - auto OnReceive() -> void; - auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int; - auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; - auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; - - // auto WaitForControlPeer() -> void; - // auto AnnounceDataAddress() -> void; auto InitOfi(Address addr) -> void; auto BindControlEndpoint() -> void; auto BindDataEndpoint() -> void; enum class Band { Control, Data }; auto ConnectEndpoint(std::unique_ptr& endpoint, Band type) -> void; - // auto ReceiveDataAddressAnnouncement() -> void; + auto SendQueueReader() -> void; + auto RecvControlQueueReader() -> void; + auto OnRecvControl(ofi::unique_ptr ctrl) -> void; + auto DataMessageReceived(MessagePtr msg) -> void; }; /* class Socket */ struct SilentSocketError : SocketError { using SocketError::SocketError; };