Add experimental static size mode for ofi transport

Whenever --ofi-size-hint > 0, the ofi transport does not use the control
band. Multipart is not supported.
This commit is contained in:
Dennis Klein 2019-04-29 20:28:40 +02:00
parent 2457094b6c
commit 3582091b1c
No known key found for this signature in database
GPG Key ID: 08E62D23FA0ECBBC
6 changed files with 118 additions and 7 deletions

View File

@ -37,6 +37,7 @@ Context::Context(FairMQTransportFactory& sendFactory,
: fIoWork(fIoContext) : fIoWork(fIoContext)
, fReceiveFactory(receiveFactory) , fReceiveFactory(receiveFactory)
, fSendFactory(sendFactory) , fSendFactory(sendFactory)
, fSizeHint(0)
{ {
InitThreadPool(numberIoThreads); InitThreadPool(numberIoThreads);
} }

View File

@ -72,6 +72,8 @@ class Context
auto Reset() -> void; auto Reset() -> void;
auto MakeReceiveMessage(size_t size) -> MessagePtr; auto MakeReceiveMessage(size_t size) -> MessagePtr;
auto MakeSendMessage(size_t size) -> MessagePtr; auto MakeSendMessage(size_t size) -> MessagePtr;
auto GetSizeHint() -> size_t { return fSizeHint; }
auto SetSizeHint(size_t size) -> void { fSizeHint = size; }
private: private:
boost::asio::io_context fIoContext; boost::asio::io_context fIoContext;
@ -79,6 +81,7 @@ class Context
std::vector<std::thread> fThreadPool; std::vector<std::thread> fThreadPool;
FairMQTransportFactory& fReceiveFactory; FairMQTransportFactory& fReceiveFactory;
FairMQTransportFactory& fSendFactory; FairMQTransportFactory& fSendFactory;
size_t fSizeHint;
auto InitThreadPool(int numberIoThreads) -> void; auto InitThreadPool(int numberIoThreads) -> void;
}; /* class Context */ }; /* class Context */

View File

