From 672e12f45bcb011fd5d9759b81a036a5884657d0 Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Wed, 21 Nov 2018 00:06:08 +0100 Subject: [PATCH] Add semaphore --- fairmq/ofi/Socket.cxx | 105 +++++++++++++++++++++++++++--------------- fairmq/ofi/Socket.h | 3 +- 2 files changed, 70 insertions(+), 38 deletions(-) diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 547dbfe3..49a96902 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -53,7 +53,8 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fSendQueueRead(fIoStrand.context(), ZMQ_PULL) , fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH) , fRecvQueueRead(fIoStrand.context(), ZMQ_PULL) - , fSentCount(0) + , fSendSem(fIoStrand.context(), 100) + , fRecvSem(fIoStrand.context(), 100) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; @@ -64,19 +65,6 @@ Socket::Socket(Context& context, const string& type, const string& name, const s // Default value for ZeroMQ is -1, which is to wait forever. fControlEndpoint.set_option(azmq::socket::linger(1000)); - // Setup internal queue - auto hashed_id = std::hash()(fId); - auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id; - fSendQueueRead.bind(queue_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id; - fSendQueueWrite.connect(queue_id); - queue_id = tools::ToString("inproc://RXQUEUE", hashed_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id; - fRecvQueueRead.bind(queue_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id; - fRecvQueueWrite.connect(queue_id); - // TODO wire this up with config azmq::socket::snd_hwm send_max(10); azmq::socket::rcv_hwm recv_max(10); @@ -90,6 +78,19 @@ Socket::Socket(Context& context, const string& type, const string& name, const s fSendQueueWrite.set_option(recv_max); fControlEndpoint.set_option(send_max); fControlEndpoint.set_option(recv_max); + + // Setup internal queue + auto hashed_id = std::hash()(fId); + auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id; + fSendQueueRead.bind(queue_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id; + fSendQueueWrite.connect(queue_id); + queue_id = tools::ToString("inproc://RXQUEUE", hashed_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id; + fRecvQueueRead.bind(queue_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id; + fRecvQueueWrite.connect(queue_id); } } @@ -104,7 +105,7 @@ try { fLocalDataAddr = addr; BindDataEndpoint(); - // boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); return true; @@ -132,7 +133,7 @@ auto Socket::Connect(const string& address) -> bool ConnectDataEndpoint(); - // boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); } @@ -243,15 +244,13 @@ auto Socket::AnnounceDataAddress() -> void auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int { - LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize(); + // LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize(); MessagePtr* msgptr(new std::unique_ptr(std::move(msg))); try { - ++fSentCount; - LOG(info) << fSentCount; auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0); - LOG(debug) << "OFI transport (" << fId << "): LEAVE Send"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE Send"; return res; } catch (const std::exception& e) { msg = std::move(*msgptr); @@ -297,24 +296,29 @@ auto Socket::Receive(std::vector& msgVec, const int timeout) -> int6 auto Socket::SendQueueReader() -> void { - fSendQueueRead.async_receive(boost::asio::bind_executor( - fIoStrand, - [&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) { + fSendSem.async_wait( + boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) { if (!ec) { - --fSentCount; - OnSend(zmsg, bytes_transferred); + LOG(debug) << "OFI transport (" << fId << "): < Wait fSendSem=" << fSendSem.get_value(); + fSendQueueRead.async_receive([&](const boost::system::error_code& ec2, + azmq::message& zmsg, + size_t bytes_transferred) { + if (!ec2) { + OnSend(zmsg, bytes_transferred); + } + }); } })); } auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void { - LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred; + // LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred; MessagePtr msg(std::move(*(static_cast(zmsg.buffer().data())))); auto size = msg->GetSize(); - LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); + // LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); // Create and send control message auto pb = MakeControlMessage(); @@ -327,14 +331,14 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void } }); - LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; } auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void { - LOG(debug) << "OFI transport (" << fId - << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred - << ",data=" << msg->GetData() << ",size=" << msg->GetSize(); + // LOG(debug) << "OFI transport (" << fId + // << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred + // << ",data=" << msg->GetData() << ",size=" << msg->GetSize(); assert(bytes_transferred == sizeof(PostBuffer)); auto size = msg->GetSize(); @@ -359,23 +363,39 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v buffer, // desc, [&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { - LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; + // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; fBytesTx += size; fMessagesTx++; + fSendSem.async_signal([&](const boost::system::error_code& ec){ + if (!ec) { + LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); + } + }); }); + } else { + fSendSem.async_signal([&](const boost::system::error_code& ec) { + if (!ec) { + LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); + } + }); } boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); - LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent"; } auto Socket::RecvControlQueueReader() -> void { - fControlEndpoint.async_receive(boost::asio::bind_executor( - fIoStrand, - [&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) { + fRecvSem.async_wait( + boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) { if (!ec) { - OnRecvControl(zmsg, bytes_transferred); + fControlEndpoint.async_receive([&](const boost::system::error_code& ec2, + azmq::message& zmsg, + size_t bytes_transferred) { + if (!ec2) { + OnRecvControl(zmsg, bytes_transferred); + } + }); } })); } @@ -413,6 +433,12 @@ auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> voi LOG(debug) << "OFI transport (" << fId << "): <<<<< Data buffer received, bytes_transferred2=" << bytes_transferred2; + fRecvSem.async_signal([&](const boost::system::error_code& ec2) { + if (!ec2) { + LOG(debug) + << "OFI transport (" << fId << "): < Signal fRecvSem"; + } + }); } }); }); @@ -431,6 +457,11 @@ auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> voi LOG(debug) << "OFI transport (" << fId << "): <<<<< Data buffer received, bytes_transferred2=" << bytes_transferred2; + fRecvSem.async_signal([&](const boost::system::error_code& ec2) { + if (!ec2) { + LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; + } + }); } }); } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 05947201..a09e817b 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include // unique_ptr @@ -94,7 +95,7 @@ class Socket final : public fair::mq::Socket int fRcvTimeout; azmq::socket fSendQueueWrite, fSendQueueRead; azmq::socket fRecvQueueWrite, fRecvQueueRead; - std::atomic fSentCount; + asiofi::semaphore fSendSem, fRecvSem; auto SendQueueReader() -> void; auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void;