/******************************************************************************** * Copyright (C) 2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * * * * This software is distributed under the terms of the * * GNU Lesser General Public Licence (LGPL) version 3, * * copied verbatim in the file "LICENSE" * ********************************************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fair { namespace mq { namespace ofi { using namespace std; Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/, FairMQTransportFactory* fac) : FairMQSocket{fac} , fDataEndpoint(nullptr) , fDataCompletionQueueTx(nullptr) , fDataCompletionQueueRx(nullptr) , fId(id + "." + name + "." + type) , fControlSocket(nullptr) , fMonitorSocket(nullptr) , fSndTimeout(100) , fRcvTimeout(100) , fContext(context) , fWaitingForControlPeer(false) , fIoStrand(fContext.GetIoContext()) , fBytesTx(0) , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } else { fControlSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); if (fControlSocket == nullptr) throw SocketError{tools::ToString("Failed creating zmq meta socket ", fId, ", reason: ", zmq_strerror(errno))}; if (zmq_setsockopt(fControlSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_IDENTITY socket option, reason: ", zmq_strerror(errno))}; // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. // Default value for ZeroMQ is -1, which is to wait forever. int linger = 1000; if (zmq_setsockopt(fControlSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_LINGER socket option, reason: ", zmq_strerror(errno))}; // TODO enable again and implement retries // if (zmq_setsockopt(fControlSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) // throw SocketError{tools::ToString("Failed setting ZMQ_SNDTIMEO socket option, reason: ", zmq_strerror(errno))}; // // if (zmq_setsockopt(fControlSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) // throw SocketError{tools::ToString("Failed setting ZMQ_RCVTIMEO socket option, reason: ", zmq_strerror(errno))}; fMonitorSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); if (fMonitorSocket == nullptr) throw SocketError{tools::ToString("Failed creating zmq monitor socket ", fId, ", reason: ", zmq_strerror(errno))}; auto mon_addr = tools::ToString("inproc://", fId); if (zmq_socket_monitor(fControlSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_CONNECTED) < 0) throw SocketError{tools::ToString("Failed setting up monitor on meta socket, reason: ", zmq_strerror(errno))}; if (zmq_connect(fMonitorSocket, mon_addr.c_str()) != 0) throw SocketError{tools::ToString("Failed connecting monitor socket to meta socket, reason: ", zmq_strerror(errno))}; } } auto Socket::Bind(const string& address) -> bool try { auto addr = Context::VerifyAddress(address); BindControlSocket(addr); fContext.InitOfi(ConnectionType::Bind, addr); InitDataEndpoint(); fWaitingForControlPeer = true; return true; } catch (const SilentSocketError& e) { // do not print error in this case, this is handled by FairMQDevice // in case no connection could be established after trying a number of random ports from a range. return false; } catch (const SocketError& e) { LOG(error) << e.what(); return false; } auto Socket::Connect(const string& address) -> bool { auto addr = Context::VerifyAddress(address); ConnectControlSocket(addr); fContext.InitOfi(ConnectionType::Connect, addr); InitDataEndpoint(); fWaitingForControlPeer = true; return true; } auto Socket::BindControlSocket(Context::Address address) -> void { auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); if (zmq_bind(fControlSocket, addr.c_str()) != 0) { if (errno == EADDRINUSE) throw SilentSocketError("EADDRINUSE"); throw SocketError(tools::ToString("Failed binding control socket ", fId, ", reason: ", zmq_strerror(errno))); } } auto Socket::ConnectControlSocket(Context::Address address) -> void { auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port); if (zmq_connect(fControlSocket, addr.c_str()) != 0) throw SocketError(tools::ToString("Failed connecting control socket ", fId, ", reason: ", zmq_strerror(errno))); } auto Socket::ProcessDataAddressAnnouncement(std::unique_ptr ctrl) -> void { assert(ctrl->has_data_address_announcement()); auto daa = ctrl->data_address_announcement(); sockaddr_in remoteAddr; remoteAddr.sin_family = AF_INET; remoteAddr.sin_port = daa.port(); remoteAddr.sin_addr.s_addr = daa.ipv4(); LOG(debug) << "Data address announcement of remote ofi endpoint received: " << Context::ConvertAddress(remoteAddr); fRemoteDataAddr = fContext.InsertAddressVector(remoteAddr); } auto Socket::InitDataEndpoint() -> void { if (!fDataEndpoint) { try { fDataEndpoint = fContext.CreateOfiEndpoint(); } catch (ContextError& e) { throw SocketError(tools::ToString("Failed creating ofi endpoint, reason: ", e.what())); } if (!fDataCompletionQueueTx) fDataCompletionQueueTx = fContext.CreateOfiCompletionQueue(Direction::Transmit); auto ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueTx->fid, FI_TRANSMIT); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed binding ofi transmit completion queue to endpoint, reason: ", fi_strerror(ret))); if (!fDataCompletionQueueRx) fDataCompletionQueueRx = fContext.CreateOfiCompletionQueue(Direction::Receive); ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueRx->fid, FI_RECV); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed binding ofi receive completion queue to endpoint, reason: ", fi_strerror(ret))); ret = fi_enable(fDataEndpoint); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed enabling ofi endpoint, reason: ", fi_strerror(ret))); } } void free_string(void* /*data*/, void* hint) { delete static_cast(hint); } auto Socket::AnnounceDataAddress() -> void try { using namespace google::protobuf; size_t addrlen = sizeof(sockaddr_in); auto ret = fi_getname(&fDataEndpoint->fid, &fLocalDataAddr, &addrlen); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed retrieving native address from ofi endpoint, reason: ", fi_strerror(ret))); assert(addrlen == sizeof(sockaddr_in)); LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr); // Create new control message auto ctrl = tools::make_unique(); auto daa = tools::make_unique(); // Fill data address announcement daa->set_ipv4(fLocalDataAddr.sin_addr.s_addr); daa->set_port(fLocalDataAddr.sin_port); // Fill control message ctrl->set_allocated_data_address_announcement(daa.release()); assert(ctrl->IsInitialized()); SendControlMessage(move(ctrl)); } catch (const SocketError& e) { throw SocketError(tools::ToString("Failed to announce data address, reason: ", e.what())); } auto Socket::SendControlMessage(unique_ptr ctrl) -> void { assert(fControlSocket); // LOG(debug) << "About to send control message: " << ctrl->DebugString(); // Serialize string* str = new string(); ctrl->SerializeToString(str); zmq_msg_t msg; auto ret = zmq_msg_init_data(&msg, const_cast(str->c_str()), str->length(), free_string, str); assert(ret == 0); // Send if (zmq_msg_send(&msg, fControlSocket, 0) == -1) { zmq_msg_close(&msg); throw SocketError(tools::ToString("Failed to send control message, reason: ", zmq_strerror(errno))); } } auto Socket::ReceiveControlMessage() -> unique_ptr { assert(fControlSocket); // Receive zmq_msg_t msg; auto ret = zmq_msg_init(&msg); assert(ret == 0); if (zmq_msg_recv(&msg, fControlSocket, 0) == -1) { zmq_msg_close(&msg); throw SocketError(tools::ToString("Failed to receive control message, reason: ", zmq_strerror(errno))); } // Deserialize auto ctrl = tools::make_unique(); ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg)); zmq_msg_close(&msg); // LOG(debug) << "Received control message: " << ctrl->DebugString(); return ctrl; } auto Socket::WaitForControlPeer() -> void { assert(fWaitingForControlPeer); // First frame in message contains event number and value zmq_msg_t msg; zmq_msg_init(&msg); if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); uint8_t* data = (uint8_t*) zmq_msg_data(&msg); uint16_t event = *(uint16_t*)(data); int value = *(uint32_t *)(data + 2); // Second frame in message contains event address zmq_msg_init(&msg); if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); if (event == ZMQ_EVENT_ACCEPTED) { // string localAddress = string(static_cast(zmq_msg_data(&msg)), zmq_msg_size(&msg)); sockaddr_in remoteAddr; socklen_t addrSize = sizeof(sockaddr_in); int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize); if (ret != 0) throw SocketError(tools::ToString("Failed retrieving remote address, reason: ", strerror(errno))); string remoteIp(inet_ntoa(remoteAddr.sin_addr)); int remotePort = ntohs(remoteAddr.sin_port); LOG(debug) << "Accepted control peer connection from " << remoteIp << ":" << remotePort; } else if (event == ZMQ_EVENT_CONNECTED) { LOG(debug) << "Connected successfully to control peer"; } else { LOG(debug) << "Unknown monitor event received: " << event << ". Ignoring."; } fWaitingForControlPeer = false; } auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); } auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return ReceiveImpl(msg, 0, timeout); } auto Socket::Send(std::vector& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); } auto Socket::Receive(std::vector& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); } auto Socket::TrySend(MessagePtr& msg) -> int { return SendImpl(msg, ZMQ_DONTWAIT, 0); } auto Socket::TryReceive(MessagePtr& msg) -> int { return ReceiveImpl(msg, ZMQ_DONTWAIT, 0); } auto Socket::TrySend(std::vector& msgVec) -> int64_t { return SendImpl(msgVec, ZMQ_DONTWAIT, 0); } auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); } auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int try { if (fWaitingForControlPeer) { WaitForControlPeer(); AnnounceDataAddress(); ProcessDataAddressAnnouncement(ReceiveControlMessage()); } auto size = msg->GetSize(); // Create and send control message auto ctrl = tools::make_unique(); auto buf = tools::make_unique(); buf->set_size(size); ctrl->set_allocated_post_buffer(buf.release()); assert(ctrl->IsInitialized()); SendControlMessage(move(ctrl)); if (size) { // Receive and process control message // auto ctrl2 = ReceiveControlMessage(); // assert(ctrl2->has_post_buffer_acknowledgement()); // assert(ctrl2->post_buffer_acknowledgement().size() == size); // Send data fi_context ctx; auto ret = fi_send(fDataEndpoint, msg->GetData(), size, nullptr, fRemoteDataAddr, &ctx); if (ret < 0) throw SocketError(tools::ToString("Failed posting ofi send buffer, reason: ", fi_strerror(ret))); } if (size) { fi_cq_err_entry cqEntry; auto ret = fi_cq_sread(fDataCompletionQueueTx, &cqEntry, 1, nullptr, -1); if (ret != 1) throw SocketError(tools::ToString("Failed reading ofi tx completion queue event, reason: ", fi_strerror(ret))); } msg.reset(nullptr); fBytesTx += size; fMessagesTx++; return size; } catch (const SilentSocketError& e) { return -2; } catch (const std::exception& e) { LOG(error) << e.what(); return -1; } auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int try { if (fWaitingForControlPeer) { WaitForControlPeer(); AnnounceDataAddress(); ProcessDataAddressAnnouncement(ReceiveControlMessage()); } // Receive and process control message auto ctrl = ReceiveControlMessage(); assert(ctrl->has_post_buffer()); auto postBuffer = ctrl->post_buffer(); auto size = postBuffer.size(); // Receive data if (size) { fi_context ctx; msg->Rebuild(size); auto buf = msg->GetData(); auto size2 = msg->GetSize(); auto ret = fi_recv(fDataEndpoint, buf, size2, nullptr, fRemoteDataAddr, &ctx); if (ret < 0) throw SocketError(tools::ToString("Failed posting ofi receive buffer, reason: ", fi_strerror(ret))); // Create and send control message // auto ctrl2 = tools::make_unique(); // auto ack = tools::make_unique(); // ack->set_size(msg->GetSize()); // ctrl2->set_allocated_post_buffer_acknowledgement(ack.release()); // assert(ctrl2->IsInitialized()); // SendControlMessage(move(ctrl2)); fi_cq_err_entry cqEntry; ret = fi_cq_sread(fDataCompletionQueueRx, &cqEntry, 1, nullptr, -1); if (ret != 1) throw SocketError(tools::ToString("Failed reading ofi rx completion queue event, reason: ", fi_strerror(ret))); assert(cqEntry.len == size2); assert(cqEntry.buf == buf); } fBytesRx += size; fMessagesRx++; return size; } catch (const SilentSocketError& e) { return -2; } catch (const std::exception& e) { LOG(error) << e.what(); return -1; } auto Socket::SendImpl(vector& msgVec, const int flags, const int timeout) -> int64_t { throw SocketError{"Not yet implemented."}; // const unsigned int vecSize = msgVec.size(); // int elapsed = 0; // // // Sending vector typicaly handles more then one part // if (vecSize > 1) // { // int64_t totalSize = 0; // int nbytes = -1; // bool repeat = false; // // while (true && !fInterrupted) // { // for (unsigned int i = 0; i < vecSize; ++i) // { // nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), // fSocket, // (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); // if (nbytes >= 0) // { // static_cast(msgVec[i].get())->fQueued = true; // size_t size = msgVec[i]->GetSize(); // // totalSize += size; // } // else // { // // according to ZMQ docs, this can only occur for the first part // if (zmq_errno() == EAGAIN) // { // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) // { // if (timeout) // { // elapsed += fSndTimeout; // if (elapsed >= timeout) // { // return -2; // } // } // repeat = true; // break; // } // else // { // return -2; // } // } // if (zmq_errno() == ETERM) // { // LOG(info) << "terminating socket " << fId; // return -1; // } // LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); // return nbytes; // } // } // // if (repeat) // { // continue; // } // // // store statistics on how many messages have been sent (handle all parts as a single message) // ++fMessagesTx; // fBytesTx += totalSize; // return totalSize; // } // // return -1; // } // If there's only one part, send it as a regular message // else if (vecSize == 1) // { // return Send(msgVec.back(), flags); // } // else // if the vector is empty, something might be wrong // { // LOG(warn) << "Will not send empty vector"; // return -1; // } } auto Socket::ReceiveImpl(vector& msgVec, const int flags, const int timeout) -> int64_t { throw SocketError{"Not yet implemented."}; // int64_t totalSize = 0; // int64_t more = 0; // bool repeat = false; // int elapsed = 0; // // while (true) // { // // Warn if the vector is filled before Receive() and empty it. // // if (msgVec.size() > 0) // // { // // LOG(warn) << "Message vector contains elements before Receive(), they will be deleted!"; // // msgVec.clear(); // // } // // totalSize = 0; // more = 0; // repeat = false; // // do // { // FairMQMessagePtr part(new FairMQMessageSHM(fManager, GetTransport())); // zmq_msg_t* msgPtr = static_cast(part.get())->GetMessage(); // // int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); // if (nbytes == 0) // { // msgVec.push_back(move(part)); // } // else if (nbytes > 0) // { // MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); // size_t size = 0; // static_cast(part.get())->fHandle = hdr->fHandle; // static_cast(part.get())->fSize = hdr->fSize; // static_cast(part.get())->fRegionId = hdr->fRegionId; // static_cast(part.get())->fHint = hdr->fHint; // size = part->GetSize(); // // msgVec.push_back(move(part)); // // totalSize += size; // } // else if (zmq_errno() == EAGAIN) // { // if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) // { // if (timeout) // { // elapsed += fSndTimeout; // if (elapsed >= timeout) // { // return -2; // } // } // repeat = true; // break; // } // else // { // return -2; // } // } // else // { // return nbytes; // } // // size_t more_size = sizeof(more); // zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); // } // while (more); // // if (repeat) // { // continue; // } // // // store statistics on how many messages have been received (handle all parts as a single message) // ++fMessagesRx; // fBytesRx += totalSize; // return totalSize; // } } auto Socket::Close() -> void { if (zmq_close(fControlSocket) != 0) throw SocketError(tools::ToString("Failed closing zmq meta socket, reason: ", zmq_strerror(errno))); if (zmq_close(fMonitorSocket) != 0) throw SocketError(tools::ToString("Failed closing zmq monitor socket, reason: ", zmq_strerror(errno))); if (fDataEndpoint) { auto ret = fi_close(&fDataEndpoint->fid); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed closing ofi endpoint, reason: ", fi_strerror(ret))); } if (fDataCompletionQueueTx) { auto ret = fi_close(&fDataCompletionQueueTx->fid); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed closing ofi transmit completion queue, reason: ", fi_strerror(ret))); } if (fDataCompletionQueueRx) { auto ret = fi_close(&fDataCompletionQueueRx->fid); if (ret != FI_SUCCESS) throw SocketError(tools::ToString("Failed closing ofi receive completion queue, reason: ", fi_strerror(ret))); } } auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void { if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))}; } } auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void { if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))}; } } int Socket::GetLinger() const { int value = 0; size_t valueSize; if (zmq_getsockopt(fControlSocket, ZMQ_LINGER, &value, &valueSize) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_LINGER, reason: ", zmq_strerror(errno))); } return value; } void Socket::SetSndBufSize(const int value) { if (zmq_setsockopt(fControlSocket, ZMQ_SNDHWM, &value, sizeof(value)) < 0) { throw SocketError(tools::ToString("failed setting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); } } int Socket::GetSndBufSize() const { int value = 0; size_t valueSize; if (zmq_getsockopt(fControlSocket, ZMQ_SNDHWM, &value, &valueSize) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); } return value; } void Socket::SetRcvBufSize(const int value) { if (zmq_setsockopt(fControlSocket, ZMQ_RCVHWM, &value, sizeof(value)) < 0) { throw SocketError(tools::ToString("failed setting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); } } int Socket::GetRcvBufSize() const { int value = 0; size_t valueSize; if (zmq_getsockopt(fControlSocket, ZMQ_RCVHWM, &value, &valueSize) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); } return value; } void Socket::SetSndKernelSize(const int value) { if (zmq_setsockopt(fControlSocket, ZMQ_SNDBUF, &value, sizeof(value)) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); } } int Socket::GetSndKernelSize() const { int value = 0; size_t valueSize; if (zmq_getsockopt(fControlSocket, ZMQ_SNDBUF, &value, &valueSize) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); } return value; } void Socket::SetRcvKernelSize(const int value) { if (zmq_setsockopt(fControlSocket, ZMQ_RCVBUF, &value, sizeof(value)) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); } } int Socket::GetRcvKernelSize() const { int value = 0; size_t valueSize; if (zmq_getsockopt(fControlSocket, ZMQ_RCVBUF, &value, &valueSize) < 0) { throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); } return value; } auto Socket::GetConstant(const string& constant) -> int { if (constant == "") return 0; if (constant == "sub") return ZMQ_SUB; if (constant == "pub") return ZMQ_PUB; if (constant == "xsub") return ZMQ_XSUB; if (constant == "xpub") return ZMQ_XPUB; if (constant == "push") return ZMQ_PUSH; if (constant == "pull") return ZMQ_PULL; if (constant == "req") return ZMQ_REQ; if (constant == "rep") return ZMQ_REP; if (constant == "dealer") return ZMQ_DEALER; if (constant == "router") return ZMQ_ROUTER; if (constant == "pair") return ZMQ_PAIR; if (constant == "snd-hwm") return ZMQ_SNDHWM; if (constant == "rcv-hwm") return ZMQ_RCVHWM; if (constant == "snd-size") return ZMQ_SNDBUF; if (constant == "rcv-size") return ZMQ_RCVBUF; if (constant == "snd-more") return ZMQ_SNDMORE; if (constant == "rcv-more") return ZMQ_RCVMORE; if (constant == "linger") return ZMQ_LINGER; if (constant == "no-block") return ZMQ_DONTWAIT; if (constant == "snd-more no-block") return ZMQ_DONTWAIT|ZMQ_SNDMORE; return -1; } Socket::~Socket() { try { Close(); // NOLINT(clang-analyzer-optin.cplusplus.VirtualCall) } catch (SocketError& e) { LOG(error) << e.what(); } } } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */