Add semaphore

This commit is contained in:
Dennis Klein 2018-11-21 00:06:08 +01:00 committed by Dennis Klein
parent 8e7cfacd78
commit 672e12f45b
2 changed files with 70 additions and 38 deletions

View File

@ -53,7 +53,8 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
, fSendQueueRead(fIoStrand.context(), ZMQ_PULL) , fSendQueueRead(fIoStrand.context(), ZMQ_PULL)
, fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH) , fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH)
, fRecvQueueRead(fIoStrand.context(), ZMQ_PULL) , fRecvQueueRead(fIoStrand.context(), ZMQ_PULL)
, fSentCount(0) , fSendSem(fIoStrand.context(), 100)
, fRecvSem(fIoStrand.context(), 100)
{ {
if (type != "pair") { if (type != "pair") {
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
@ -64,19 +65,6 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
// Default value for ZeroMQ is -1, which is to wait forever. // Default value for ZeroMQ is -1, which is to wait forever.
fControlEndpoint.set_option(azmq::socket::linger(1000)); fControlEndpoint.set_option(azmq::socket::linger(1000));
// Setup internal queue
auto hashed_id = std::hash<std::string>()(fId);
auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id;
fSendQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id;
fSendQueueWrite.connect(queue_id);
queue_id = tools::ToString("inproc://RXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id;
fRecvQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id;
fRecvQueueWrite.connect(queue_id);
// TODO wire this up with config // TODO wire this up with config
azmq::socket::snd_hwm send_max(10); azmq::socket::snd_hwm send_max(10);
azmq::socket::rcv_hwm recv_max(10); azmq::socket::rcv_hwm recv_max(10);
@ -90,6 +78,19 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
fSendQueueWrite.set_option(recv_max); fSendQueueWrite.set_option(recv_max);
fControlEndpoint.set_option(send_max); fControlEndpoint.set_option(send_max);
fControlEndpoint.set_option(recv_max); fControlEndpoint.set_option(recv_max);
// Setup internal queue
auto hashed_id = std::hash<std::string>()(fId);
auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id;
fSendQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id;
fSendQueueWrite.connect(queue_id);
queue_id = tools::ToString("inproc://RXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id;
fRecvQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id;
fRecvQueueWrite.connect(queue_id);
} }
} }
@ -104,7 +105,7 @@ try {
fLocalDataAddr = addr; fLocalDataAddr = addr;
BindDataEndpoint(); BindDataEndpoint();
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
return true; return true;
@ -132,7 +133,7 @@ auto Socket::Connect(const string& address) -> bool
ConnectDataEndpoint(); ConnectDataEndpoint();
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
} }
@ -243,15 +244,13 @@ auto Socket::AnnounceDataAddress() -> void
auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
{ {
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize(); // LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize();
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg))); MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg)));
try { try {
++fSentCount;
LOG(info) << fSentCount;
auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0); auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
LOG(debug) << "OFI transport (" << fId << "): LEAVE Send"; // LOG(debug) << "OFI transport (" << fId << "): LEAVE Send";
return res; return res;
} catch (const std::exception& e) { } catch (const std::exception& e) {
msg = std::move(*msgptr); msg = std::move(*msgptr);
@ -297,24 +296,29 @@ auto Socket::Receive(std::vector<MessagePtr>& msgVec, const int timeout) -> int6
auto Socket::SendQueueReader() -> void auto Socket::SendQueueReader() -> void
{ {
fSendQueueRead.async_receive(boost::asio::bind_executor( fSendSem.async_wait(
fIoStrand, boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) {
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) { if (!ec) {
--fSentCount; LOG(debug) << "OFI transport (" << fId << "): < Wait fSendSem=" << fSendSem.get_value();
OnSend(zmsg, bytes_transferred); fSendQueueRead.async_receive([&](const boost::system::error_code& ec2,
azmq::message& zmsg,
size_t bytes_transferred) {
if (!ec2) {
OnSend(zmsg, bytes_transferred);
}
});
} }
})); }));
} }
auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
{ {
LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred; // LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred;
MessagePtr msg(std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data())))); MessagePtr msg(std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data()))));
auto size = msg->GetSize(); auto size = msg->GetSize();
LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); // LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize();
// Create and send control message // Create and send control message
auto pb = MakeControlMessage<PostBuffer>(); auto pb = MakeControlMessage<PostBuffer>();
@ -327,14 +331,14 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
} }
}); });
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend";
} }
auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void
{ {
LOG(debug) << "OFI transport (" << fId // LOG(debug) << "OFI transport (" << fId
<< "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred // << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred
<< ",data=" << msg->GetData() << ",size=" << msg->GetSize(); // << ",data=" << msg->GetData() << ",size=" << msg->GetSize();
assert(bytes_transferred == sizeof(PostBuffer)); assert(bytes_transferred == sizeof(PostBuffer));
auto size = msg->GetSize(); auto size = msg->GetSize();
@ -359,23 +363,39 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v
buffer, buffer,
// desc, // desc,
[&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { [&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent";
fBytesTx += size; fBytesTx += size;
fMessagesTx++; fMessagesTx++;
fSendSem.async_signal([&](const boost::system::error_code& ec){
if (!ec) {
LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value();
}
});
}); });
} else {
fSendSem.async_signal([&](const boost::system::error_code& ec) {
if (!ec) {
LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value();
}
});
} }
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent"; // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent";
} }
auto Socket::RecvControlQueueReader() -> void auto Socket::RecvControlQueueReader() -> void
{ {
fControlEndpoint.async_receive(boost::asio::bind_executor( fRecvSem.async_wait(
fIoStrand, boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) {
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) { if (!ec) {
OnRecvControl(zmsg, bytes_transferred); fControlEndpoint.async_receive([&](const boost::system::error_code& ec2,
azmq::message& zmsg,
size_t bytes_transferred) {
if (!ec2) {
OnRecvControl(zmsg, bytes_transferred);
}
});
} }
})); }));
} }
@ -413,6 +433,12 @@ auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> voi
LOG(debug) << "OFI transport (" << fId LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2=" << "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2; << bytes_transferred2;
fRecvSem.async_signal([&](const boost::system::error_code& ec2) {
if (!ec2) {
LOG(debug)
<< "OFI transport (" << fId << "): < Signal fRecvSem";
}
});
} }
}); });
}); });
@ -431,6 +457,11 @@ auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> voi
LOG(debug) << "OFI transport (" << fId LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2=" << "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2; << bytes_transferred2;
fRecvSem.async_signal([&](const boost::system::error_code& ec2) {
if (!ec2) {
LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem";
}
});
} }
}); });
} }

View File

@ -15,6 +15,7 @@
#include <fairmq/ofi/ControlMessages.h> #include <fairmq/ofi/ControlMessages.h>
#include <asiofi/connected_endpoint.hpp> #include <asiofi/connected_endpoint.hpp>
#include <asiofi/semaphore.hpp>
#include <azmq/socket.hpp> #include <azmq/socket.hpp>
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include <memory> // unique_ptr #include <memory> // unique_ptr
@ -94,7 +95,7 @@ class Socket final : public fair::mq::Socket
int fRcvTimeout; int fRcvTimeout;
azmq::socket fSendQueueWrite, fSendQueueRead; azmq::socket fSendQueueWrite, fSendQueueRead;
azmq::socket fRecvQueueWrite, fRecvQueueRead; azmq::socket fRecvQueueWrite, fRecvQueueRead;
std::atomic<unsigned long> fSentCount; asiofi::semaphore fSendSem, fRecvSem;
auto SendQueueReader() -> void; auto SendQueueReader() -> void;
auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void; auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void;