diff --git a/fairmq/CMakeLists.txt b/fairmq/CMakeLists.txt index d763c0cd..91597c06 100644 --- a/fairmq/CMakeLists.txt +++ b/fairmq/CMakeLists.txt @@ -199,10 +199,28 @@ configure_file(${CMAKE_SOURCE_DIR}/fairmq/options/startConfigExample.sh.in ${CMAKE_BINARY_DIR}/bin/startConfigExample.sh) +######################## +# compile protobuffers # +######################## +add_custom_target(mkofibuilddir COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/ofi) +add_custom_command( + OUTPUT + ${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} -I=${CMAKE_CURRENT_SOURCE_DIR}/ofi --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/ofi Control.proto + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS mkofibuilddir +) +set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h PROPERTIES GENERATED TRUE) +set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc PROPERTIES GENERATED TRUE) + + ################################# # define libFairMQ build target # ################################# add_library(FairMQ SHARED + ${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h + ${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc ${FAIRMQ_SOURCE_FILES} ${FAIRMQ_HEADER_FILES} # for IDE integration ) @@ -216,6 +234,7 @@ target_include_directories(FairMQ $ $ $ + $ $ $ ) @@ -243,6 +262,7 @@ target_link_libraries(FairMQ PRIVATE # only libFairMQ links against private dependencies ZeroMQ OFI::libfabric + protobuf::libprotobuf Msgpack $<$:nanomsg> ) diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index 186b54ed..ec7b7d82 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -36,10 +37,15 @@ using namespace std; Context::Context(int numberIoThreads) : fOfiDomain(nullptr) , fOfiFabric(nullptr) + , fOfiInfo(nullptr) + , fOfiAddressVector(nullptr) + , fOfiEventQueue(nullptr) , fZmqContext(zmq_ctx_new()) { if (!fZmqContext) throw ContextError{tools::ToString("Failed creating zmq context, reason: ", zmq_strerror(errno))}; + + GOOGLE_PROTOBUF_VERIFY_VERSION; } Context::~Context() @@ -48,6 +54,12 @@ Context::~Context() LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno); } + if (fOfiEventQueue) { + auto ret = fi_close(&fOfiEventQueue->fid); + if (ret != FI_SUCCESS) + LOG(error) << "Failed closing ofi event queue, reason: " << fi_strerror(ret); + } + if (fOfiAddressVector) { auto ret = fi_close(&fOfiAddressVector->fid); if (ret != FI_SUCCESS) @@ -67,19 +79,24 @@ Context::~Context() } } -auto Context::GetZmqVersion() const -> std::string +auto Context::GetZmqVersion() const -> string { int major, minor, patch; zmq_version(&major, &minor, &patch); return tools::ToString(major, ".", minor, ".", patch); } -auto Context::GetOfiApiVersion() const -> std::string +auto Context::GetOfiApiVersion() const -> string { auto ofi_version{fi_version()}; return tools::ToString(FI_MAJOR(ofi_version), ".", FI_MINOR(ofi_version)); } +auto Context::GetPbVersion() const -> string +{ + return google::protobuf::internal::VersionString(GOOGLE_PROTOBUF_VERSION); +} + auto Context::InitOfi(ConnectionType type, std::string addr) -> void { auto addr2 = ConvertAddress(addr); @@ -93,7 +110,7 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void // Prepare fi_getinfo query unique_ptr ofi_hints(fi_allocinfo(), fi_freeinfo); - ofi_hints->caps = FI_MSG | FI_SOURCE; + ofi_hints->caps = FI_MSG | FI_RMA; ofi_hints->mode = FI_ASYNC_IOV; ofi_hints->addr_format = FI_SOCKADDR_IN; ofi_hints->fabric_attr->prov_name = strdup("sockets"); @@ -105,17 +122,17 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void // ofi_hints->src_addr = sa; // ofi_hints->src_addrlen = sizeof(sockaddr_in); // } else { - ofi_hints->dest_addr = sa; - ofi_hints->dest_addrlen = sizeof(sockaddr_in); + // ofi_hints->dest_addr = sa; + // ofi_hints->dest_addrlen = sizeof(sockaddr_in); // } // Query fi_getinfo for fabric to use - auto res = fi_getinfo(FI_VERSION(1, 5), nullptr, nullptr, 0, ofi_hints.get(), &fOfiInfo); + auto res = fi_getinfo(FI_VERSION(1, 5), strdup(addr2.Ip.c_str()), 0, 0, ofi_hints.get(), &fOfiInfo); if (res != 0) throw ContextError{tools::ToString("Failed querying fi_getinfo, reason: ", fi_strerror(res))}; if (!fOfiInfo) throw ContextError{"Could not find any ofi compatible fabric."}; // for(auto cursor{ofi_info}; cursor->next != nullptr; cursor = cursor->next) { - LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); + // LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); // } // } else { @@ -123,6 +140,7 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void } OpenOfiFabric(); + OpenOfiEventQueue(); OpenOfiDomain(); OpenOfiAddressVector(); } @@ -154,23 +172,39 @@ auto Context::OpenOfiDomain() -> void } } +auto Context::OpenOfiEventQueue() -> void +{ + fi_eq_attr eqAttr = {100, 0, FI_WAIT_UNSPEC, 0, nullptr}; + // size_t size; [> # entries for EQ <] + // uint64_t flags; [> operation flags <] + // enum fi_wait_obj wait_obj; [> requested wait object <] + // int signaling_vector; [> interrupt affinity <] + // struct fid_wait *wait_set; [> optional wait set <] + auto ret = fi_eq_open(fOfiFabric, &eqAttr, &fOfiEventQueue, nullptr); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed opening ofi event queue, reason: ", fi_strerror(ret))}; +} + auto Context::OpenOfiAddressVector() -> void { if (!fOfiAddressVector) { assert(fOfiDomain); fi_av_attr attr = {fOfiInfo->domain_attr->av_type, 0, 1000, 0, nullptr, nullptr, 0}; -// struct fi_av_attr { -// enum fi_av_type type; [> type of AV <] -// int rx_ctx_bits; [> address bits to identify rx ctx <] -// size_t count; [> # entries for AV <] -// size_t ep_per_node; [> # endpoints per fabric address <] -// const char *name; [> system name of AV <] -// void *map_addr; [> base mmap address <] -// uint64_t flags; [> operation flags <] -// }; + // enum fi_av_type type; [> type of AV <] + // int rx_ctx_bits; [> address bits to identify rx ctx <] + // size_t count; [> # entries for AV <] + // size_t ep_per_node; [> # endpoints per fabric address <] + // const char *name; [> system name of AV <] + // void *map_addr; [> base mmap address <] + // uint64_t flags; [> operation flags <] auto ret = fi_av_open(fOfiDomain, &attr, &fOfiAddressVector, nullptr); if (ret != FI_SUCCESS) throw ContextError{tools::ToString("Failed opening ofi address vector, reason: ", fi_strerror(ret))}; + + assert(fOfiEventQueue); + ret = fi_av_bind(fOfiAddressVector, &fOfiEventQueue->fid, 0); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed binding ofi event queue to address vector, reason: ", fi_strerror(ret))}; } else { LOG(debug) << "Ofi address vector already opened. Skipping."; } @@ -185,6 +219,11 @@ auto Context::CreateOfiEndpoint() -> fid_ep* if (ret != FI_SUCCESS) throw ContextError{tools::ToString("Failed creating ofi endpoint, reason: ", fi_strerror(ret))}; + assert(fOfiEventQueue); + ret = fi_ep_bind(ep, &fOfiEventQueue->fid, 0); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed binding ofi address vector to ofi endpoint, reason: ", fi_strerror(ret))}; + assert(fOfiAddressVector); ret = fi_ep_bind(ep, &fOfiAddressVector->fid, 0); if (ret != FI_SUCCESS) @@ -254,6 +293,21 @@ auto Context::ConvertAddress(Address address) -> sockaddr_in return sa; } +auto Context::ConvertAddress(sockaddr_in address) -> Address +{ + return {"tcp", inet_ntoa(address.sin_addr), ntohs(address.sin_port)}; +} + +auto Context::VerifyAddress(const std::string& address) -> Address +{ + auto addr = ConvertAddress(address); + + if (addr.Protocol != "tcp") + throw ContextError("Wrong protocol: Supported protocols are: tcp"); + + return addr; +} + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index 04f26f2e..3993032d 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -37,21 +38,24 @@ class Context Context(int numberIoThreads = 1); ~Context(); - /// Deferred Ofi initialization auto InitOfi(ConnectionType type, std::string address) -> void; auto CreateOfiEndpoint() -> fid_ep*; auto CreateOfiCompletionQueue(Direction dir) -> fid_cq*; auto GetZmqVersion() const -> std::string; auto GetOfiApiVersion() const -> std::string; + auto GetPbVersion() const -> std::string; auto GetZmqContext() const -> void* { return fZmqContext; } auto InsertAddressVector(sockaddr_in address) -> fi_addr_t; struct Address { std::string Protocol; std::string Ip; unsigned int Port; + friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; } }; static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(Address address) -> sockaddr_in; + static auto ConvertAddress(sockaddr_in address) -> Address; + static auto VerifyAddress(const std::string& address) -> Address; private: void* fZmqContext; @@ -59,8 +63,10 @@ class Context fid_fabric* fOfiFabric; fid_domain* fOfiDomain; fid_av* fOfiAddressVector; + fid_eq* fOfiEventQueue; auto OpenOfiFabric() -> void; + auto OpenOfiEventQueue() -> void; auto OpenOfiDomain() -> void; auto OpenOfiAddressVector() -> void; }; /* class Context */ diff --git a/fairmq/ofi/Control.proto b/fairmq/ofi/Control.proto new file mode 100644 index 00000000..ff0be7af --- /dev/null +++ b/fairmq/ofi/Control.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; +option optimize_for = SPEED; + +package fair.mq.ofi; + +message DataAddressAnnouncement { + uint32 ipv4 = 1; // in_addr_t from + uint32 port = 2; // in_port_t from +} + +message ControlMessage { + oneof type { + DataAddressAnnouncement data_address_announcement = 1; + } +} diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 4b6f6eac..0c3598c4 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -35,34 +36,34 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fDataCompletionQueueTx(nullptr) , fDataCompletionQueueRx(nullptr) , fId(id + "." + name + "." + type) - , fMetaSocket(nullptr) + , fControlSocket(nullptr) , fMonitorSocket(nullptr) , fSndTimeout(100) , fRcvTimeout(100) , fContext(context) - , fWaitingForRemoteConnect(false) + , fWaitingForControlPeer(false) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } else { - fMetaSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); + fControlSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); - if (fMetaSocket == nullptr) + if (fControlSocket == nullptr) throw SocketError{tools::ToString("Failed creating zmq meta socket ", fId, ", reason: ", zmq_strerror(errno))}; - if (zmq_setsockopt(fMetaSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) + if (zmq_setsockopt(fControlSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_IDENTITY socket option, reason: ", zmq_strerror(errno))}; // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. // Default value for ZeroMQ is -1, which is to wait forever. int linger = 1000; - if (zmq_setsockopt(fMetaSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) + if (zmq_setsockopt(fControlSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_LINGER socket option, reason: ", zmq_strerror(errno))}; - if (zmq_setsockopt(fMetaSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) + if (zmq_setsockopt(fControlSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_SNDTIMEO socket option, reason: ", zmq_strerror(errno))}; - if (zmq_setsockopt(fMetaSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) + if (zmq_setsockopt(fControlSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_RCVTIMEO socket option, reason: ", zmq_strerror(errno))}; fMonitorSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); @@ -71,7 +72,7 @@ Socket::Socket(Context& context, const string& type, const string& name, const s throw SocketError{tools::ToString("Failed creating zmq monitor socket ", fId, ", reason: ", zmq_strerror(errno))}; auto mon_addr = tools::ToString("inproc://", fId); - if (zmq_socket_monitor(fMetaSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED) < 0) + if (zmq_socket_monitor(fControlSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_CONNECTED) < 0) throw SocketError{tools::ToString("Failed setting up monitor on meta socket, reason: ", zmq_strerror(errno))}; if (zmq_connect(fMonitorSocket, mon_addr.c_str()) != 0) @@ -80,50 +81,65 @@ Socket::Socket(Context& context, const string& type, const string& name, const s } auto Socket::Bind(const string& address) -> bool -{ - auto addr2 = fContext.ConvertAddress(address); - if (addr2.Protocol != "tcp") - throw SocketError("Wrong protocol: Supported protocols are: tcp"); - - if (zmq_bind(fMetaSocket, address.c_str()) != 0) { - if (errno == EADDRINUSE) { - // do not print error in this case, this is handled by FairMQDevice - // in case no connection could be established after trying a number of random ports from a range. - return false; - } - LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); - return false; - } - +try { + auto addr = Context::VerifyAddress(address); + BindControlSocket(addr); fContext.InitOfi(ConnectionType::Bind, address); - - try { - InitDataEndpoint(); - } catch (SocketError& e) { - LOG(error) << e.what(); - return false; - } - - fWaitingForRemoteConnect = true; - + InitDataEndpoint(); + fWaitingForControlPeer = true; return true; } +catch (const SilentSocketError& e) +{ + // do not print error in this case, this is handled by FairMQDevice + // in case no connection could be established after trying a number of random ports from a range. + return false; +} +catch (const SocketError& e) +{ + LOG(error) << e.what(); + return false; +} auto Socket::Connect(const string& address) -> void { - auto addr2 = fContext.ConvertAddress(address); - if (addr2.Protocol != "tcp") - throw SocketError("Wrong protocol: Supported protocols are: tcp"); - - if (zmq_connect(fMetaSocket, address.c_str()) != 0) { - throw SocketError(tools::ToString("Failed connecting socket ", fId, ", reason: ", zmq_strerror(errno))); - } - + auto addr = Context::VerifyAddress(address); + ConnectControlSocket(addr); fContext.InitOfi(ConnectionType::Connect, address); - InitDataEndpoint(); + fWaitingForControlPeer = true; +} - fRemoteAddr = fContext.InsertAddressVector(fContext.ConvertAddress(addr2)); +auto Socket::BindControlSocket(Context::Address address) -> void +{ + auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); + + if (zmq_bind(fControlSocket, addr.c_str()) != 0) { + if (errno == EADDRINUSE) throw SilentSocketError("EADDRINUSE"); + throw SocketError(tools::ToString("Failed binding control socket ", fId, ", reason: ", zmq_strerror(errno))); + } +} + +auto Socket::ConnectControlSocket(Context::Address address) -> void +{ + auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); + + if (zmq_connect(fControlSocket, addr.c_str()) != 0) + throw SocketError(tools::ToString("Failed connecting control socket ", fId, ", reason: ", zmq_strerror(errno))); +} + +auto Socket::ProcessDataAddressAnnouncement(std::unique_ptr ctrl) -> void +{ + assert(ctrl->has_data_address_announcement()); + auto daa = ctrl->data_address_announcement(); + + sockaddr_in remoteAddr; + remoteAddr.sin_family = AF_INET; + remoteAddr.sin_port = daa.port(); + remoteAddr.sin_addr.s_addr = daa.ipv4(); + + LOG(debug) << Context::ConvertAddress(remoteAddr); + fRemoteDataAddr = fContext.InsertAddressVector(remoteAddr); } auto Socket::InitDataEndpoint() -> void @@ -149,13 +165,81 @@ auto Socket::InitDataEndpoint() -> void ret = fi_enable(fDataEndpoint); if (ret != FI_SUCCESS) - throw SocketError(tools::ToString("Failed opening ofi fabric, reason: ", fi_strerror(ret))); + throw SocketError(tools::ToString("Failed enabling ofi endpoint, reason: ", fi_strerror(ret))); } } -auto Socket::WaitForRemoteConnect() -> void +void free_string(void* /*data*/, void* hint) { - assert(fWaitingForRemoteConnect); + delete static_cast(hint); +} + +auto Socket::AnnounceDataAddress() -> void +try { + using namespace google::protobuf; + + size_t addrlen = sizeof(sockaddr_in); + auto ret = fi_getname(&fDataEndpoint->fid, &fLocalDataAddr, &addrlen); + if (ret != FI_SUCCESS) + throw SocketError(tools::ToString("Failed retrieving native address from ofi endpoint, reason: ", fi_strerror(ret))); + assert(addrlen == sizeof(sockaddr_in)); + + LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr); + + // Create new control message + auto ctrl = tools::make_unique(); + auto daa = tools::make_unique(); + + // Fill data address announcement + daa->set_ipv4(fLocalDataAddr.sin_addr.s_addr); + daa->set_port(fLocalDataAddr.sin_port); + + // Fill control message + ctrl->set_allocated_data_address_announcement(daa.release()); + assert(ctrl->IsInitialized()); + + SendControlMessage(move(ctrl)); +} catch (const SocketError& e) { + throw SocketError(tools::ToString("Failed to announce data address, reason: ", e.what())); +} + +auto Socket::SendControlMessage(unique_ptr ctrl) -> void +{ + assert(fControlSocket); + + // Serialize + string* str = new string(); + ctrl->SerializeToString(str); + zmq_msg_t msg; + auto ret = zmq_msg_init_data(&msg, const_cast(str->c_str()), str->length(), free_string, str); + assert(ret == 0); + + // Send + if (zmq_msg_send(&msg, fControlSocket, 0) == -1) + throw SocketError(tools::ToString("Failed to send control message, reason: ", zmq_strerror(errno))); +} + +auto Socket::ReceiveControlMessage() -> unique_ptr +{ + assert(fControlSocket); + + // Receive + zmq_msg_t msg; + auto ret = zmq_msg_init(&msg); + assert(ret == 0); + if (zmq_msg_recv(&msg, fControlSocket, 0) == -1) + throw SocketError(tools::ToString("Failed to receive control message, reason: ", zmq_strerror(errno))); + + // Deserialize + auto ctrl = tools::make_unique(); + ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg)); + + return ctrl; +} + +auto Socket::WaitForControlPeer() -> void +{ + assert(fWaitingForControlPeer); // First frame in message contains event number and value zmq_msg_t msg; @@ -172,21 +256,23 @@ auto Socket::WaitForRemoteConnect() -> void if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); - string localAddress = string(static_cast(zmq_msg_data(&msg)), zmq_msg_size(&msg)); + if (event == ZMQ_EVENT_ACCEPTED) { + // string localAddress = string(static_cast(zmq_msg_data(&msg)), zmq_msg_size(&msg)); + sockaddr_in remoteAddr; + socklen_t addrSize = sizeof(sockaddr_in); + int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize); + if (ret != 0) + throw SocketError(tools::ToString("Failed retrieving remote address, reason: ", strerror(errno))); + string remoteIp(inet_ntoa(remoteAddr.sin_addr)); + int remotePort = ntohs(remoteAddr.sin_port); + LOG(debug) << "Accepted control peer connection from " << remoteIp << ":" << remotePort; + } else if (event == ZMQ_EVENT_CONNECTED) { + LOG(debug) << "Connected successfully to control peer"; + } else { + LOG(debug) << "Unknown monitor event received: " << event << ". Ignoring."; + } - assert(event == ZMQ_EVENT_ACCEPTED); // we only subscribed for this event - - sockaddr_in remoteAddr; - socklen_t addrSize = sizeof(sockaddr_in); - int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize); - if (ret != 0) - throw SocketError(tools::ToString("Failed retrieving peer address, reason: ", strerror(errno))); - string remoteIp(inet_ntoa(remoteAddr.sin_addr)); - int remotePort = ntohs(remoteAddr.sin_port); - LOG(debug) << "peer connected from " << remoteIp << ":" << remotePort << " at " << localAddress; - - fRemoteAddr = fContext.InsertAddressVector(remoteAddr); - fWaitingForRemoteConnect = false; + fWaitingForControlPeer = false; } auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); } @@ -200,41 +286,51 @@ auto Socket::TrySend(std::vector& msgVec) -> int64_t { return SendIm auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); } auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int -{ - if (fWaitingForRemoteConnect) { - try { - WaitForRemoteConnect(); - } catch (const std::exception& e) { - LOG(error) << e.what(); - return -1; - } +try { + if (fWaitingForControlPeer) { + WaitForControlPeer(); + AnnounceDataAddress(); + ProcessDataAddressAnnouncement(ReceiveControlMessage()); } - // void* metadata = malloc(sizeof(size_t)); - - auto ret = zmq_send(fMetaSocket, nullptr, 0, flags); - if (ret == EAGAIN) { - return -2; - } else if (ret < 0) { - LOG(error) << "Failed sending meta message on socket " << fId << ", reason: " << zmq_strerror(errno); - return -1; - } else { - // auto ret2 = fi_send(fDataEndpoint, msg->GetData(), msg->GetSize(), nullptr, fi_addr_t dest_addr, nullptr); - return ret; - } + auto ret = zmq_send(fControlSocket, nullptr, 0, flags); + if (ret == EAGAIN) throw SilentSocketError("EAGAIN"); + if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno))); + + return ret; +} +catch (const SilentSocketError& e) +{ + return -2; +} +catch (const std::exception& e) +{ + LOG(error) << e.what(); + return -1; } auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int -{ - auto ret = zmq_recv(fMetaSocket, nullptr, 0, flags); - if (ret == EAGAIN) { - return -2; - } else if (ret < 0) { - LOG(error) << "Failed receiving meta message on socket " << fId << ", reason: " << zmq_strerror(errno); - return -1; - } else { - return ret; +try { + if (fWaitingForControlPeer) { + WaitForControlPeer(); + AnnounceDataAddress(); + ProcessDataAddressAnnouncement(ReceiveControlMessage()); } + + auto ret = zmq_recv(fControlSocket, nullptr, 0, flags); + if (ret == EAGAIN) throw SilentSocketError("EAGAIN"); + if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno))); + + return ret; +} +catch (const SilentSocketError& e) +{ + return -2; +} +catch (const std::exception& e) +{ + LOG(error) << e.what(); + return -1; } auto Socket::SendImpl(vector& msgVec, const int flags, const int timeout) -> int64_t @@ -410,7 +506,7 @@ auto Socket::ReceiveImpl(vector& msgVec, const int flags, cons auto Socket::Close() -> void { - if (zmq_close(fMetaSocket) != 0) + if (zmq_close(fControlSocket) != 0) throw SocketError(tools::ToString("Failed closing zmq meta socket, reason: ", zmq_strerror(errno))); if (zmq_close(fMonitorSocket) != 0) @@ -437,14 +533,14 @@ auto Socket::Close() -> void auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void { - if (zmq_setsockopt(fMetaSocket, GetConstant(option), value, valueSize) < 0) { + if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))}; } } auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void { - if (zmq_getsockopt(fMetaSocket, GetConstant(option), value, valueSize) < 0) { + if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))}; } } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 44d12ad1..1237f7b0 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -12,8 +12,10 @@ #include #include #include +#include #include // unique_ptr +#include #include namespace fair @@ -25,7 +27,7 @@ namespace ofi /** * @class Socket Socket.h - * @brief + * @brief * * @todo TODO insert long description */ @@ -51,7 +53,7 @@ class Socket : public fair::mq::Socket auto TrySend(std::vector& msgVec) -> int64_t override; auto TryReceive(std::vector& msgVec) -> int64_t override; - auto GetSocket() const -> void* override { return fMetaSocket; } + auto GetSocket() const -> void* override { return fControlSocket; } auto GetSocket(int nothing) const -> int override { return -1; } auto Close() -> void override; @@ -74,7 +76,7 @@ class Socket : public fair::mq::Socket ~Socket() override; private: - void* fMetaSocket; + void* fControlSocket; void* fMonitorSocket; fid_ep* fDataEndpoint; fid_cq* fDataCompletionQueueTx; @@ -85,8 +87,9 @@ class Socket : public fair::mq::Socket std::atomic fMessagesTx; std::atomic fMessagesRx; Context& fContext; - fi_addr_t fRemoteAddr; - bool fWaitingForRemoteConnect; + fi_addr_t fRemoteDataAddr; + sockaddr_in fLocalDataAddr; + bool fWaitingForControlPeer; int fSndTimeout; int fRcvTimeout; @@ -97,9 +100,20 @@ class Socket : public fair::mq::Socket auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto InitDataEndpoint() -> void; - auto WaitForRemoteConnect() -> void; + auto WaitForControlPeer() -> void; + auto AnnounceDataAddress() -> void; + auto SendControlMessage(std::unique_ptr ctrl) -> void; + auto ReceiveControlMessage() -> std::unique_ptr; + auto ProcessDataAddressAnnouncement(std::unique_ptr ctrl) -> void; + auto ConnectControlSocket(Context::Address address) -> void; + auto BindControlSocket(Context::Address address) -> void; }; /* class Socket */ +// helper function to clean up the object holding the data after it is transported. +void free_string(void* /*data*/, void* hint); + +struct SilentSocketError : SocketError { using SocketError::SocketError; }; + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 5be1ab3e..7106f00b 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -27,7 +27,8 @@ TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* co try : FairMQTransportFactory{id} { LOG(debug) << "Transport: Using ZeroMQ (" << fContext.GetZmqVersion() << ") & " - << "OFI libfabric (API " << fContext.GetOfiApiVersion() << ")"; + << "OFI libfabric (API " << fContext.GetOfiApiVersion() << ") & " + << "Google Protobuf (" << fContext.GetPbVersion() << ")"; } catch (ContextError& e) { throw TransportFactoryError{e.what()}; }