@ -154,8 +154,17 @@ auto Socket::BindDataEndpoint() -> void
fDataEndpoint->accept([&]() { fDataEndpoint->accept([&]() {
LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; LOG(debug) << "OFI transport (" << fId << "): data band connection accepted.";
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); if (fContext.GetSizeHint()) {
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); boost::asio::post(fContext.GetIoContext(),
std::bind(&Socket::SendQueueReaderStatic, this));
boost::asio::post(fContext.GetIoContext(),
std::bind(&Socket::RecvQueueReaderStatic, this));
} else {
boost::asio::post(fContext.GetIoContext(),
std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fContext.GetIoContext(),
std::bind(&Socket::RecvControlQueueReader, this));
}
}); });
}); });
@ -174,8 +183,13 @@ try {
ConnectEndpoint(fControlEndpoint, Band::Control); ConnectEndpoint(fControlEndpoint, Band::Control);
ConnectEndpoint(fDataEndpoint, Band::Data); ConnectEndpoint(fDataEndpoint, Band::Data);
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); if (fContext.GetSizeHint()) {
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReaderStatic, this));
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvQueueReaderStatic, this));
} else {
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this));
}
return true; return true;
} }
@ -347,6 +361,57 @@ auto Socket::SendQueueReader() -> void
}); });
} }
auto Socket::SendQueueReaderStatic() -> void
{
fSendPopSem.async_wait([&] {
// Read msg from send queue
std::unique_lock<std::mutex> lk(fSendQueueMutex);
std::vector<MessagePtr> msgVec(std::move(fSendQueue.front()));
fSendQueue.pop();
lk.unlock();
bool postMultiPartStartBuffer = msgVec.size() > 1;
if (postMultiPartStartBuffer) {
throw SocketError{tools::ToString("Multipart API not supported in static size mode.")};
}
MessagePtr& msg = msgVec[0];
// Send data message
const auto size = msg->GetSize();
if (size) {
boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send);
auto desc = mr.desc();
fDataEndpoint->send(buffer,
desc,
[&, size, msg2 = std::move(msg), mr2 = std::move(mr)](
boost::asio::mutable_buffer) mutable {
fBytesTx += size;
fMessagesTx++;
fSendPushSem.signal();
});
} else {
fDataEndpoint->send(
buffer, [&, size, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable {
fBytesTx += size;
fMessagesTx++;
fSendPushSem.signal();
});
}
} else {
++fMessagesTx;
fSendPushSem.signal();
}
boost::asio::dispatch(fContext.GetIoContext(), std::bind(&Socket::SendQueueReaderStatic, this));
});
}
auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int
try { try {
// timeout argument not yet implemented // timeout argument not yet implemented
@ -472,6 +537,42 @@ auto Socket::OnRecvControl(ofi::unique_ptr<ControlMessage> ctrl) -> void
std::bind(&Socket::RecvControlQueueReader, this)); std::bind(&Socket::RecvControlQueueReader, this));
} }
auto Socket::RecvQueueReaderStatic() -> void
{
fRecvPushSem.async_wait([&] {
static size_t size = fContext.GetSizeHint();
// Receive data
auto msg = fContext.MakeReceiveMessage(size);
if (size) {
boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::recv);
auto desc = mr.desc();
fDataEndpoint->recv(buffer,
desc,
[&, msg2 = std::move(msg), mr2 = std::move(mr)](
boost::asio::mutable_buffer) mutable {
DataMessageReceived(std::move(msg2));
});
} else {
fDataEndpoint->recv(
buffer, [&, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable {
DataMessageReceived(std::move(msg2));
});
}
} else {
DataMessageReceived(std::move(msg));
}
boost::asio::dispatch(fContext.GetIoContext(),
std::bind(&Socket::RecvQueueReaderStatic, this));
});
}
auto Socket::DataMessageReceived(MessagePtr msg) -> void auto Socket::DataMessageReceived(MessagePtr msg) -> void
{ {
if (fMultiPartRecvCounter > 0) { if (fMultiPartRecvCounter > 0) {

View File

@ -110,7 +110,9 @@ class Socket final : public fair::mq::Socket
enum class Band { Control, Data }; enum class Band { Control, Data };
auto ConnectEndpoint(std::unique_ptr<asiofi::connected_endpoint>& endpoint, Band type) -> void; auto ConnectEndpoint(std::unique_ptr<asiofi::connected_endpoint>& endpoint, Band type) -> void;
auto SendQueueReader() -> void; auto SendQueueReader() -> void;
auto SendQueueReaderStatic() -> void;
auto RecvControlQueueReader() -> void; auto RecvControlQueueReader() -> void;
auto RecvQueueReaderStatic() -> void;
auto OnRecvControl(ofi::unique_ptr<ControlMessage> ctrl) -> void; auto OnRecvControl(ofi::unique_ptr<ControlMessage> ctrl) -> void;
auto DataMessageReceived(MessagePtr msg) -> void; auto DataMessageReceived(MessagePtr msg) -> void;
}; /* class Socket */ }; /* class Socket */

View File

@ -23,12 +23,15 @@ namespace ofi
using namespace std; using namespace std;
TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* /*config*/) TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* config)
try : FairMQTransportFactory(id) try : FairMQTransportFactory(id)
, fContext(*this, *this, 1) , fContext(*this, *this, 1)
{ {
LOG(debug) << "OFI transport: Using AZMQ & " LOG(debug) << "OFI transport: asiofi (" << fContext.GetAsiofiVersion() << ")";
<< "asiofi (" << fContext.GetAsiofiVersion() << ")";
if (config) {
fContext.SetSizeHint(config->GetValue<size_t>("ofi-size-hint"));
}
} catch (ContextError& e) { } catch (ContextError& e) {
throw TransportFactoryError{e.what()}; throw TransportFactoryError{e.what()};
} }

View File

@ -66,6 +66,7 @@ FairMQProgOptions::FairMQProgOptions()
("print-channels", po::value<bool >()->implicit_value(true), "Print registered channel endpoints in a machine-readable format (<channel name>:<min num subchannels>:<max num subchannels>)") ("print-channels", po::value<bool >()->implicit_value(true), "Print registered channel endpoints in a machine-readable format (<channel name>:<min num subchannels>:<max num subchannels>)")
("shm-segment-size", po::value<size_t >()->default_value(2000000000), "Shared memory: size of the shared memory segment (in bytes).") ("shm-segment-size", po::value<size_t >()->default_value(2000000000), "Shared memory: size of the shared memory segment (in bytes).")
("shm-monitor", po::value<bool >()->default_value(true), "Shared memory: run monitor daemon.") ("shm-monitor", po::value<bool >()->default_value(true), "Shared memory: run monitor daemon.")
("ofi-size-hint", po::value<size_t >()->default_value(0), "EXPERIMENTAL: OFI size hint for the allocator.")
("rate", po::value<float >()->default_value(0.), "Rate for conditional run loop (Hz).") ("rate", po::value<float >()->default_value(0.), "Rate for conditional run loop (Hz).")
("session", po::value<string >()->default_value("default"), "Session name."); ("session", po::value<string >()->default_value("default"), "Session name.");