/******************************************************************************** * Copyright (C) 2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * * * * This software is distributed under the terms of the * * GNU Lesser General Public Licence (LGPL) version 3, * * copied verbatim in the file "LICENSE" * ********************************************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fair { namespace mq { namespace ofi { using namespace std; Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) : fContext(context) , fOfiInfo(nullptr) , fOfiFabric(nullptr) , fOfiDomain(nullptr) , fPassiveEndpoint(nullptr) , fDataEndpoint(nullptr) , fControlEndpoint(nullptr) , fId(id + "." + name + "." + type) , fBytesTx(0) , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) , fSndTimeout(100) , fRcvTimeout(100) , fMultiPartRecvCounter(-1) , fSendPushSem(fContext.GetIoContext(), 384) , fSendPopSem(fContext.GetIoContext(), 0) , fRecvPushSem(fContext.GetIoContext(), 384) , fRecvPopSem(fContext.GetIoContext(), 0) , fNeedOfiMemoryRegistration(false) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } } auto Socket::InitOfi(Address addr) -> void { if (!fOfiInfo) { assert(!fOfiFabric); assert(!fOfiDomain); asiofi::hints hints; if (addr.Protocol == "tcp") { hints.set_provider("sockets"); } else if (addr.Protocol == "verbs") { hints.set_provider("verbs"); } if (fRemoteAddr == addr) { fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); } else { fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints); } LOG(debug) << "OFI transport (" << fId << "): " << *fOfiInfo; fOfiFabric = tools::make_unique(*fOfiInfo); fOfiDomain = tools::make_unique(*fOfiFabric); } } auto Socket::Bind(const string& addr) -> bool try { fLocalAddr = Context::VerifyAddress(addr); if (fLocalAddr.Protocol == "verbs") { fNeedOfiMemoryRegistration = true; } InitOfi(fLocalAddr); fPassiveEndpoint = tools::make_unique(fContext.GetIoContext(), *fOfiFabric); //fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr)); BindControlEndpoint(); return true; } // TODO catch the correct ofi error catch (const SilentSocketError& e) { // do not print error in this case, this is handled by FairMQDevice // in case no connection could be established after trying a number of random ports from a range. return false; } catch (const std::exception& e) { LOG(error) << "OFI transport: " << e.what(); return false; } catch (...) { LOG(error) << "OFI transport: Unknown exception in ofi::Socket::Bind"; return false; } auto Socket::BindControlEndpoint() -> void { assert(!fControlEndpoint); fPassiveEndpoint->listen([&](asiofi::info&& info) { LOG(debug) << "OFI transport (" << fId << "): control band connection request received. Accepting ..."; fControlEndpoint = tools::make_unique( fContext.GetIoContext(), *fOfiDomain, info); fControlEndpoint->enable(); fControlEndpoint->accept([&]() { LOG(debug) << "OFI transport (" << fId << "): control band connection accepted."; BindDataEndpoint(); }); }); LOG(debug) << "OFI transport (" << fId << "): control band bound to " << fLocalAddr; } auto Socket::BindDataEndpoint() -> void { assert(!fDataEndpoint); fPassiveEndpoint->listen([&](asiofi::info&& info) { LOG(debug) << "OFI transport (" << fId << "): data band connection request received. Accepting ..."; fDataEndpoint = tools::make_unique( fContext.GetIoContext(), *fOfiDomain, info); fDataEndpoint->enable(); fDataEndpoint->accept([&]() { LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); }); }); LOG(debug) << "OFI transport (" << fId << "): data band bound to " << fLocalAddr; } auto Socket::Connect(const string& address) -> bool try { fRemoteAddr = Context::VerifyAddress(address); if (fRemoteAddr.Protocol == "verbs") { fNeedOfiMemoryRegistration = true; } InitOfi(fRemoteAddr); ConnectEndpoint(fControlEndpoint, Band::Control); ConnectEndpoint(fDataEndpoint, Band::Data); boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); boost::asio::post(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); return true; } catch (const SilentSocketError& e) { // do not print error in this case, this is handled by FairMQDevice return false; } catch (const std::exception& e) { LOG(error) << "OFI transport: " << e.what(); return false; } catch (...) { LOG(error) << "OFI transport: Unknown exception in ofi::Socket::Connect"; return false; } auto Socket::ConnectEndpoint(std::unique_ptr& endpoint, Band type) -> void { assert(!endpoint); std::string band(type == Band::Control ? "control" : "data"); endpoint = tools::make_unique(fContext.GetIoContext(), *fOfiDomain); endpoint->enable(); LOG(debug) << "OFI transport (" << fId << "): Sending " << band << " band connection request to " << fRemoteAddr; std::mutex mtx; std::condition_variable cv; bool notified(false), connected(false); while (true) { endpoint->connect(Context::ConvertAddress(fRemoteAddr), [&, band](asiofi::eq::event event) { // LOG(debug) << "OFI transport (" << fId << "): " << band << " band conn event happened"; std::unique_lock lk2(mtx); notified = true; if (event == asiofi::eq::event::connected) { LOG(debug) << "OFI transport (" << fId << "): " << band << " band connected."; connected = true; } else { // LOG(debug) << "OFI transport (" << fId << "): " << band << " band connection refused. Trying again."; } lk2.unlock(); cv.notify_one(); }); { std::unique_lock lk(mtx); cv.wait(lk, [&] { return notified; }); if (connected) { break; } else { notified = false; lk.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } } } auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int { // timeout argument not yet implemented std::vector msgVec; msgVec.reserve(1); msgVec.emplace_back(std::move(msg)); return Send(msgVec); } auto Socket::Send(std::vector& msgVec, const int /*timeout*/) -> int64_t try { // timeout argument not yet implemented int size(0); for (auto& msg : msgVec) { size += msg->GetSize(); } fSendPushSem.wait(); { std::lock_guard lk(fSendQueueMutex); fSendQueue.emplace(std::move(msgVec)); } fSendPopSem.signal(); return size; } catch (const std::exception& e) { LOG(error) << e.what(); return -1; } auto Socket::SendQueueReader() -> void { fSendPopSem.async_wait([&] { // Read msg from send queue std::unique_lock lk(fSendQueueMutex); std::vector msgVec(std::move(fSendQueue.front())); fSendQueue.pop(); lk.unlock(); bool postMultiPartStartBuffer = msgVec.size() > 1; for (auto& msg : msgVec) { // Create control message ofi::unique_ptr ctrl(nullptr); if (postMultiPartStartBuffer) { postMultiPartStartBuffer = false; ctrl = MakeControlMessageWithPmr(fControlMemPool); ctrl->msg.postMultiPartStartBuffer.numParts = msgVec.size(); ctrl->msg.postMultiPartStartBuffer.size = msg->GetSize(); } else { ctrl = MakeControlMessageWithPmr(fControlMemPool); ctrl->msg.postBuffer.size = msg->GetSize(); } // Send control message boost::asio::mutable_buffer ctrlMsg(ctrl.get(), sizeof(ControlMessage)); if (fNeedOfiMemoryRegistration) { asiofi::memory_region mr(*fOfiDomain, ctrlMsg, asiofi::mr::access::send); auto desc = mr.desc(); fControlEndpoint->send(ctrlMsg, desc, [&, ctrl2 = std::move(ctrlMsg), mr2 = std::move(mr)]( boost::asio::mutable_buffer) mutable {}); } else { fControlEndpoint->send( ctrlMsg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable {}); } // Send data message const auto size = msg->GetSize(); if (size) { boost::asio::mutable_buffer buffer(msg->GetData(), size); if (fNeedOfiMemoryRegistration) { asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send); auto desc = mr.desc(); fDataEndpoint->send(buffer, desc, [&, size, msg2 = std::move(msg), mr2 = std::move(mr)]( boost::asio::mutable_buffer) mutable { fBytesTx += size; fMessagesTx++; fSendPushSem.signal(); }); } else { fDataEndpoint->send( buffer, [&, size, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { fBytesTx += size; fMessagesTx++; fSendPushSem.signal(); }); } } else { ++fMessagesTx; fSendPushSem.signal(); } } boost::asio::dispatch(fContext.GetIoContext(), std::bind(&Socket::SendQueueReader, this)); }); } auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int try { // timeout argument not yet implemented fRecvPopSem.wait(); { std::lock_guard lk(fRecvQueueMutex); msg = std::move(fRecvQueue.front().front()); fRecvQueue.pop(); } fRecvPushSem.signal(); int size(msg->GetSize()); fBytesRx += size; ++fMessagesRx; return size; } catch (const std::exception& e) { LOG(error) << e.what(); return -1; } auto Socket::Receive(std::vector& msgVec, const int /*timeout*/) -> int64_t try { // timeout argument not yet implemented fRecvPopSem.wait(); { std::lock_guard lk(fRecvQueueMutex); msgVec = std::move(fRecvQueue.front()); fRecvQueue.pop(); } fRecvPushSem.signal(); int64_t size(0); for (auto& msg : msgVec) { size += msg->GetSize(); } fBytesRx += size; ++fMessagesRx; return size; } catch (const std::exception& e) { LOG(error) << e.what(); return -1; } auto Socket::RecvControlQueueReader() -> void { fRecvPushSem.async_wait([&] { // Receive control message ofi::unique_ptr ctrl(MakeControlMessageWithPmr(fControlMemPool)); boost::asio::mutable_buffer ctrlMsg(ctrl.get(), sizeof(ControlMessage)); if (fNeedOfiMemoryRegistration) { asiofi::memory_region mr(*fOfiDomain, ctrlMsg, asiofi::mr::access::recv); auto desc = mr.desc(); fControlEndpoint->recv( ctrlMsg, desc, [&, ctrl2 = std::move(ctrl), mr2 = std::move(mr)]( boost::asio::mutable_buffer) mutable { OnRecvControl(std::move(ctrl2)); }); } else { fControlEndpoint->recv( ctrlMsg, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { OnRecvControl(std::move(ctrl2)); }); } }); } auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void { // Check control message type auto size(0); if (ctrl->type == ControlMessageType::PostMultiPartStartBuffer) { size = ctrl->msg.postMultiPartStartBuffer.size; if (fMultiPartRecvCounter == -1) { fMultiPartRecvCounter = ctrl->msg.postMultiPartStartBuffer.numParts; assert(fInflightMultiPartMessage.empty()); fInflightMultiPartMessage.reserve(ctrl->msg.postMultiPartStartBuffer.numParts); } else { throw SocketError{tools::ToString( "OFI transport: Received control start of new multi part message without completed " "reception of previous multi part message. Number of parts missing: ", fMultiPartRecvCounter)}; } } else if (ctrl->type == ControlMessageType::PostBuffer) { size = ctrl->msg.postBuffer.size; } else { throw SocketError{tools::ToString("OFI transport: Unknown control message type: '", static_cast(ctrl->type))}; } // Receive data auto msg = fContext.MakeReceiveMessage(size); if (size) { boost::asio::mutable_buffer buffer(msg->GetData(), size); if (fNeedOfiMemoryRegistration) { asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::recv); auto desc = mr.desc(); fDataEndpoint->recv( buffer, desc, [&, msg2 = std::move(msg), mr2 = std::move(mr)]( boost::asio::mutable_buffer) mutable { DataMessageReceived(std::move(msg2)); }); } else { fDataEndpoint->recv(buffer, [&, msg2 = std::move(msg)](boost::asio::mutable_buffer) mutable { DataMessageReceived(std::move(msg2)); }); } } else { DataMessageReceived(std::move(msg)); } boost::asio::dispatch(fContext.GetIoContext(), std::bind(&Socket::RecvControlQueueReader, this)); } auto Socket::DataMessageReceived(MessagePtr msg) -> void { if (fMultiPartRecvCounter > 0) { --fMultiPartRecvCounter; fInflightMultiPartMessage.push_back(std::move(msg)); } if (fMultiPartRecvCounter == 0) { std::unique_lock lk(fRecvQueueMutex); fRecvQueue.push(std::move(fInflightMultiPartMessage)); lk.unlock(); fMultiPartRecvCounter = -1; fRecvPopSem.signal(); } else if (fMultiPartRecvCounter == -1) { std::vector msgVec; msgVec.push_back(std::move(msg)); std::unique_lock lk(fRecvQueueMutex); fRecvQueue.push(std::move(msgVec)); lk.unlock(); fRecvPopSem.signal(); } } auto Socket::Close() -> void {} auto Socket::SetOption(const string& /*option*/, const void* /*value*/, size_t /*valueSize*/) -> void { // if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { // throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))}; // } } auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valueSize*/) -> void { // if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { // throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))}; // } } void Socket::SetLinger(const int /*value*/) { // azmq::socket::linger opt(value); // fControlEndpoint.set_option(opt); } int Socket::GetLinger() const { // azmq::socket::linger opt(0); // fControlEndpoint.get_option(opt); // return opt.value(); return 0; } void Socket::SetSndBufSize(const int /*value*/) { // azmq::socket::snd_hwm opt(value); // fControlEndpoint.set_option(opt); } int Socket::GetSndBufSize() const { // azmq::socket::snd_hwm opt(0); // fControlEndpoint.get_option(opt); // return opt.value(); return 0; } void Socket::SetRcvBufSize(const int /*value*/) { // azmq::socket::rcv_hwm opt(value); // fControlEndpoint.set_option(opt); } int Socket::GetRcvBufSize() const { // azmq::socket::rcv_hwm opt(0); // fControlEndpoint.get_option(opt); // return opt.value(); return 0; } void Socket::SetSndKernelSize(const int /*value*/) { // azmq::socket::snd_buf opt(value); // fControlEndpoint.set_option(opt); } int Socket::GetSndKernelSize() const { // azmq::socket::snd_buf opt(0); // fControlEndpoint.get_option(opt); // return opt.value(); return 0; } void Socket::SetRcvKernelSize(const int /*value*/) { // azmq::socket::rcv_buf opt(value); // fControlEndpoint.set_option(opt); } int Socket::GetRcvKernelSize() const { // azmq::socket::rcv_buf opt(0); // fControlEndpoint.get_option(opt); // return opt.value(); return 0; } auto Socket::GetConstant(const string& /*constant*/) -> int { // if (constant == "") // return 0; // if (constant == "sub") // return ZMQ_SUB; // if (constant == "pub") // return ZMQ_PUB; // if (constant == "xsub") // return ZMQ_XSUB; // if (constant == "xpub") // return ZMQ_XPUB; // if (constant == "push") // return ZMQ_PUSH; // if (constant == "pull") // return ZMQ_PULL; // if (constant == "req") // return ZMQ_REQ; // if (constant == "rep") // return ZMQ_REP; // if (constant == "dealer") // return ZMQ_DEALER; // if (constant == "router") // return ZMQ_ROUTER; // if (constant == "pair") // return ZMQ_PAIR; // // if (constant == "snd-hwm") // return ZMQ_SNDHWM; // if (constant == "rcv-hwm") // return ZMQ_RCVHWM; // if (constant == "snd-size") // return ZMQ_SNDBUF; // if (constant == "rcv-size") // return ZMQ_RCVBUF; // if (constant == "snd-more") // return ZMQ_SNDMORE; // if (constant == "rcv-more") // return ZMQ_RCVMORE; // // if (constant == "linger") // return ZMQ_LINGER; // if (constant == "no-block") // return ZMQ_DONTWAIT; // if (constant == "snd-more no-block") // return ZMQ_DONTWAIT|ZMQ_SNDMORE; // return -1; } Socket::~Socket() { try { Close(); // NOLINT(clang-analyzer-optin.cplusplus.VirtualCall) } catch (SocketError& e) { LOG(error) << e.what(); } } } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */