Implement parallel ofi::Socket::Receive

This commit is contained in:
Dennis Klein
2018-11-20 12:45:46 +01:00
committed by Dennis Klein
parent 46e2420547
commit 8e7cfacd78
5 changed files with 160 additions and 83 deletions

View File

@@ -49,8 +49,11 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
, fControlEndpoint(fIoStrand.context(), ZMQ_PAIR)
, fSndTimeout(100)
, fRcvTimeout(100)
, fQueue1(fIoStrand.context())
, fQueue2(fIoStrand.context())
, fSendQueueWrite(fIoStrand.context(), ZMQ_PUSH)
, fSendQueueRead(fIoStrand.context(), ZMQ_PULL)
, fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH)
, fRecvQueueRead(fIoStrand.context(), ZMQ_PULL)
, fSentCount(0)
{
if (type != "pair") {
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
@@ -63,17 +66,28 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
// Setup internal queue
auto hashed_id = std::hash<std::string>()(fId);
auto queue_id = tools::ToString("inproc://QUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding Q1: " << queue_id;
fQueue1.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting Q2: " << queue_id;
fQueue2.connect(queue_id);
azmq::socket::snd_hwm send_max(100);
azmq::socket::rcv_hwm recv_max(100);
fQueue1.set_option(send_max);
fQueue1.set_option(recv_max);
fQueue2.set_option(send_max);
fQueue2.set_option(recv_max);
auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id;
fSendQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id;
fSendQueueWrite.connect(queue_id);
queue_id = tools::ToString("inproc://RXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id;
fRecvQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id;
fRecvQueueWrite.connect(queue_id);
// TODO wire this up with config
azmq::socket::snd_hwm send_max(10);
azmq::socket::rcv_hwm recv_max(10);
fSendQueueRead.set_option(send_max);
fSendQueueRead.set_option(recv_max);
fSendQueueWrite.set_option(send_max);
fSendQueueWrite.set_option(recv_max);
fRecvQueueRead.set_option(send_max);
fRecvQueueRead.set_option(recv_max);
fSendQueueWrite.set_option(send_max);
fSendQueueWrite.set_option(recv_max);
fControlEndpoint.set_option(send_max);
fControlEndpoint.set_option(recv_max);
}
@@ -90,7 +104,8 @@ try {
fLocalDataAddr = addr;
BindDataEndpoint();
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
return true;
}
@@ -116,6 +131,9 @@ auto Socket::Connect(const string& address) -> bool
ReceiveDataAddressAnnouncement();
ConnectDataEndpoint();
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
}
auto Socket::BindControlEndpoint(Context::Address address) -> void
@@ -225,11 +243,13 @@ auto Socket::AnnounceDataAddress() -> void
auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
{
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: size=" << msg->GetSize();
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize();
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg)));
try {
auto res = fQueue1.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
++fSentCount;
LOG(info) << fSentCount;
auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
LOG(debug) << "OFI transport (" << fId << "): LEAVE Send";
return res;
@@ -244,16 +264,44 @@ auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
}
}
auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return 0; /*ReceiveImpl(msg, 0, timeout);*/ }
auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int
{
LOG(debug) << "OFI transport (" << fId << "): ENTER Receive";
try {
azmq::message zmsg;
auto recv = fRecvQueueRead.receive(zmsg);
size_t size(0);
if (recv > 0) {
msg = std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data())));
size = msg->GetSize();
}
fBytesRx += size;
fMessagesRx++;
LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive";
return size;
} catch (const std::exception& e) {
LOG(error) << e.what();
return -1;
} catch (const boost::system::error_code& e) {
LOG(error) << e;
return -1;
}
}
auto Socket::Send(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); }
auto Socket::Receive(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); }
auto Socket::SendQueueReader() -> void
{
fQueue2.async_receive(boost::asio::bind_executor(
fSendQueueRead.async_receive(boost::asio::bind_executor(
fIoStrand,
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) {
--fSentCount;
OnSend(zmsg, bytes_transferred);
}
}));
@@ -266,7 +314,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
MessagePtr msg(std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data()))));
auto size = msg->GetSize();
LOG(debug) << "OFI transport (" << fId << "): >>>>> OnSend: size=" << size;
LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize();
// Create and send control message
auto pb = MakeControlMessage<PostBuffer>();
@@ -284,7 +332,9 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void
{
LOG(debug) << "OFI transport (" << fId << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred;
LOG(debug) << "OFI transport (" << fId
<< "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred
<< ",data=" << msg->GetData() << ",size=" << msg->GetSize();
assert(bytes_transferred == sizeof(PostBuffer));
auto size = msg->GetSize();
@@ -302,77 +352,92 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v
// received, size_ack=" << size_ack;
boost::asio::mutable_buffer buffer(msg->GetData(), size);
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
// auto desc = mr.desc();
fDataEndpoint->send(buffer, mr.desc(), [&, mr2 = std::move(mr)](boost::asio::mutable_buffer) {
LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent";
fBytesTx += size;
fMessagesTx++;
});
fDataEndpoint->send(
buffer,
// desc,
[&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent";
fBytesTx += size;
fMessagesTx++;
});
}
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent";
}
auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int
try {
LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl";
// Receive and process control message
azmq::message ctrl;
auto recv = fControlEndpoint.receive(ctrl);
assert(recv == sizeof(PostBuffer)); (void)recv;
auto pb(static_cast<const PostBuffer*>(ctrl.data()));
auto Socket::RecvControlQueueReader() -> void
{
fControlEndpoint.async_receive(boost::asio::bind_executor(
fIoStrand,
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) {
OnRecvControl(zmsg, bytes_transferred);
}
}));
}
auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> void
{
LOG(debug) << "OFI transport (" << fId
<< "): ENTER OnRecvControl: bytes_transferred=" << bytes_transferred;
assert(bytes_transferred == sizeof(PostBuffer));
auto pb(static_cast<const PostBuffer*>(zmsg.data()));
assert(pb->type == ControlMessageType::PostBuffer);
auto size = pb->size;
LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size;
LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size;
// Receive data
if (size) {
msg->Rebuild(size);
auto msg = fContext.MakeReceiveMessage(size);
boost::asio::mutable_buffer buffer(msg->GetData(), size);
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
// auto msg33 = fContext.MakeReceiveMessage(size);
// boost::asio::mutable_buffer buffer33(msg33->GetData(), size);
// asiofi::memory_region mr33(fContext.GetDomain(), buffer33, asiofi::mr::access::recv);
// auto desc = mr.desc();
std::mutex m;
std::condition_variable cv;
bool completed(false);
fDataEndpoint->recv(buffer, mr.desc(), [&](boost::asio::mutable_buffer) {
{
std::unique_lock<std::mutex> lk(m);
completed = true;
}
cv.notify_one();
}
);
fDataEndpoint->recv(
buffer,
// desc,
[&, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg2)));
fRecvQueueWrite.async_send(
azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))),
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
if (!ec) {
LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2;
}
});
});
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted";
auto ack = MakeControlMessage<PostBuffer>();
ack.size = size;
auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
assert(sent == sizeof(PostBuffer)); (void)sent;
// auto ack = MakeControlMessage<PostBuffer>();
// ack.size = size;
// auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
// assert(sent == sizeof(PostBuffer)); (void)sent;
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent";
{
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&](){ return completed; });
}
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received";
} else {
fRecvQueueWrite.async_send(
azmq::message(boost::asio::const_buffer(nullptr, 0)),
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
if (!ec) {
LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2;
}
});
}
fBytesRx += size;
fMessagesRx++;
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
// LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl";
return size;
}
catch (const SilentSocketError& e)
{
return -2;
}
catch (const std::exception& e)
{
LOG(error) << e.what();
return -1;
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl";
}
auto Socket::SendImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t
@@ -548,14 +613,14 @@ auto Socket::ReceiveImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags
auto Socket::Close() -> void {}
auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void
auto Socket::SetOption(const string& /*option*/, const void* /*value*/, size_t /*valueSize*/) -> void
{
// if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
// throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))};
// }
}
auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void
auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valueSize*/) -> void
{
// if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
// throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))};