From b31ab1cc48d8af3d84ce9b4388bc22d22edae045 Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Mon, 26 Nov 2018 00:20:41 +0100 Subject: [PATCH] Implement control band with asiofi --- fairmq/ofi/Context.cxx | 88 ++---- fairmq/ofi/Context.h | 21 +- fairmq/ofi/ControlMessages.h | 23 ++ fairmq/ofi/Socket.cxx | 538 +++++++++++++++++--------------- fairmq/ofi/Socket.h | 26 +- fairmq/ofi/TransportFactory.cxx | 2 +- 6 files changed, 366 insertions(+), 332 deletions(-) diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index 4f64b535..83909d23 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -31,12 +31,15 @@ namespace ofi using namespace std; -Context::Context(FairMQTransportFactory& receiveFactory, int numberIoThreads) +Context::Context(FairMQTransportFactory& sendFactory, + FairMQTransportFactory& receiveFactory, + int numberIoThreads) : fOfiInfo(nullptr) , fOfiFabric(nullptr) , fOfiDomain(nullptr) , fIoWork(fIoContext) , fReceiveFactory(receiveFactory) + , fSendFactory(sendFactory) { InitThreadPool(numberIoThreads); } @@ -66,73 +69,27 @@ auto Context::GetAsiofiVersion() const -> string return ASIOFI_VERSION; } -auto Context::InitOfi(ConnectionType type, Address addr) -> void +auto Context::InitOfi(Address addr) -> void { - assert(!fOfiInfo); - assert(!fOfiFabric); - assert(!fOfiDomain); + if (!fOfiInfo) { + assert(!fOfiFabric); + assert(!fOfiDomain); - asiofi::hints hints; - if (addr.Protocol == "tcp") { - hints.set_provider("sockets"); - } else if (addr.Protocol == "verbs") { - hints.set_provider("verbs"); + asiofi::hints hints; + if (addr.Protocol == "tcp") { + hints.set_provider("sockets"); + } else if (addr.Protocol == "verbs") { + hints.set_provider("verbs"); + } + fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); + // LOG(debug) << "OFI transport: " << *fOfiInfo; + + fOfiFabric = tools::make_unique(*fOfiInfo); + + fOfiDomain = tools::make_unique(*fOfiFabric); } - if (type == ConnectionType::Bind) { - fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints); - } else { - fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); - } - // LOG(debug) << "OFI transport: " << *fOfiInfo; - - fOfiFabric = tools::make_unique(*fOfiInfo); - - fOfiDomain = tools::make_unique(*fOfiFabric); } -auto Context::MakeOfiPassiveEndpoint(Address addr) -> unique_ptr -{ - InitOfi(ConnectionType::Bind, addr); - - return tools::make_unique(fIoContext, *fOfiFabric); -} - -auto Context::MakeOfiConnectedEndpoint(const asiofi::info& info) -> std::unique_ptr -{ - assert(fOfiDomain); - - return tools::make_unique(fIoContext, *fOfiDomain, info); -} - -auto Context::MakeOfiConnectedEndpoint(Address addr) -> std::unique_ptr -{ - InitOfi(ConnectionType::Connect, addr); - - return tools::make_unique(fIoContext, *fOfiDomain); -} -// auto Context::CreateOfiEndpoint() -> fid_ep* -// { - // assert(fOfiDomain); - // assert(fOfiInfo); - // fid_ep* ep = nullptr; - // fi_context ctx; - // auto ret = fi_endpoint(fOfiDomain, fOfiInfo, &ep, &ctx); - // if (ret != FI_SUCCESS) - // throw ContextError{tools::ToString("Failed creating ofi endpoint, reason: ", fi_strerror(ret))}; - - //assert(fOfiEventQueue); - //ret = fi_ep_bind(ep, &fOfiEventQueue->fid, 0); - //if (ret != FI_SUCCESS) - // throw ContextError{tools::ToString("Failed binding ofi event queue to ofi endpoint, reason: ", fi_strerror(ret))}; - - // assert(fOfiAddressVector); - // ret = fi_ep_bind(ep, &fOfiAddressVector->fid, 0); - // if (ret != FI_SUCCESS) - // throw ContextError{tools::ToString("Failed binding ofi address vector to ofi endpoint, reason: ", fi_strerror(ret))}; -// - // return ep; -// } - auto Context::ConvertAddress(std::string address) -> Address { string protocol, ip; @@ -182,6 +139,11 @@ auto Context::MakeReceiveMessage(size_t size) -> MessagePtr return fReceiveFactory.CreateMessage(size); } +auto Context::MakeSendMessage(size_t size) -> MessagePtr +{ + return fSendFactory.CreateMessage(size); +} + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index a45e11af..3dcea79a 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -12,11 +12,9 @@ #include #include -#include #include #include #include -#include #include #include #include @@ -33,9 +31,6 @@ namespace mq namespace ofi { -enum class ConnectionType : bool { Bind, Connect }; -enum class Direction : bool { Receive, Transmit }; - /** * @class Context Context.h * @brief Transport-wide context @@ -45,10 +40,11 @@ enum class Direction : bool { Receive, Transmit }; class Context { public: - Context(FairMQTransportFactory& receiveFactory, int numberIoThreads = 1); + Context(FairMQTransportFactory& sendFactory, + FairMQTransportFactory& receiveFactory, + int numberIoThreads = 1); ~Context(); - // auto CreateOfiEndpoint() -> fid_ep*; auto GetAsiofiVersion() const -> std::string; auto GetIoContext() -> boost::asio::io_context& { return fIoContext; } struct Address { @@ -57,17 +53,18 @@ class Context unsigned int Port; friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; } }; - auto MakeOfiPassiveEndpoint(Address addr) -> std::unique_ptr; - auto MakeOfiConnectedEndpoint(Address addr) -> std::unique_ptr; - auto MakeOfiConnectedEndpoint(const asiofi::info& info) -> std::unique_ptr; + auto InitOfi(Address address) -> void; + auto GetOfiInfo() const -> const asiofi::info& { return *fOfiInfo; } + auto GetOfiFabric() const -> const asiofi::fabric& { return *fOfiFabric; } + auto GetOfiDomain() const -> const asiofi::domain& { return *fOfiDomain; } static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(sockaddr_in address) -> Address; static auto VerifyAddress(const std::string& address) -> Address; - auto GetDomain() const -> const asiofi::domain& { return *fOfiDomain; } auto Interrupt() -> void { LOG(debug) << "OFI transport: Interrupted (NOOP - not implemented)."; } auto Resume() -> void { LOG(debug) << "OFI transport: Resumed (NOOP - not implemented)."; } auto MakeReceiveMessage(size_t size) -> MessagePtr; + auto MakeSendMessage(size_t size) -> MessagePtr; private: std::unique_ptr fOfiInfo; @@ -77,9 +74,9 @@ class Context boost::asio::io_context::work fIoWork; std::vector fThreadPool; FairMQTransportFactory& fReceiveFactory; + FairMQTransportFactory& fSendFactory; auto InitThreadPool(int numberIoThreads) -> void; - auto InitOfi(ConnectionType type, Address address) -> void; }; /* class Context */ struct ContextError : std::runtime_error { using std::runtime_error::runtime_error; }; diff --git a/fairmq/ofi/ControlMessages.h b/fairmq/ofi/ControlMessages.h index 512ee35c..86209a1d 100644 --- a/fairmq/ofi/ControlMessages.h +++ b/fairmq/ofi/ControlMessages.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -55,6 +56,28 @@ struct PostBuffer : ControlMessage uint64_t size; // buffer size (size_t) }; +template +using unique_ptr = std::unique_ptr>; + +template +auto MakeControlMessageWithPmr(boost::container::pmr::memory_resource* pmr, Args&&... args) + -> ofi::unique_ptr +{ + void* mem = pmr->allocate(sizeof(T)); + T* ctrl = new (mem) T(std::forward(args)...); + + if (std::is_same::value) { + ctrl->type = ControlMessageType::DataAddressAnnouncement; + } else if (std::is_same::value) { + ctrl->type = ControlMessageType::PostBuffer; + } + + return ofi::unique_ptr(ctrl, [=](T* p) { + p->~T(); + pmr->deallocate(p, sizeof(T)); + }); +} + template auto MakeControlMessage(Args&&... args) -> T { diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 49a96902..1cc37764 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -38,36 +38,31 @@ using namespace std; Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) : fContext(context) - , fPassiveDataEndpoint(nullptr) + , fPassiveEndpoint(nullptr) , fDataEndpoint(nullptr) + , fControlEndpoint(nullptr) , fId(id + "." + name + "." + type) , fBytesTx(0) , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) , fIoStrand(fContext.GetIoContext()) - , fControlEndpoint(fIoStrand.context(), ZMQ_PAIR) , fSndTimeout(100) , fRcvTimeout(100) , fSendQueueWrite(fIoStrand.context(), ZMQ_PUSH) , fSendQueueRead(fIoStrand.context(), ZMQ_PULL) , fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH) , fRecvQueueRead(fIoStrand.context(), ZMQ_PULL) - , fSendSem(fIoStrand.context(), 100) - , fRecvSem(fIoStrand.context(), 100) + , fSendSem(fIoStrand.context(), 300) + , fRecvSem(fIoStrand.context(), 300) + , fNeedOfiMemoryRegistration(false) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } else { - fControlEndpoint.set_option(azmq::socket::identity(fId)); - - // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. - // Default value for ZeroMQ is -1, which is to wait forever. - fControlEndpoint.set_option(azmq::socket::linger(1000)); - // TODO wire this up with config - azmq::socket::snd_hwm send_max(10); - azmq::socket::rcv_hwm recv_max(10); + azmq::socket::snd_hwm send_max(300); + azmq::socket::rcv_hwm recv_max(300); fSendQueueRead.set_option(send_max); fSendQueueRead.set_option(recv_max); fSendQueueWrite.set_option(send_max); @@ -76,8 +71,6 @@ Socket::Socket(Context& context, const string& type, const string& name, const s fRecvQueueRead.set_option(recv_max); fSendQueueWrite.set_option(send_max); fSendQueueWrite.set_option(recv_max); - fControlEndpoint.set_option(send_max); - fControlEndpoint.set_option(recv_max); // Setup internal queue auto hashed_id = std::hash()(fId); @@ -94,22 +87,25 @@ Socket::Socket(Context& context, const string& type, const string& name, const s } } -auto Socket::Bind(const string& address) -> bool +auto Socket::Bind(const string& addr) -> bool try { - auto addr = Context::VerifyAddress(address); + fLocalAddr = Context::VerifyAddress(addr); + if (fLocalAddr.Protocol == "verbs") { + fNeedOfiMemoryRegistration = true; + } - BindControlEndpoint(addr); + fContext.InitOfi(fLocalAddr); + + fPassiveEndpoint = tools::make_unique(fIoStrand.context(), fContext.GetOfiFabric()); + fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr)); + + BindControlEndpoint(); - // TODO make data port choice more robust - addr.Port += 555; - fLocalDataAddr = addr; BindDataEndpoint(); - boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); - boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); - return true; } +// TODO catch the correct ofi error catch (const SilentSocketError& e) { // do not print error in this case, this is handled by FairMQDevice @@ -122,14 +118,55 @@ catch (const SocketError& e) return false; } -auto Socket::Connect(const string& address) -> bool +auto Socket::BindControlEndpoint() -> void { - auto addr = Context::VerifyAddress(address); - fRemoteDataAddr = addr; + assert(!fControlEndpoint); - ConnectControlEndpoint(addr); + fPassiveEndpoint->listen([&](fid_t /*handle*/, asiofi::info&& info) { + LOG(debug) << "OFI transport (" << fId + << "): control band connection request received. Accepting ..."; + fControlEndpoint = tools::make_unique( + fIoStrand.context(), fContext.GetOfiDomain(), info); + fControlEndpoint->enable(); + fControlEndpoint->accept([&]() { + LOG(debug) << "OFI transport (" << fId << "): control band connection accepted."; + }); + }); - ReceiveDataAddressAnnouncement(); + LOG(debug) << "OFI transport (" << fId << "): control band bound to " << fLocalAddr; +} + +auto Socket::BindDataEndpoint() -> void +{ + assert(!fDataEndpoint); + + fPassiveEndpoint->listen([&](fid_t /*handle*/, asiofi::info&& info) { + LOG(debug) << "OFI transport (" << fId + << "): data band connection request received. Accepting ..."; + fDataEndpoint = tools::make_unique( + fIoStrand.context(), fContext.GetOfiDomain(), info); + fDataEndpoint->enable(); + fDataEndpoint->accept([&]() { + LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; + + boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); + }); + }); + + LOG(debug) << "OFI transport (" << fId << "): data band bound to " << fLocalAddr; +} + +auto Socket::Connect(const string& address) -> void +{ + fRemoteAddr = Context::VerifyAddress(address); + if (fRemoteAddr.Protocol == "verbs") { + fNeedOfiMemoryRegistration = true; + } + + fContext.InitOfi(fRemoteAddr); + + ConnectControlEndpoint(); ConnectDataEndpoint(); @@ -137,110 +174,101 @@ auto Socket::Connect(const string& address) -> bool boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); } -auto Socket::BindControlEndpoint(Context::Address address) -> void +auto Socket::ConnectControlEndpoint() -> void { - auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); - - fControlEndpoint.bind(addr); - // if (zmq_bind(fControlSocket, addr.c_str()) != 0) { - // TODO if (errno == EADDRINUSE) throw SilentSocketError("EADDRINUSE"); - // throw SocketError(tools::ToString("Failed binding control socket ", fId, ", reason: ", zmq_strerror(errno))); - // } - - LOG(debug) << "OFI transport (" << fId << "): control band bound to " << address; -} - -auto Socket::BindDataEndpoint() -> void -{ - assert(!fPassiveDataEndpoint); - assert(!fDataEndpoint); + assert(!fControlEndpoint); std::mutex m; std::condition_variable cv; bool completed(false); - fPassiveDataEndpoint = fContext.MakeOfiPassiveEndpoint(fLocalDataAddr); - fPassiveDataEndpoint->listen([&](fid_t /*handle*/, asiofi::info&& info) { - LOG(debug) << "OFI transport (" << fId << "): data band connection request received. Accepting ..."; - fDataEndpoint = fContext.MakeOfiConnectedEndpoint(info); - fDataEndpoint->enable(); - fDataEndpoint->accept([&]() { - { - std::unique_lock lk(m); - completed = true; - } - cv.notify_one(); - }); + fControlEndpoint = + tools::make_unique(fIoStrand.context(), fContext.GetOfiDomain()); + fControlEndpoint->enable(); + + fControlEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { + { + std::unique_lock lk(m); + completed = true; + } + cv.notify_one(); }); - LOG(debug) << "OFI transport (" << fId << "): data band bound to " << fLocalDataAddr; - - AnnounceDataAddress(); + LOG(debug) << "OFI transport (" << fId << "): control band connection request sent to " + << fRemoteAddr; { std::unique_lock lk(m); - cv.wait(lk, [&](){ return completed; }); + cv.wait(lk, [&]() { return completed; }); } - LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; -} - -auto Socket::ConnectControlEndpoint(Context::Address address) -> void -{ - auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); - - fControlEndpoint.connect(addr); - - LOG(debug) << "OFI transport (" << fId << "): control band connected to " << address; + LOG(debug) << "OFI transport (" << fId << "): control band connected."; } auto Socket::ConnectDataEndpoint() -> void { assert(!fDataEndpoint); - fDataEndpoint = fContext.MakeOfiConnectedEndpoint(fRemoteDataAddr); + std::mutex m; + std::condition_variable cv; + bool completed(false); + + fDataEndpoint = + tools::make_unique(fIoStrand.context(), fContext.GetOfiDomain()); fDataEndpoint->enable(); - LOG(debug) << "OFI transport (" << fId << "): local data band address: " << Context::ConvertAddress(fDataEndpoint->get_local_address()); - fDataEndpoint->connect([&]() { - LOG(debug) << "OFI transport (" << fId << "): data band connected."; + fDataEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { + { + std::unique_lock lk(m); + completed = true; + } + cv.notify_one(); }); + + LOG(debug) << "OFI transport (" << fId << "): data band connection request sent to " + << fRemoteAddr; + + { + std::unique_lock lk(m); + cv.wait(lk, [&]() { return completed; }); + } + LOG(debug) << "OFI transport (" << fId << "): data band connected."; } -auto Socket::ReceiveDataAddressAnnouncement() -> void -{ - azmq::message ctrl; - auto recv = fControlEndpoint.receive(ctrl); - assert(recv == sizeof(DataAddressAnnouncement)); (void)recv; - auto daa(static_cast(ctrl.data())); - assert(daa->type == ControlMessageType::DataAddressAnnouncement); - - sockaddr_in remoteAddr; - remoteAddr.sin_family = AF_INET; - remoteAddr.sin_port = daa->port; - remoteAddr.sin_addr.s_addr = daa->ipv4; - - auto addr = Context::ConvertAddress(remoteAddr); - addr.Protocol = fRemoteDataAddr.Protocol; - LOG(debug) << "OFI transport (" << fId << "): Data address announcement of remote endpoint received: " << addr; - fRemoteDataAddr = addr; -} - -auto Socket::AnnounceDataAddress() -> void -{ +// auto Socket::ReceiveDataAddressAnnouncement() -> void +// { + // azmq::message ctrl; + // auto recv = fControlEndpoint.receive(ctrl); + // assert(recv == sizeof(DataAddressAnnouncement)); (void)recv; + // auto daa(static_cast(ctrl.data())); + // assert(daa->type == ControlMessageType::DataAddressAnnouncement); +// + // sockaddr_in remoteAddr; + // remoteAddr.sin_family = AF_INET; + // remoteAddr.sin_port = daa->port; + // remoteAddr.sin_addr.s_addr = daa->ipv4; +// + // auto addr = Context::ConvertAddress(remoteAddr); + // addr.Protocol = fRemoteDataAddr.Protocol; + // LOG(debug) << "OFI transport (" << fId << "): Data address announcement of remote endpoint received: " << addr; + // fRemoteDataAddr = addr; +// } +// +// auto Socket::AnnounceDataAddress() -> void +// { // fLocalDataAddr = fDataEndpoint->get_local_address(); // LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr); - +// // Create new data address announcement message - auto daa = MakeControlMessage(); - auto addr = Context::ConvertAddress(fLocalDataAddr); - daa.ipv4 = addr.sin_addr.s_addr; - daa.port = addr.sin_port; - - auto sent = fControlEndpoint.send(boost::asio::buffer(daa)); - assert(sent == sizeof(addr)); (void)sent; - - LOG(debug) << "OFI transport (" << fId << "): data band address " << fLocalDataAddr << " announced."; -} + // auto daa = MakeControlMessage(); + // auto addr = Context::ConvertAddress(fLocalDataAddr); + // daa.ipv4 = addr.sin_addr.s_addr; + // daa.port = addr.sin_port; +// + // auto sent = fControlEndpoint.send(boost::asio::buffer(daa)); + // assert(sent == sizeof(addr)); (void)sent; +// + // LOG(debug) << "OFI transport (" << fId << "): data band address " << fLocalDataAddr << " announced."; +// } auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int { @@ -265,7 +293,7 @@ auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int { - LOG(debug) << "OFI transport (" << fId << "): ENTER Receive"; + // LOG(debug) << "OFI transport (" << fId << "): ENTER Receive"; try { azmq::message zmsg; @@ -280,7 +308,7 @@ auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int fBytesRx += size; fMessagesRx++; - LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive"; return size; } catch (const std::exception& e) { LOG(error) << e.what(); @@ -299,7 +327,7 @@ auto Socket::SendQueueReader() -> void fSendSem.async_wait( boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) { if (!ec) { - LOG(debug) << "OFI transport (" << fId << "): < Wait fSendSem=" << fSendSem.get_value(); + // LOG(debug) << "OFI transport (" << fId << "): < Wait fSendSem=" << fSendSem.get_value(); fSendQueueRead.async_receive([&](const boost::system::error_code& ec2, azmq::message& zmsg, size_t bytes_transferred) { @@ -311,7 +339,7 @@ auto Socket::SendQueueReader() -> void })); } -auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void +auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void { // LOG(debug) << "OFI transport (" << fId << "): ENTER OnSend: bytes_transferred=" << bytes_transferred; @@ -321,67 +349,72 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void // LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); // Create and send control message - auto pb = MakeControlMessage(); - pb.size = size; - fControlEndpoint.async_send( - azmq::message(boost::asio::buffer(pb)), - [&, msg2 = std::move(msg)](const boost::system::error_code& ec, size_t bytes_transferred2) mutable { - if (!ec) { - OnControlMessageSent(bytes_transferred2, std::move(msg2)); - } - }); - - // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; -} - -auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void -{ - // LOG(debug) << "OFI transport (" << fId - // << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred - // << ",data=" << msg->GetData() << ",size=" << msg->GetSize(); - assert(bytes_transferred == sizeof(PostBuffer)); - - auto size = msg->GetSize(); - - if (size) { - // Receive ack - // azmq::message ctrl; - // auto recv = fControlEndpoint.receive(ctrl); - // assert(recv == sizeof(PostBuffer)); - // (void)recv; - // auto ack(static_cast(ctrl.data())); - // assert(ack->type == ControlMessageType::PostBuffer); - // (void)ack; - // LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Control ack - // received, size_ack=" << size_ack; - - boost::asio::mutable_buffer buffer(msg->GetData(), size); - // asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send); - // auto desc = mr.desc(); - - fDataEndpoint->send( - buffer, - // desc, - [&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { - // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; - fBytesTx += size; - fMessagesTx++; - fSendSem.async_signal([&](const boost::system::error_code& ec){ - if (!ec) { - LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); - } - }); + auto ctrl = MakeControlMessageWithPmr(&fControlMemPool); + ctrl->size = size; + auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); + if (fNeedOfiMemoryRegistration) { + asiofi::memory_region mr(fContext.GetOfiDomain(), ctrl_msg, asiofi::mr::access::send); + auto desc = mr.desc(); + fControlEndpoint->send( + ctrl_msg, desc, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { + // LOG(debug) << "OFI transport (" << fId << "): >>>>> Control message sent"; }); } else { + fControlEndpoint->send(ctrl_msg, + [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { + // LOG(debug) << "OFI transport (" << fId << "): >>>>> Control + // message sent"; + }); + } + + if (size) { + boost::asio::mutable_buffer buffer(msg->GetData(), size); + + if (fNeedOfiMemoryRegistration) { + asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::send); + auto desc = mr.desc(); + + fDataEndpoint->send(buffer, + desc, + [&, size, msg2 = std::move(msg), mr2 = std::move(mr)]( + boost::asio::mutable_buffer) mutable { + // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data + // buffer sent"; + fBytesTx += size; + fMessagesTx++; + fSendSem.async_signal([&](const boost::system::error_code& ec) { + if (!ec) { + // LOG(debug) << "OFI transport (" << fId << "): > + // Signal fSendSem=" << fSendSem.get_value(); + } + }); + }); + + } else { + fDataEndpoint->send( + buffer, [&, size, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { + // LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; + fBytesTx += size; + fMessagesTx++; + fSendSem.async_signal([&](const boost::system::error_code& ec) { + if (!ec) { + // LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" + // << fSendSem.get_value(); + } + }); + }); + } + } else { + ++fMessagesTx; fSendSem.async_signal([&](const boost::system::error_code& ec) { if (!ec) { - LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); + // LOG(debug) << "OFI transport (" << fId << "): > Signal fSendSem=" << fSendSem.get_value(); } }); } boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); - // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnSend"; } auto Socket::RecvControlQueueReader() -> void @@ -389,77 +422,89 @@ auto Socket::RecvControlQueueReader() -> void fRecvSem.async_wait( boost::asio::bind_executor(fIoStrand, [&](const boost::system::error_code& ec) { if (!ec) { - fControlEndpoint.async_receive([&](const boost::system::error_code& ec2, - azmq::message& zmsg, - size_t bytes_transferred) { - if (!ec2) { - OnRecvControl(zmsg, bytes_transferred); - } + auto ctrl = MakeControlMessageWithPmr(&fControlMemPool); + auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); + + fControlEndpoint->recv(ctrl_msg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { + OnRecvControl(std::move(ctrl2)); }); } })); } -auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> void +auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void { - LOG(debug) << "OFI transport (" << fId - << "): ENTER OnRecvControl: bytes_transferred=" << bytes_transferred; + // LOG(debug) << "OFI transport (" << fId << "): ENTER OnRecvControl"; - assert(bytes_transferred == sizeof(PostBuffer)); - auto pb(static_cast(zmsg.data())); - assert(pb->type == ControlMessageType::PostBuffer); - auto size = pb->size; - LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size; + auto size = ctrl->size; + // LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size; // Receive data if (size) { auto msg = fContext.MakeReceiveMessage(size); boost::asio::mutable_buffer buffer(msg->GetData(), size); - // asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv); - // auto msg33 = fContext.MakeReceiveMessage(size); - // boost::asio::mutable_buffer buffer33(msg33->GetData(), size); - // asiofi::memory_region mr33(fContext.GetDomain(), buffer33, asiofi::mr::access::recv); - // auto desc = mr.desc(); - fDataEndpoint->recv( - buffer, - // desc, - [&, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { - MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); - fRecvQueueWrite.async_send( - azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), - [&](const boost::system::error_code& ec, size_t bytes_transferred2) { - if (!ec) { - LOG(debug) << "OFI transport (" << fId - << "): <<<<< Data buffer received, bytes_transferred2=" - << bytes_transferred2; - fRecvSem.async_signal([&](const boost::system::error_code& ec2) { - if (!ec2) { - LOG(debug) - << "OFI transport (" << fId << "): < Signal fRecvSem"; - } - }); - } - }); - }); - // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted"; + if (fNeedOfiMemoryRegistration) { + asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::recv); + auto desc = mr.desc(); - // auto ack = MakeControlMessage(); - // ack.size = size; - // auto sent = fControlEndpoint.send(boost::asio::buffer(ack)); - // assert(sent == sizeof(PostBuffer)); (void)sent; - // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent"; + fDataEndpoint->recv( + buffer, + desc, + [&, msg2 = std::move(msg), mr2 = std::move(mr)]( + boost::asio::mutable_buffer) mutable { + MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); + fRecvQueueWrite.async_send( + azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), + [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { + if (!ec) { + // LOG(debug) << "OFI transport (" << fId + // << "): <<<<< Data buffer received, bytes_transferred2=" + // << bytes_transferred2; + fRecvSem.async_signal([&](const boost::system::error_code& ec2) { + if (!ec2) { + // LOG(debug) + // << "OFI transport (" << fId << "): < Signal + // fRecvSem"; + } + }); + } + }); + }); + + } else { + fDataEndpoint->recv( + buffer, [&, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { + MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); + fRecvQueueWrite.async_send( + azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), + [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { + if (!ec) { + // LOG(debug) << "OFI transport (" << fId + // << "): <<<<< Data buffer received, bytes_transferred2=" + // << bytes_transferred2; + fRecvSem.async_signal([&](const boost::system::error_code& ec2) { + if (!ec2) { + // LOG(debug) + // << "OFI transport (" << fId << "): < Signal + // fRecvSem"; + } + }); + } + }); + }); + } } else { fRecvQueueWrite.async_send( azmq::message(boost::asio::const_buffer(nullptr, 0)), - [&](const boost::system::error_code& ec, size_t bytes_transferred2) { + [&](const boost::system::error_code& ec, size_t /*bytes_transferred2*/) { if (!ec) { - LOG(debug) << "OFI transport (" << fId - << "): <<<<< Data buffer received, bytes_transferred2=" - << bytes_transferred2; + // LOG(debug) << "OFI transport (" << fId + // << "): <<<<< Data buffer received, bytes_transferred2=" + // << bytes_transferred2; fRecvSem.async_signal([&](const boost::system::error_code& ec2) { if (!ec2) { - LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; + // LOG(debug) << "OFI transport (" << fId << "): < Signal fRecvSem"; } }); } @@ -468,7 +513,7 @@ auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> voi boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); - LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl"; } auto Socket::SendImpl(vector& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t @@ -658,69 +703,74 @@ auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valu // } } -void Socket::SetLinger(const int value) +void Socket::SetLinger(const int /*value*/) { - azmq::socket::linger opt(value); - fControlEndpoint.set_option(opt); + // azmq::socket::linger opt(value); + // fControlEndpoint.set_option(opt); } int Socket::GetLinger() const { - azmq::socket::linger opt(0); - fControlEndpoint.get_option(opt); - return opt.value(); + // azmq::socket::linger opt(0); + // fControlEndpoint.get_option(opt); + // return opt.value(); + return 0; } -void Socket::SetSndBufSize(const int value) +void Socket::SetSndBufSize(const int /*value*/) { - azmq::socket::snd_hwm opt(value); - fControlEndpoint.set_option(opt); + // azmq::socket::snd_hwm opt(value); + // fControlEndpoint.set_option(opt); } int Socket::GetSndBufSize() const { - azmq::socket::snd_hwm opt(0); - fControlEndpoint.get_option(opt); - return opt.value(); + // azmq::socket::snd_hwm opt(0); + // fControlEndpoint.get_option(opt); + // return opt.value(); + return 0; } -void Socket::SetRcvBufSize(const int value) +void Socket::SetRcvBufSize(const int /*value*/) { - azmq::socket::rcv_hwm opt(value); - fControlEndpoint.set_option(opt); + // azmq::socket::rcv_hwm opt(value); + // fControlEndpoint.set_option(opt); } int Socket::GetRcvBufSize() const { - azmq::socket::rcv_hwm opt(0); - fControlEndpoint.get_option(opt); - return opt.value(); + // azmq::socket::rcv_hwm opt(0); + // fControlEndpoint.get_option(opt); + // return opt.value(); + return 0; } -void Socket::SetSndKernelSize(const int value) +void Socket::SetSndKernelSize(const int /*value*/) { - azmq::socket::snd_buf opt(value); - fControlEndpoint.set_option(opt); + // azmq::socket::snd_buf opt(value); + // fControlEndpoint.set_option(opt); } int Socket::GetSndKernelSize() const { - azmq::socket::snd_buf opt(0); - fControlEndpoint.get_option(opt); - return opt.value(); + // azmq::socket::snd_buf opt(0); + // fControlEndpoint.get_option(opt); + // return opt.value(); + return 0; } -void Socket::SetRcvKernelSize(const int value) +void Socket::SetRcvKernelSize(const int /*value*/) { - azmq::socket::rcv_buf opt(value); - fControlEndpoint.set_option(opt); + // azmq::socket::rcv_buf opt(value); + // fControlEndpoint.set_option(opt); } int Socket::GetRcvKernelSize() const { - azmq::socket::rcv_buf opt(0); - fControlEndpoint.get_option(opt); - return opt.value(); + // azmq::socket::rcv_buf opt(0); + // fControlEndpoint.get_option(opt); + // return opt.value(); + return 0; } auto Socket::GetConstant(const string& constant) -> int diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index a09e817b..01f22280 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -15,6 +15,8 @@ #include #include +#include +#include #include #include #include @@ -51,7 +53,7 @@ class Socket final : public fair::mq::Socket auto Send(std::vector& msgVec, int timeout = 0) -> int64_t override; auto Receive(std::vector& msgVec, int timeout = 0) -> int64_t override; - auto GetSocket() const -> void* { return fControlEndpoint.native_handle(); } + auto GetSocket() const -> void* { return nullptr; } void SetLinger(const int value) override; int GetLinger() const override; @@ -80,40 +82,40 @@ class Socket final : public fair::mq::Socket private: Context& fContext; - std::unique_ptr fPassiveDataEndpoint; - std::unique_ptr fDataEndpoint; + std::unique_ptr fPassiveEndpoint; + std::unique_ptr fDataEndpoint, fControlEndpoint; std::string fId; std::atomic fBytesTx; std::atomic fBytesRx; std::atomic fMessagesTx; std::atomic fMessagesRx; - Context::Address fRemoteDataAddr; - Context::Address fLocalDataAddr; + Context::Address fRemoteAddr; + Context::Address fLocalAddr; boost::asio::io_service::strand fIoStrand; - mutable azmq::socket fControlEndpoint; int fSndTimeout; int fRcvTimeout; azmq::socket fSendQueueWrite, fSendQueueRead; azmq::socket fRecvQueueWrite, fRecvQueueRead; asiofi::semaphore fSendSem, fRecvSem; + asiofi::allocated_pool_resource fControlMemPool; + std::atomic fNeedOfiMemoryRegistration; auto SendQueueReader() -> void; auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void; - auto OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void; auto RecvControlQueueReader() -> void; - auto OnRecvControl(azmq::message& msg, size_t bytes_transferred) -> void; + auto OnRecvControl(ofi::unique_ptr ctrl) -> void; auto OnReceive() -> void; auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int; auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; // auto WaitForControlPeer() -> void; - auto AnnounceDataAddress() -> void; - auto ConnectControlEndpoint(Context::Address address) -> void; - auto BindControlEndpoint(Context::Address address) -> void; + // auto AnnounceDataAddress() -> void; + auto BindControlEndpoint() -> void; auto BindDataEndpoint() -> void; + auto ConnectControlEndpoint() -> void; auto ConnectDataEndpoint() -> void; - auto ReceiveDataAddressAnnouncement() -> void; + // auto ReceiveDataAddressAnnouncement() -> void; }; /* class Socket */ struct SilentSocketError : SocketError { using SocketError::SocketError; }; diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 7f76c794..763f4847 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -25,7 +25,7 @@ using namespace std; TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* /*config*/) try : FairMQTransportFactory(id) - , fContext(*this, 1) + , fContext(*this, *this, 1) { LOG(debug) << "OFI transport: Using AZMQ & " << "asiofi (" << fContext.GetAsiofiVersion() << ")";