From 953c4a75c8f7d6174be131b1b051310524393e2c Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Mon, 22 Nov 2021 12:42:09 +0100 Subject: [PATCH] refactor: deduplicate more zmq/shmem code --- fairmq/shmem/Socket.h | 60 +++++++------------------------------- fairmq/zeromq/Common.h | 63 ++++++++++++++++++++++++++++++++++++++++ fairmq/zeromq/Socket.h | 65 +++++++----------------------------------- 3 files changed, 83 insertions(+), 105 deletions(-) diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index 04a013d1..df742293 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -117,52 +117,12 @@ class Socket final : public fair::mq::Socket bool Bind(const std::string& address) override { - // LOG(info) << "binding socket " << fId << " on " << address; - if (zmq_bind(fSocket, address.c_str()) != 0) { - if (errno == EADDRINUSE) { - // do not print error in this case, this is handled by FairMQDevice in case no connection could be established after trying a number of random ports from a range. - return false; - } - LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); - return false; - } - return true; + return zmq::Bind(fSocket, address, fId); } bool Connect(const std::string& address) override { - // LOG(info) << "connecting socket " << fId << " on " << address; - if (zmq_connect(fSocket, address.c_str()) != 0) { - LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); - return false; - } - return true; - } - - bool ShouldRetry(int flags, int timeout, int& elapsed) const - { - if ((flags & ZMQ_DONTWAIT) == 0) { - if (timeout > 0) { - elapsed += fTimeout; - if (elapsed >= timeout) { - return false; - } - } - return true; - } else { - return false; - } - } - - int HandleErrors() const - { - if (zmq_errno() == ETERM) { - LOG(debug) << "Terminating socket " << fId; - return static_cast(TransferCode::error); - } else { - LOG(error) << "Failed transfer on socket " << fId << ", reason: " << zmq_strerror(errno); - return static_cast(TransferCode::error); - } + return zmq::Connect(fSocket, address, fId); } int64_t Send(MessagePtr& msg, int timeout = -1) override @@ -186,13 +146,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fManager.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } @@ -226,13 +186,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fManager.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } } @@ -277,13 +237,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fManager.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } @@ -333,13 +293,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fManager.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } diff --git a/fairmq/zeromq/Common.h b/fairmq/zeromq/Common.h index c433d0fd..c33f0f49 100644 --- a/fairmq/zeromq/Common.h +++ b/fairmq/zeromq/Common.h @@ -20,6 +20,69 @@ namespace fair::mq::zmq struct Error : std::runtime_error { using std::runtime_error::runtime_error; }; +inline bool Bind(void* socket, const std::string& address, const std::string& id) +{ + // LOG(debug) << "Binding socket " << id << " on " << address; + if (zmq_bind(socket, address.c_str()) != 0) { + if (errno == EADDRINUSE) { + // do not print error in this case, this is handled upstream in case no + // connection could be established after trying a number of random ports from a range. + return false; + } else if (errno == EACCES) { + // check if TCP port 1 was given, if yes then it will be handeled upstream, print debug only + size_t protocolPos = address.find(':'); + std::string protocol = address.substr(0, protocolPos); + if (protocol == "tcp") { + size_t portPos = address.rfind(':'); + std::string port = address.substr(portPos + 1); + if (port == "1") { + LOG(debug) << "Failed binding socket " << id << ", address: " << address << ", reason: " << zmq_strerror(errno); + return false; + } + } + } + LOG(error) << "Failed binding socket " << id << ", address: " << address << ", reason: " << zmq_strerror(errno); + return false; + } + return true; +} + +inline bool Connect(void* socket, const std::string& address, const std::string& id) +{ + // LOG(debug) << "Connecting socket " << id << " on " << address; + if (zmq_connect(socket, address.c_str()) != 0) { + LOG(error) << "Failed connecting socket " << id << ", address: " << address << ", reason: " << zmq_strerror(errno); + return false; + } + return true; +} + +inline bool ShouldRetry(int flags, int socketTimeout, int userTimeout, int& elapsed) +{ + if ((flags & ZMQ_DONTWAIT) == 0) { + if (userTimeout > 0) { + elapsed += socketTimeout; + if (elapsed >= userTimeout) { + return false; + } + } + return true; + } else { + return false; + } +} + +inline int HandleErrors(const std::string& id) +{ + if (zmq_errno() == ETERM) { + LOG(debug) << "Terminating socket " << id; + return static_cast(TransferCode::error); + } else { + LOG(error) << "Failed transfer on socket " << id << ", errno: " << errno << ", reason: " << zmq_strerror(errno); + return static_cast(TransferCode::error); + } +} + /// Lookup table for various zmq constants inline auto getConstant(std::string_view constant) -> int { diff --git a/fairmq/zeromq/Socket.h b/fairmq/zeromq/Socket.h index ed6465e9..7063662b 100644 --- a/fairmq/zeromq/Socket.h +++ b/fairmq/zeromq/Socket.h @@ -85,57 +85,12 @@ class Socket final : public fair::mq::Socket bool Bind(const std::string& address) override { - // LOG(debug) << "Binding socket " << fId << " on " << address; - - if (zmq_bind(fSocket, address.c_str()) != 0) { - if (errno == EADDRINUSE) { - // do not print error in this case, this is handled by FairMQDevice in case no - // connection could be established after trying a number of random ports from a range. - return false; - } - LOG(error) << "Failed binding socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); - return false; - } - - return true; + return zmq::Bind(fSocket, address, fId); } bool Connect(const std::string& address) override { - // LOG(debug) << "Connecting socket " << fId << " on " << address; - - if (zmq_connect(fSocket, address.c_str()) != 0) { - LOG(error) << "Failed connecting socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); - return false; - } - - return true; - } - - bool ShouldRetry(int flags, int timeout, int& elapsed) const - { - if ((flags & ZMQ_DONTWAIT) == 0) { - if (timeout > 0) { - elapsed += fTimeout; - if (elapsed >= timeout) { - return false; - } - } - return true; - } else { - return false; - } - } - - int HandleErrors() const - { - if (zmq_errno() == ETERM) { - LOG(debug) << "Terminating socket " << fId; - return static_cast(TransferCode::error); - } else { - LOG(error) << "Failed transfer on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno); - return static_cast(TransferCode::error); - } + return zmq::Connect(fSocket, address, fId); } int64_t Send(MessagePtr& msg, int timeout = -1) override @@ -157,13 +112,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fCtx.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } } @@ -187,13 +142,13 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fCtx.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { continue; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } } @@ -222,14 +177,14 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fCtx.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { repeat = true; break; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } } @@ -274,14 +229,14 @@ class Socket final : public fair::mq::Socket } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { if (fCtx.Interrupted()) { return static_cast(TransferCode::interrupted); - } else if (ShouldRetry(flags, timeout, elapsed)) { + } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { repeat = true; break; } else { return static_cast(TransferCode::timeout); } } else { - return HandleErrors(); + return zmq::HandleErrors(fId); } size_t moreSize = sizeof(more);