diff --git a/CMakeLists.txt b/CMakeLists.txt index 80d046cb..50a073cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ if(BUILD_OFI_TRANSPORT) ) endif() -if(BUILD_NANOMSG_TRANSPORT OR BUILD_OFI_TRANSPORT) +if(BUILD_NANOMSG_TRANSPORT) find_package2(PRIVATE msgpack REQUIRED VERSION 3.1.0 ) diff --git a/fairmq/CMakeLists.txt b/fairmq/CMakeLists.txt index c037370d..58e9f353 100644 --- a/fairmq/CMakeLists.txt +++ b/fairmq/CMakeLists.txt @@ -219,7 +219,10 @@ if(BUILD_NANOMSG_TRANSPORT) set(NANOMSG_DEPS nanomsg msgpackc-cxx) endif() if(BUILD_OFI_TRANSPORT) - set(OFI_DEPS asiofi::asiofi msgpackc-cxx) + set(OFI_DEPS + asiofi::asiofi + Boost::container + ) endif() set(optional_deps ${NANOMSG_DEPS} ${OFI_DEPS}) if(optional_deps) diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index edc045de..1025e32c 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -51,9 +51,9 @@ auto Context::InitThreadPool(int numberIoThreads) -> void for (int i = 1; i <= numberIoThreads; ++i) { fThreadPool.emplace_back([&, i, numberIoThreads]{ - LOG(debug) << "OFI transport: I/O thread #" << i << "/" << numberIoThreads << " started"; + LOG(debug) << "OFI transport: I/O thread #" << i << " of " << numberIoThreads << " started"; fIoContext.run(); - LOG(debug) << "OFI transport: I/O thread #" << i << "/" << numberIoThreads << " stopped"; + LOG(debug) << "OFI transport: I/O thread #" << i << " of " << numberIoThreads << " stopped"; }); } } @@ -97,12 +97,31 @@ auto Context::InitOfi(ConnectionType type, Address addr) -> void } else { fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); } + LOG(debug) << "OFI transport: " << *fOfiInfo; fOfiFabric = tools::make_unique(*fOfiInfo); fOfiDomain = tools::make_unique(*fOfiFabric); } +auto Context::MakeOfiPassiveEndpoint(Address addr) -> unique_ptr +{ + InitOfi(ConnectionType::Bind, addr); + + return tools::make_unique(fIoContext, *fOfiFabric); +} + +auto Context::MakeOfiConnectedEndpoint(const asiofi::info& info) -> std::unique_ptr +{ + return tools::make_unique(fIoContext, *fOfiDomain, info); +} + +auto Context::MakeOfiConnectedEndpoint(Address addr) -> std::unique_ptr +{ + InitOfi(ConnectionType::Connect, addr); + + return tools::make_unique(fIoContext, *fOfiDomain); +} // auto Context::CreateOfiEndpoint() -> fid_ep* // { // assert(fOfiDomain); diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index e12d1bf4..7d8fa4de 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -9,10 +9,12 @@ #ifndef FAIR_MQ_OFI_CONTEXT_H #define FAIR_MQ_OFI_CONTEXT_H +#include #include #include #include -#include +#include +#include #include #include #include @@ -47,14 +49,16 @@ class Context auto GetZmqVersion() const -> std::string; auto GetAsiofiVersion() const -> std::string; auto GetZmqContext() const -> void* { return fZmqContext; } - auto GetIoContext() -> boost::asio::io_service& { return fIoContext; } + auto GetIoContext() -> boost::asio::io_context& { return fIoContext; } struct Address { std::string Protocol; std::string Ip; unsigned int Port; friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; } }; - auto InitOfi(ConnectionType type, Address address) -> void; + auto MakeOfiPassiveEndpoint(Address addr) -> std::unique_ptr; + auto MakeOfiConnectedEndpoint(Address addr) -> std::unique_ptr; + auto MakeOfiConnectedEndpoint(const asiofi::info& info) -> std::unique_ptr; static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(sockaddr_in address) -> Address; @@ -65,15 +69,24 @@ class Context std::unique_ptr fOfiInfo; std::unique_ptr fOfiFabric; std::unique_ptr fOfiDomain; - boost::asio::io_service fIoContext; - boost::asio::io_service::work fIoWork; + boost::asio::io_context fIoContext; + boost::asio::io_context::work fIoWork; std::vector fThreadPool; auto InitThreadPool(int numberIoThreads) -> void; + auto InitOfi(ConnectionType type, Address address) -> void; }; /* class Context */ struct ContextError : std::runtime_error { using std::runtime_error::runtime_error; }; +template +std::unique_ptr +static_unique_ptr_downcast( std::unique_ptr&& p ) +{ + auto d = static_cast(p.release()); + return std::unique_ptr(d, std::move(p.get_deleter())); +} + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Control.proto b/fairmq/ofi/Control.proto deleted file mode 100644 index 69eee4a7..00000000 --- a/fairmq/ofi/Control.proto +++ /dev/null @@ -1,25 +0,0 @@ -syntax = "proto3"; -option optimize_for = SPEED; - -package fair.mq.ofi; - -message DataAddressAnnouncement { - uint32 ipv4 = 1; // in_addr_t from - uint32 port = 2; // in_port_t from -} - -message PostBuffer { - uint64 size = 1; // buffer size (size_t) -} - -message PostBufferAcknowledgement { - uint64 size = 1; // size_t -} - -message ControlMessage { - oneof type { - DataAddressAnnouncement data_address_announcement = 1; - PostBuffer post_buffer = 2; - PostBufferAcknowledgement post_buffer_acknowledgement = 3; - } -} diff --git a/fairmq/ofi/ControlMessages.h b/fairmq/ofi/ControlMessages.h new file mode 100644 index 00000000..32ff4d31 --- /dev/null +++ b/fairmq/ofi/ControlMessages.h @@ -0,0 +1,84 @@ +/******************************************************************************** + * Copyright (C) 2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * + * * + * This software is distributed under the terms of the * + * GNU Lesser General Public Licence (LGPL) version 3, * + * copied verbatim in the file "LICENSE" * + ********************************************************************************/ + +#ifndef FAIR_MQ_OFI_CONTROLMESSAGES_H +#define FAIR_MQ_OFI_CONTROLMESSAGES_H + +#include +#include +#include +#include + +namespace fair +{ +namespace mq +{ +namespace ofi +{ + +enum class ControlMessageType +{ + DataAddressAnnouncement = 1, + PostBuffer, + PostBufferAcknowledgement +}; + +struct ControlMessage { + ControlMessageType type; +}; + +struct DataAddressAnnouncement : ControlMessage { + uint32_t ipv4; // in_addr_t from + uint32_t port; // in_port_t from +}; + +struct PostBuffer : ControlMessage { + uint64_t size; // buffer size (size_t) +}; + +struct PostBufferAcknowledgement { + uint64_t size; // size_t +}; + +template +using CtrlMsgPtr = std::unique_ptr>; + +template +auto MakeControlMessage(A* pmr, Args&& ... args) -> CtrlMsgPtr +{ + void* raw_mem = pmr->allocate(sizeof(T)); + T* raw_ptr = new (raw_mem) T(std::forward(args)...); + + if (std::is_same::value) { + raw_ptr->type = ControlMessageType::DataAddressAnnouncement; + } + + return {raw_ptr, [=](T* p) { pmr->deallocate(p, sizeof(T)); }}; +} + +template +auto StaticUniquePtrDowncast(std::unique_ptr&& p) -> std::unique_ptr +{ + auto down = static_cast(p.release()); + return std::unique_ptr(down, std::move(p.get_deleter())); +} + +template +auto StaticUniquePtrUpcast(std::unique_ptr&& p) -> std::unique_ptr> +{ + auto up = static_cast(p.release()); + return {up, [deleter = std::move(p.get_deleter())](Base* ptr) { + deleter(static_cast(ptr)); + }}; +} + +} /* namespace ofi */ +} /* namespace mq */ +} /* namespace fair */ + +#endif /* FAIR_MQ_OFI_CONTROLMESSAGES_H */ diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 23569160..f33f9365 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -6,6 +6,7 @@ * copied verbatim in the file "LICENSE" * ********************************************************************************/ +#include #include #include #include @@ -39,6 +40,7 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fId(id + "." + name + "." + type) , fControlSocket(nullptr) , fMonitorSocket(nullptr) + , fPassiveDataEndpoint(nullptr) , fDataEndpoint(nullptr) , fId(id + "." + name + "." + type) , fBytesTx(0) @@ -92,10 +94,16 @@ Socket::Socket(Context& context, const string& type, const string& name, const s auto Socket::Bind(const string& address) -> bool try { auto addr = Context::VerifyAddress(address); + BindControlSocket(addr); - fContext.InitOfi(ConnectionType::Bind, addr); - InitDataEndpoint(); - fWaitingForControlPeer = true; + + // TODO make data port choice more robust + addr.Port += 500; + fLocalDataAddr = addr; + BindDataEndpoint(); + + AnnounceDataAddress(); + return true; } catch (const SilentSocketError& e) @@ -106,18 +114,20 @@ catch (const SilentSocketError& e) } catch (const SocketError& e) { - LOG(error) << e.what(); + LOG(error) << "OFI transport: " << e.what(); return false; } auto Socket::Connect(const string& address) -> bool { auto addr = Context::VerifyAddress(address); + ConnectControlSocket(addr); - fContext.InitOfi(ConnectionType::Connect, addr); - InitDataEndpoint(); - fWaitingForControlPeer = true; - return true; + + ProcessControlMessage( + StaticUniquePtrDowncast(ReceiveControlMessage())); + + ConnectDataEndpoint(); } auto Socket::BindControlSocket(Context::Address address) -> void @@ -128,6 +138,26 @@ auto Socket::BindControlSocket(Context::Address address) -> void if (errno == EADDRINUSE) throw SilentSocketError("EADDRINUSE"); throw SocketError(tools::ToString("Failed binding control socket ", fId, ", reason: ", zmq_strerror(errno))); } + + LOG(debug) << "OFI transport (" << fId << "): control band bound to " << address; +} + +auto Socket::BindDataEndpoint() -> void +{ + assert(!fPassiveDataEndpoint); + assert(!fDataEndpoint); + + fPassiveDataEndpoint = fContext.MakeOfiPassiveEndpoint(fLocalDataAddr); + fPassiveDataEndpoint->listen([&](fid_t /*handle*/, asiofi::info&& info) { + LOG(debug) << "OFI transport (" << fId << "): data band connection request received. Accepting ..."; + fDataEndpoint = fContext.MakeOfiConnectedEndpoint(info); + fDataEndpoint->enable(); + fDataEndpoint->accept([&]() { + LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; + }); + }); + + LOG(debug) << "OFI transport (" << fId << "): data band bound to " << fLocalDataAddr; } auto Socket::ConnectControlSocket(Context::Address address) -> void @@ -138,119 +168,120 @@ auto Socket::ConnectControlSocket(Context::Address address) -> void throw SocketError(tools::ToString("Failed connecting control socket ", fId, ", reason: ", zmq_strerror(errno))); } -// auto Socket::ProcessDataAddressAnnouncement(std::unique_ptr ctrl) -> void -// { - // assert(ctrl->has_data_address_announcement()); - // auto daa = ctrl->data_address_announcement(); -// - // sockaddr_in remoteAddr; - // remoteAddr.sin_family = AF_INET; - // remoteAddr.sin_port = daa.port(); - // remoteAddr.sin_addr.s_addr = daa.ipv4(); -// - // LOG(debug) << "Data address announcement of remote ofi endpoint received: " << Context::ConvertAddress(remoteAddr); - // fRemoteDataAddr = fContext.InsertAddressVector(remoteAddr); -// } - -auto Socket::InitDataEndpoint() -> void +auto Socket::ConnectDataEndpoint() -> void { assert(!fDataEndpoint); - // try { - // fDataEndpoint = fContext.CreateOfiEndpoint(); - // } catch (ContextError& e) { - // throw SocketError(tools::ToString("Failed creating ofi endpoint, reason: ", e.what())); - // } -// - // if (!fDataCompletionQueueTx) - // fDataCompletionQueueTx = fContext.CreateOfiCompletionQueue(Direction::Transmit); - // auto ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueTx->fid, FI_TRANSMIT); - // if (ret != FI_SUCCESS) - // throw SocketError(tools::ToString("Failed binding ofi transmit completion queue to endpoint, reason: ", fi_strerror(ret))); -// - // if (!fDataCompletionQueueRx) - // fDataCompletionQueueRx = fContext.CreateOfiCompletionQueue(Direction::Receive); - // ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueRx->fid, FI_RECV); - // if (ret != FI_SUCCESS) - // throw SocketError(tools::ToString("Failed binding ofi receive completion queue to endpoint, reason: ", fi_strerror(ret))); -// - // ret = fi_enable(fDataEndpoint); - // if (ret != FI_SUCCESS) - // throw SocketError(tools::ToString("Failed enabling ofi endpoint, reason: ", fi_strerror(ret))); + fDataEndpoint = fContext.MakeOfiConnectedEndpoint(fRemoteDataAddr); + fDataEndpoint->enable(); + LOG(debug) << "OFI transport (" << fId << "): local data band address: " << Context::ConvertAddress(fDataEndpoint->get_local_address()); + fDataEndpoint->connect([&]() { + LOG(debug) << "OFI transport (" << fId << "): data band connected."; + }); } -void free_string(void* /*data*/, void* hint) +auto Socket::ProcessControlMessage(CtrlMsgPtr daa) -> void { - delete static_cast(hint); + assert(daa->type == ControlMessageType::DataAddressAnnouncement); + + sockaddr_in remoteAddr; + remoteAddr.sin_family = AF_INET; + remoteAddr.sin_port = daa->port; + remoteAddr.sin_addr.s_addr = daa->ipv4; + + auto addr = Context::ConvertAddress(remoteAddr); + LOG(debug) << "OFI transport (" << fId << "): Data address announcement of remote endpoint received: " << addr; + fRemoteDataAddr = addr; } auto Socket::AnnounceDataAddress() -> void try { - // size_t addrlen = sizeof(sockaddr_in); - // auto ret = fi_getname(&fDataEndpoint->fid, &fLocalDataAddr, &addrlen); - // if (ret != FI_SUCCESS) - // throw SocketError(tools::ToString("Failed retrieving native address from ofi endpoint, reason: ", fi_strerror(ret))); - // assert(addrlen == sizeof(sockaddr_in)); -// + // fLocalDataAddr = fDataEndpoint->get_local_address(); // LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr); - // Create new control message - // auto ctrl = tools::make_unique(); - // auto daa = tools::make_unique(); + // Create new data address announcement message + auto daa = MakeControlMessage(&fCtrlMemPool); + auto addr = Context::ConvertAddress(fLocalDataAddr); + daa->ipv4 = addr.sin_addr.s_addr; + daa->port = addr.sin_port; - // Fill data address announcement - // daa->set_ipv4(fLocalDataAddr.sin_addr.s_addr); - // daa->set_port(fLocalDataAddr.sin_port); + SendControlMessage(StaticUniquePtrUpcast(std::move(daa))); - // Fill control message - // ctrl->set_allocated_data_address_announcement(daa.release()); - // assert(ctrl->IsInitialized()); - - // SendControlMessage(move(ctrl)); + LOG(debug) << "OFI transport (" << fId << "): data address announced."; } catch (const SocketError& e) { throw SocketError(tools::ToString("Failed to announce data address, reason: ", e.what())); } -// auto Socket::SendControlMessage(unique_ptr ctrl) -> void -// { - // assert(fControlSocket); +auto Socket::SendControlMessage(CtrlMsgPtr ctrl) -> void +{ + assert(fControlSocket); // LOG(debug) << "About to send control message: " << ctrl->DebugString(); -// + // Serialize - // string* str = new string(); - // ctrl->SerializeToString(str); - // zmq_msg_t msg; - // auto ret = zmq_msg_init_data(&msg, const_cast(str->c_str()), str->length(), free_string, str); - // assert(ret == 0); -// + struct ZmqMsg + { + zmq_msg_t msg; + ~ZmqMsg() { zmq_msg_close(&msg); } + operator zmq_msg_t*() { return &msg; } + } msg; + + switch (ctrl->type) { + case ControlMessageType::DataAddressAnnouncement: + { + auto ret = zmq_msg_init_size(msg, sizeof(DataAddressAnnouncement)); + (void)ret; + assert(ret == 0); + std::memcpy(zmq_msg_data(msg), ctrl.get(), sizeof(DataAddressAnnouncement)); + } + break; + default: + throw SocketError(tools::ToString("Cannot send control message of unknown type.")); + } + // Send - // if (zmq_msg_send(&msg, fControlSocket, 0) == -1) { - // zmq_msg_close(&msg); - // throw SocketError(tools::ToString("Failed to send control message, reason: ", zmq_strerror(errno))); - // } -// } -// -// auto Socket::ReceiveControlMessage() -> unique_ptr -// { - // assert(fControlSocket); -// + if (zmq_msg_send(msg, fControlSocket, 0) == -1) { + throw SocketError( + tools::ToString("Failed to send control message, reason: ", zmq_strerror(errno))); + } +} + +auto Socket::ReceiveControlMessage() -> CtrlMsgPtr +{ + assert(fControlSocket); + // Receive - // zmq_msg_t msg; - // auto ret = zmq_msg_init(&msg); - // assert(ret == 0); - // if (zmq_msg_recv(&msg, fControlSocket, 0) == -1) { - // zmq_msg_close(&msg); - // throw SocketError(tools::ToString("Failed to receive control message, reason: ", zmq_strerror(errno))); - // } -// - // Deserialize - // auto ctrl = tools::make_unique(); - // ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg)); -// - // zmq_msg_close(&msg); - // LOG(debug) << "Received control message: " << ctrl->DebugString(); - // return ctrl; -// } + struct ZmqMsg + { + zmq_msg_t msg; + ~ZmqMsg() { zmq_msg_close(&msg); } + operator zmq_msg_t*() { return &msg; } + } msg; + auto ret = zmq_msg_init(msg); + (void)ret; + assert(ret == 0); + if (zmq_msg_recv(msg, fControlSocket, 0) == -1) { + throw SocketError( + tools::ToString("Failed to receive control message, reason: ", zmq_strerror(errno))); + } + + // Deserialize and sanity check + const void* msg_data = zmq_msg_data(msg); + const size_t msg_size = zmq_msg_size(msg); + (void)msg_size; + assert(msg_size >= sizeof(ControlMessage)); + + switch (static_cast(msg_data)->type) { + case ControlMessageType::DataAddressAnnouncement: { + assert(msg_size == sizeof(DataAddressAnnouncement)); + auto daa = MakeControlMessage(&fCtrlMemPool); + std::memcpy(daa.get(), msg_data, sizeof(DataAddressAnnouncement)); + // LOG(debug) << "Received control message: " << ctrl->DebugString(); + return StaticUniquePtrUpcast(std::move(daa)); + } + default: + throw SocketError(tools::ToString("Received control message of unknown type.")); + } +} auto Socket::WaitForControlPeer() -> void { @@ -302,12 +333,6 @@ auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return Rec auto Socket::SendImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int try { - if (fWaitingForControlPeer) { - WaitForControlPeer(); - AnnounceDataAddress(); - // ProcessDataAddressAnnouncement(ReceiveControlMessage()); - } - auto size = msg->GetSize(); // Create and send control message @@ -358,7 +383,7 @@ auto Socket::ReceiveImpl(FairMQMessagePtr& /*msg*/, const int /*flags*/, const i try { if (fWaitingForControlPeer) { WaitForControlPeer(); - AnnounceDataAddress(); + // AnnounceDataAddress(); // ProcessDataAddressAnnouncement(ReceiveControlMessage()); } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index fc083f29..71b63583 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -12,9 +12,11 @@ #include #include #include +#include #include #include +#include #include // unique_ptr #include #include @@ -85,6 +87,7 @@ class Socket final : public fair::mq::Socket private: void* fControlSocket; void* fMonitorSocket; + std::unique_ptr fPassiveDataEndpoint; std::unique_ptr fDataEndpoint; std::string fId; std::atomic fBytesTx; @@ -92,10 +95,11 @@ class Socket final : public fair::mq::Socket std::atomic fMessagesTx; std::atomic fMessagesRx; Context& fContext; - fi_addr_t fRemoteDataAddr; - sockaddr_in fLocalDataAddr; + Context::Address fRemoteDataAddr; + Context::Address fLocalDataAddr; bool fWaitingForControlPeer; boost::asio::io_service::strand fIoStrand; + boost::container::pmr::unsynchronized_pool_resource fCtrlMemPool; int fSndTimeout; int fRcvTimeout; @@ -105,19 +109,17 @@ class Socket final : public fair::mq::Socket auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; - auto InitDataEndpoint() -> void; auto WaitForControlPeer() -> void; auto AnnounceDataAddress() -> void; - // auto SendControlMessage(std::unique_ptr ctrl) -> void; - // auto ReceiveControlMessage() -> std::unique_ptr; - // auto ProcessDataAddressAnnouncement(std::unique_ptr ctrl) -> void; + auto SendControlMessage(CtrlMsgPtr ctrl) -> void; + auto ReceiveControlMessage() -> CtrlMsgPtr; + auto ProcessControlMessage(CtrlMsgPtr ctrl) -> void; auto ConnectControlSocket(Context::Address address) -> void; auto BindControlSocket(Context::Address address) -> void; + auto BindDataEndpoint() -> void; + auto ConnectDataEndpoint() -> void; }; /* class Socket */ -// helper function to clean up the object holding the data after it is transported. -void free_string(void* /*data*/, void* hint); - struct SilentSocketError : SocketError { using SocketError::SocketError; }; } /* namespace ofi */ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c9790d0f..c3093000 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,7 +33,15 @@ add_testhelper(runTestDevice ) if(BUILD_NANOMSG_TRANSPORT) - set(definitions DEFINITIONS BUILD_NANOMSG_TRANSPORT) + list(APPEND definitions BUILD_NANOMSG_TRANSPORT) +endif() + +if(BUILD_OFI_TRANSPORT) + LIST(APPEND definitions BUILD_OFI_TRANSPORT) +endif() + +if(definitions) + set(definitions DEFINITIONS ${definitions}) endif() set(MQ_CONFIG "${CMAKE_BINARY_DIR}/test/testsuite_FairMQ.IOPatterns_config.json")