diff --git a/fairmq/CMakeLists.txt b/fairmq/CMakeLists.txt index 9e686a44..6dfab539 100644 --- a/fairmq/CMakeLists.txt +++ b/fairmq/CMakeLists.txt @@ -232,12 +232,6 @@ if(BUILD_FAIRMQ) SuboptParser.cxx plugins/config/Config.cxx plugins/Control.cxx - shmem/Message.cxx - shmem/Poller.cxx - shmem/Socket.cxx - shmem/TransportFactory.cxx - shmem/Manager.cxx - shmem/Region.cxx zeromq/FairMQMessageZMQ.cxx zeromq/FairMQPollerZMQ.cxx zeromq/FairMQUnmanagedRegionZMQ.cxx diff --git a/fairmq/shmem/Manager.cxx b/fairmq/shmem/Manager.cxx deleted file mode 100644 index 38e7de4e..00000000 --- a/fairmq/shmem/Manager.cxx +++ /dev/null @@ -1,344 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ - -#include "Manager.h" - -#include -#include - -#include -#include - -using namespace std; -using bie = ::boost::interprocess::interprocess_exception; -namespace bipc = ::boost::interprocess; -namespace bfs = ::boost::filesystem; -namespace bpt = ::boost::posix_time; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -Manager::Manager(const string& id, size_t size) - : fShmId(id) - , fSegmentName("fmq_" + fShmId + "_main") - , fManagementSegmentName("fmq_" + fShmId + "_mng") - , fSegment(bipc::open_or_create, fSegmentName.c_str(), size) - , fManagementSegment(bipc::open_or_create, fManagementSegmentName.c_str(), 655360) - , fShmVoidAlloc(fManagementSegment.get_segment_manager()) - , fShmMtx(bipc::open_or_create, string("fmq_" + fShmId + "_mtx").c_str()) - , fRegionEventsCV(bipc::open_or_create, string("fmq_" + fShmId + "_cv").c_str()) - , fRegionEventsSubscriptionActive(false) - , fDeviceCounter(nullptr) - , fRegionInfos(nullptr) - , fInterrupted(false) -{ - LOG(debug) << "created/opened shared memory segment '" << "fmq_" << fShmId << "_main" << "' of " << size << " bytes. Available are " << fSegment.get_free_memory() << " bytes."; - - fRegionInfos = fManagementSegment.find_or_construct(bipc::unique_instance)(fShmVoidAlloc); - // store info about the managed segment as region with id 0 - fRegionInfos->emplace(0, RegionInfo("", 0, 0, fShmVoidAlloc)); - - bipc::scoped_lock lock(fShmMtx); - - fDeviceCounter = fManagementSegment.find(bipc::unique_instance).first; - - if (fDeviceCounter) { - LOG(debug) << "device counter found, with value of " << fDeviceCounter->fCount << ". incrementing."; - (fDeviceCounter->fCount)++; - LOG(debug) << "incremented device counter, now: " << fDeviceCounter->fCount; - } else { - LOG(debug) << "no device counter found, creating one and initializing with 1"; - fDeviceCounter = fManagementSegment.construct(bipc::unique_instance)(1); - LOG(debug) << "initialized device counter with: " << fDeviceCounter->fCount; - } -} - -void Manager::StartMonitor(const string& id) -{ - try { - bipc::named_mutex monitorStatus(bipc::open_only, string("fmq_" + id + "_ms").c_str()); - LOG(debug) << "Found fairmq-shmmonitor for shared memory id " << id; - } catch (bie&) { - LOG(debug) << "no fairmq-shmmonitor found for shared memory id " << id << ", starting..."; - auto env = boost::this_process::environment(); - - vector ownPath = boost::this_process::path(); - - if (const char* fmqp = getenv("FAIRMQ_PATH")) { - ownPath.insert(ownPath.begin(), bfs::path(fmqp)); - } - - bfs::path p = boost::process::search_path("fairmq-shmmonitor", ownPath); - - if (!p.empty()) { - boost::process::spawn(p, "-x", "--shmid", id, "-d", "-t", "2000", env); - int numTries = 0; - do { - try { - bipc::named_mutex monitorStatus(bipc::open_only, string("fmq_" + id + "_ms").c_str()); - LOG(debug) << "Started fairmq-shmmonitor for shared memory id " << id; - break; - } catch (bie&) { - this_thread::sleep_for(chrono::milliseconds(10)); - if (++numTries > 1000) { - LOG(error) << "Did not get response from fairmq-shmmonitor after " << 10 * 1000 << " milliseconds. Exiting."; - throw runtime_error(tools::ToString("Did not get response from fairmq-shmmonitor after ", 10 * 1000, " milliseconds. Exiting.")); - } - } - } while (true); - } else { - LOG(warn) << "could not find fairmq-shmmonitor in the path"; - } - } -} - -pair Manager::CreateRegion(const size_t size, - const int64_t userFlags, - RegionCallback callback, - RegionBulkCallback bulkCallback, - const string& path /* = "" */, - int flags /* = 0 */) -{ - try { - - pair result; - - { - uint64_t id = 0; - bipc::scoped_lock lock(fShmMtx); - - RegionCounter* rc = fManagementSegment.find(bipc::unique_instance).first; - - if (rc) { - LOG(debug) << "region counter found, with value of " << rc->fCount << ". incrementing."; - (rc->fCount)++; - LOG(debug) << "incremented region counter, now: " << rc->fCount; - } else { - LOG(debug) << "no region counter found, creating one and initializing with 1"; - rc = fManagementSegment.construct(bipc::unique_instance)(1); - LOG(debug) << "initialized region counter with: " << rc->fCount; - } - - id = rc->fCount; - - auto it = fRegions.find(id); - if (it != fRegions.end()) { - LOG(error) << "Trying to create a region that already exists"; - return {nullptr, id}; - } - - // create region info - fRegionInfos->emplace(id, RegionInfo(path.c_str(), flags, userFlags, fShmVoidAlloc)); - - auto r = fRegions.emplace(id, tools::make_unique(*this, id, size, false, callback, bulkCallback, path, flags)); - // LOG(debug) << "Created region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; - - r.first->second->StartReceivingAcks(); - result.first = &(r.first->second->fRegion); - result.second = id; - } - fRegionEventsCV.notify_all(); - - return result; - - } catch (bipc::interprocess_exception& e) { - LOG(error) << "cannot create region. Already created/not cleaned up?"; - LOG(error) << e.what(); - throw; - } -} - -void Manager::RemoveRegion(const uint64_t id) -{ - { - bipc::scoped_lock lock(fShmMtx); - fRegions.erase(id); - fRegionInfos->at(id).fDestroyed = true; - } - fRegionEventsCV.notify_all(); -} - -Region* Manager::GetRegion(const uint64_t id) -{ - bipc::scoped_lock lock(fShmMtx); - return GetRegionUnsafe(id); -} - -Region* Manager::GetRegionUnsafe(const uint64_t id) -{ - // remote region could actually be a local one if a message originates from this device (has been sent out and returned) - auto it = fRegions.find(id); - if (it != fRegions.end()) { - return it->second.get(); - } else { - try { - // get region info - RegionInfo regionInfo = fRegionInfos->at(id); - string path = regionInfo.fPath.c_str(); - int flags = regionInfo.fFlags; - // LOG(debug) << "Located remote region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; - - auto r = fRegions.emplace(id, tools::make_unique(*this, id, 0, true, nullptr, nullptr, path, flags)); - return r.first->second.get(); - } catch (bie& e) { - LOG(warn) << "Could not get remote region for id: " << id; - return nullptr; - } - } -} - -vector Manager::GetRegionInfo() -{ - bipc::scoped_lock lock(fShmMtx); - return GetRegionInfoUnsafe(); -} - -vector Manager::GetRegionInfoUnsafe() -{ - vector result; - - for (const auto& e : *fRegionInfos) { - fair::mq::RegionInfo info; - info.id = e.first; - info.flags = e.second.fUserFlags; - info.event = e.second.fDestroyed ? RegionEvent::destroyed : RegionEvent::created; - if (info.id != 0) { - if (!e.second.fDestroyed) { - auto region = GetRegionUnsafe(info.id); - info.ptr = region->fRegion.get_address(); - info.size = region->fRegion.get_size(); - } else { - info.ptr = nullptr; - info.size = 0; - } - result.push_back(info); - } else { - if (!e.second.fDestroyed) { - info.ptr = fSegment.get_address(); - info.size = fSegment.get_size(); - } else { - info.ptr = nullptr; - info.size = 0; - } - result.push_back(info); - } - } - - return result; -} - -void Manager::SubscribeToRegionEvents(RegionEventCallback callback) -{ - if (fRegionEventThread.joinable()) { - LOG(debug) << "Already subscribed. Overwriting previous subscription."; - bipc::scoped_lock lock(fShmMtx); - fRegionEventsSubscriptionActive = false; - lock.unlock(); - fRegionEventsCV.notify_all(); - fRegionEventThread.join(); - } - bipc::scoped_lock lock(fShmMtx); - fRegionEventCallback = callback; - fRegionEventsSubscriptionActive = true; - fRegionEventThread = thread(&Manager::RegionEventsSubscription, this); -} - -bool Manager::SubscribedToRegionEvents() -{ - return fRegionEventThread.joinable(); -} - -void Manager::UnsubscribeFromRegionEvents() -{ - if (fRegionEventThread.joinable()) { - bipc::scoped_lock lock(fShmMtx); - fRegionEventsSubscriptionActive = false; - lock.unlock(); - fRegionEventsCV.notify_all(); - fRegionEventThread.join(); - lock.lock(); - fRegionEventCallback = nullptr; - } -} - -void Manager::RegionEventsSubscription() -{ - bipc::scoped_lock lock(fShmMtx); - while (fRegionEventsSubscriptionActive) { - auto infos = GetRegionInfoUnsafe(); - for (const auto& i : infos) { - auto el = fObservedRegionEvents.find(i.id); - if (el == fObservedRegionEvents.end()) { - fRegionEventCallback(i); - fObservedRegionEvents.emplace(i.id, i.event); - } else { - if (el->second == RegionEvent::created && i.event == RegionEvent::destroyed) { - fRegionEventCallback(i); - el->second = i.event; - } else { - // LOG(debug) << "ignoring event for id" << i.id << ":"; - // LOG(debug) << "incoming event: " << i.event; - // LOG(debug) << "stored event: " << el->second; - } - } - } - fRegionEventsCV.wait(lock); - } -} - -void Manager::RemoveSegments() -{ - if (bipc::shared_memory_object::remove(fSegmentName.c_str())) { - LOG(debug) << "successfully removed '" << fSegmentName << "' segment after the device has stopped."; - } else { - LOG(debug) << "did not remove " << fSegmentName << " segment after the device stopped. Already removed?"; - } - - if (bipc::shared_memory_object::remove(fManagementSegmentName.c_str())) { - LOG(debug) << "successfully removed '" << fManagementSegmentName << "' segment after the device has stopped."; - } else { - LOG(debug) << "did not remove '" << fManagementSegmentName << "' segment after the device stopped. Already removed?"; - } -} - -Manager::~Manager() -{ - bool lastRemoved = false; - - UnsubscribeFromRegionEvents(); - - try { - bipc::scoped_lock lock(fShmMtx); - - (fDeviceCounter->fCount)--; - - if (fDeviceCounter->fCount == 0) { - LOG(debug) << "last segment user, removing segment."; - - RemoveSegments(); - lastRemoved = true; - } else { - LOG(debug) << "other segment users present (" << fDeviceCounter->fCount << "), not removing it."; - } - } catch(bie& e) { - LOG(error) << "error while acquiring lock in Manager destructor: " << e.what(); - } - - if (lastRemoved) { - bipc::named_mutex::remove(string("fmq_" + fShmId + "_mtx").c_str()); - bipc::named_condition::remove(string("fmq_" + fShmId + "_cv").c_str()); - } -} - -} // namespace shmem -} // namespace mq -} // namespace fair diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 03fc77c9..3ad3c64c 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -19,19 +19,25 @@ #include "Region.h" #include -#include +#include +#include +#include -#include +#include +#include +#include #include #include #include +#include +#include // getenv #include #include #include #include #include -#include +#include // pair #include namespace fair @@ -45,52 +51,373 @@ struct SharedMemoryError : std::runtime_error { using std::runtime_error::runtim class Manager { - friend struct Region; - public: - Manager(const std::string& id, size_t size); + Manager(std::string id, std::string deviceId, size_t size) + : fShmId(std::move(id)) + , fDeviceId(std::move(deviceId)) + , fSegmentName("fmq_" + fShmId + "_main") + , fManagementSegmentName("fmq_" + fShmId + "_mng") + , fSegment(boost::interprocess::open_or_create, fSegmentName.c_str(), size) + , fManagementSegment(boost::interprocess::open_or_create, fManagementSegmentName.c_str(), 655360) + , fShmVoidAlloc(fManagementSegment.get_segment_manager()) + , fShmMtx(boost::interprocess::open_or_create, std::string("fmq_" + fShmId + "_mtx").c_str()) + , fRegionEventsCV(boost::interprocess::open_or_create, std::string("fmq_" + fShmId + "_cv").c_str()) + , fRegionEventsSubscriptionActive(false) + , fDeviceCounter(nullptr) + , fRegionInfos(nullptr) + , fInterrupted(false) + , fMsgCounter(0) + , fHeartbeatThread() + , fSendHeartbeats(true) + { + using namespace boost::interprocess; + LOG(debug) << "created/opened shared memory segment '" << "fmq_" << fShmId << "_main" << "' of " << size << " bytes. Available are " << fSegment.get_free_memory() << " bytes."; + + fRegionInfos = fManagementSegment.find_or_construct(unique_instance)(fShmVoidAlloc); + // store info about the managed segment as region with id 0 + fRegionInfos->emplace(0, RegionInfo("", 0, 0, fShmVoidAlloc)); + + boost::interprocess::scoped_lock lock(fShmMtx); + + fDeviceCounter = fManagementSegment.find(unique_instance).first; + + if (fDeviceCounter) { + LOG(debug) << "device counter found, with value of " << fDeviceCounter->fCount << ". incrementing."; + (fDeviceCounter->fCount)++; + LOG(debug) << "incremented device counter, now: " << fDeviceCounter->fCount; + } else { + LOG(debug) << "no device counter found, creating one and initializing with 1"; + fDeviceCounter = fManagementSegment.construct(unique_instance)(1); + LOG(debug) << "initialized device counter with: " << fDeviceCounter->fCount; + } + + fHeartbeatThread = std::thread(&Manager::SendHeartbeats, this); + } Manager() = delete; Manager(const Manager&) = delete; Manager operator=(const Manager&) = delete; - ~Manager(); + ~Manager() + { + using namespace boost::interprocess; + bool lastRemoved = false; + + UnsubscribeFromRegionEvents(); + + fSendHeartbeats = false; + fHeartbeatThread.join(); + + try { + boost::interprocess::scoped_lock lock(fShmMtx); + + (fDeviceCounter->fCount)--; + + if (fDeviceCounter->fCount == 0) { + LOG(debug) << "last segment user, removing segment."; + + RemoveSegments(); + lastRemoved = true; + } else { + LOG(debug) << "other segment users present (" << fDeviceCounter->fCount << "), not removing it."; + } + } catch(interprocess_exception& e) { + LOG(error) << "error while acquiring lock in Manager destructor: " << e.what(); + } + + if (lastRemoved) { + named_mutex::remove(std::string("fmq_" + fShmId + "_mtx").c_str()); + named_condition::remove(std::string("fmq_" + fShmId + "_cv").c_str()); + } + } boost::interprocess::managed_shared_memory& Segment() { return fSegment; } boost::interprocess::managed_shared_memory& ManagementSegment() { return fManagementSegment; } - static void StartMonitor(const std::string&); + static void StartMonitor(const std::string& id) + { + using namespace boost::interprocess; + try { + named_mutex monitorStatus(open_only, std::string("fmq_" + id + "_ms").c_str()); + LOG(debug) << "Found fairmq-shmmonitor for shared memory id " << id; + } catch (interprocess_exception&) { + LOG(debug) << "no fairmq-shmmonitor found for shared memory id " << id << ", starting..."; + auto env = boost::this_process::environment(); + + std::vector ownPath = boost::this_process::path(); + + if (const char* fmqp = getenv("FAIRMQ_PATH")) { + ownPath.insert(ownPath.begin(), boost::filesystem::path(fmqp)); + } + + boost::filesystem::path p = boost::process::search_path("fairmq-shmmonitor", ownPath); + + if (!p.empty()) { + boost::process::spawn(p, "-x", "--shmid", id, "-d", "-t", "2000", env); + int numTries = 0; + do { + try { + named_mutex monitorStatus(open_only, std::string("fmq_" + id + "_ms").c_str()); + LOG(debug) << "Started fairmq-shmmonitor for shared memory id " << id; + break; + } catch (interprocess_exception&) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (++numTries > 1000) { + LOG(error) << "Did not get response from fairmq-shmmonitor after " << 10 * 1000 << " milliseconds. Exiting."; + throw std::runtime_error(tools::ToString("Did not get response from fairmq-shmmonitor after ", 10 * 1000, " milliseconds. Exiting.")); + } + } + } while (true); + } else { + LOG(warn) << "could not find fairmq-shmmonitor in the path"; + } + } + } void Interrupt() { fInterrupted.store(true); } void Resume() { fInterrupted.store(false); } + void Reset() + { + if (fMsgCounter.load() != 0) { + LOG(error) << "Message counter during Reset expected to be 0, found: " << fMsgCounter.load(); + throw MessageError(tools::ToString("Message counter during Reset expected to be 0, found: ", fMsgCounter.load())); + } + } bool Interrupted() { return fInterrupted.load(); } - int GetDeviceCounter(); - int IncrementDeviceCounter(); - int DecrementDeviceCounter(); - std::pair CreateRegion(const size_t size, const int64_t userFlags, RegionCallback callback, RegionBulkCallback bulkCallback, const std::string& path = "", - int flags = 0); - Region* GetRegion(const uint64_t id); - Region* GetRegionUnsafe(const uint64_t id); - void RemoveRegion(const uint64_t id); + int flags = 0) + { + using namespace boost::interprocess; + try { + std::pair result; - std::vector GetRegionInfo(); - std::vector GetRegionInfoUnsafe(); - void SubscribeToRegionEvents(RegionEventCallback callback); - bool SubscribedToRegionEvents(); - void UnsubscribeFromRegionEvents(); - void RegionEventsSubscription(); + { + uint64_t id = 0; + boost::interprocess::scoped_lock lock(fShmMtx); - void RemoveSegments(); + RegionCounter* rc = fManagementSegment.find(unique_instance).first; + + if (rc) { + LOG(debug) << "region counter found, with value of " << rc->fCount << ". incrementing."; + (rc->fCount)++; + LOG(debug) << "incremented region counter, now: " << rc->fCount; + } else { + LOG(debug) << "no region counter found, creating one and initializing with 1"; + rc = fManagementSegment.construct(unique_instance)(1); + LOG(debug) << "initialized region counter with: " << rc->fCount; + } + + id = rc->fCount; + + auto it = fRegions.find(id); + if (it != fRegions.end()) { + LOG(error) << "Trying to create a region that already exists"; + return {nullptr, id}; + } + + // create region info + fRegionInfos->emplace(id, RegionInfo(path.c_str(), flags, userFlags, fShmVoidAlloc)); + + auto r = fRegions.emplace(id, tools::make_unique(fShmId, id, size, false, callback, bulkCallback, path, flags)); + // LOG(debug) << "Created region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; + + r.first->second->StartReceivingAcks(); + result.first = &(r.first->second->fRegion); + result.second = id; + } + fRegionEventsCV.notify_all(); + + return result; + + } catch (interprocess_exception& e) { + LOG(error) << "cannot create region. Already created/not cleaned up?"; + LOG(error) << e.what(); + throw; + } + } + + Region* GetRegion(const uint64_t id) + { + boost::interprocess::scoped_lock lock(fShmMtx); + return GetRegionUnsafe(id); + } + + Region* GetRegionUnsafe(const uint64_t id) + { + // remote region could actually be a local one if a message originates from this device (has been sent out and returned) + auto it = fRegions.find(id); + if (it != fRegions.end()) { + return it->second.get(); + } else { + try { + // get region info + RegionInfo regionInfo = fRegionInfos->at(id); + std::string path = regionInfo.fPath.c_str(); + int flags = regionInfo.fFlags; + // LOG(debug) << "Located remote region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; + + auto r = fRegions.emplace(id, tools::make_unique(fShmId, id, 0, true, nullptr, nullptr, path, flags)); + return r.first->second.get(); + } catch (boost::interprocess::interprocess_exception& e) { + LOG(warn) << "Could not get remote region for id: " << id; + return nullptr; + } + } + } + + void RemoveRegion(const uint64_t id) + { + { + boost::interprocess::scoped_lock lock(fShmMtx); + fRegions.erase(id); + fRegionInfos->at(id).fDestroyed = true; + } + fRegionEventsCV.notify_all(); + } + + std::vector GetRegionInfo() + { + boost::interprocess::scoped_lock lock(fShmMtx); + return GetRegionInfoUnsafe(); + } + + std::vector GetRegionInfoUnsafe() + { + std::vector result; + + for (const auto& e : *fRegionInfos) { + fair::mq::RegionInfo info; + info.id = e.first; + info.flags = e.second.fUserFlags; + info.event = e.second.fDestroyed ? RegionEvent::destroyed : RegionEvent::created; + if (info.id != 0) { + if (!e.second.fDestroyed) { + auto region = GetRegionUnsafe(info.id); + info.ptr = region->fRegion.get_address(); + info.size = region->fRegion.get_size(); + } else { + info.ptr = nullptr; + info.size = 0; + } + result.push_back(info); + } else { + if (!e.second.fDestroyed) { + info.ptr = fSegment.get_address(); + info.size = fSegment.get_size(); + } else { + info.ptr = nullptr; + info.size = 0; + } + result.push_back(info); + } + } + + return result; + } + + void SubscribeToRegionEvents(RegionEventCallback callback) + { + if (fRegionEventThread.joinable()) { + LOG(debug) << "Already subscribed. Overwriting previous subscription."; + boost::interprocess::scoped_lock lock(fShmMtx); + fRegionEventsSubscriptionActive = false; + lock.unlock(); + fRegionEventsCV.notify_all(); + fRegionEventThread.join(); + } + boost::interprocess::scoped_lock lock(fShmMtx); + fRegionEventCallback = callback; + fRegionEventsSubscriptionActive = true; + fRegionEventThread = std::thread(&Manager::RegionEventsSubscription, this); + } + + bool SubscribedToRegionEvents() { return fRegionEventThread.joinable(); } + + void UnsubscribeFromRegionEvents() + { + if (fRegionEventThread.joinable()) { + boost::interprocess::scoped_lock lock(fShmMtx); + fRegionEventsSubscriptionActive = false; + lock.unlock(); + fRegionEventsCV.notify_all(); + fRegionEventThread.join(); + lock.lock(); + fRegionEventCallback = nullptr; + } + } + + void RegionEventsSubscription() + { + boost::interprocess::scoped_lock lock(fShmMtx); + while (fRegionEventsSubscriptionActive) { + auto infos = GetRegionInfoUnsafe(); + for (const auto& i : infos) { + auto el = fObservedRegionEvents.find(i.id); + if (el == fObservedRegionEvents.end()) { + fRegionEventCallback(i); + fObservedRegionEvents.emplace(i.id, i.event); + } else { + if (el->second == RegionEvent::created && i.event == RegionEvent::destroyed) { + fRegionEventCallback(i); + el->second = i.event; + } else { + // LOG(debug) << "ignoring event for id" << i.id << ":"; + // LOG(debug) << "incoming event: " << i.event; + // LOG(debug) << "stored event: " << el->second; + } + } + } + fRegionEventsCV.wait(lock); + } + } + + void IncrementMsgCounter() { ++fMsgCounter; } + void DecrementMsgCounter() { --fMsgCounter; } + + void RemoveSegments() + { + using namespace boost::interprocess; + if (shared_memory_object::remove(fSegmentName.c_str())) { + LOG(debug) << "successfully removed '" << fSegmentName << "' segment after the device has stopped."; + } else { + LOG(debug) << "did not remove " << fSegmentName << " segment after the device stopped. Already removed?"; + } + + if (shared_memory_object::remove(fManagementSegmentName.c_str())) { + LOG(debug) << "successfully removed '" << fManagementSegmentName << "' segment after the device has stopped."; + } else { + LOG(debug) << "did not remove '" << fManagementSegmentName << "' segment after the device stopped. Already removed?"; + } + } + + void SendHeartbeats() + { + std::string controlQueueName("fmq_" + fShmId + "_cq"); + while (fSendHeartbeats) { + try { + boost::interprocess::message_queue mq(boost::interprocess::open_only, controlQueueName.c_str()); + boost::posix_time::ptime sndTill = boost::posix_time::microsec_clock::universal_time() + boost::posix_time::milliseconds(100); + if (mq.timed_send(fDeviceId.c_str(), fDeviceId.size(), 0, sndTill)) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } else { + LOG(debug) << "control queue timeout"; + } + } catch (boost::interprocess::interprocess_exception& ie) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + // LOG(warn) << "no " << controlQueueName << " found"; + } + } + } private: std::string fShmId; + std::string fDeviceId; std::string fSegmentName; std::string fManagementSegmentName; boost::interprocess::managed_shared_memory fSegment; @@ -109,6 +436,10 @@ class Manager std::unordered_map> fRegions; std::atomic fInterrupted; + std::atomic fMsgCounter; // TODO: find a better lifetime solution instead of the counter + + std::thread fHeartbeatThread; + std::atomic fSendHeartbeats; }; } // namespace shmem diff --git a/fairmq/shmem/Message.cxx b/fairmq/shmem/Message.cxx deleted file mode 100644 index 2536da72..00000000 --- a/fairmq/shmem/Message.cxx +++ /dev/null @@ -1,248 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ - -#include "Region.h" -#include "Message.h" -#include "UnmanagedRegion.h" -#include "TransportFactory.h" - -#include - -#include - -#include - -using namespace std; - -namespace bipc = ::boost::interprocess; -namespace bpt = ::boost::posix_time; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -Message::Message(Manager& manager, FairMQTransportFactory* factory) - : fair::mq::Message{factory} - , fManager(manager) - , fQueued(false) - , fMeta{0, 0, 0, -1} - , fRegionPtr(nullptr) - , fLocalPtr(nullptr) -{ - static_cast(GetTransport())->IncrementMsgCounter(); -} - -Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* factory) - : fair::mq::Message{factory} - , fManager(manager) - , fQueued(false) - , fMeta{0, 0, 0, -1} - , fRegionPtr(nullptr) - , fLocalPtr(nullptr) -{ - InitializeChunk(size); - static_cast(GetTransport())->IncrementMsgCounter(); -} - -Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory) - : fair::mq::Message{factory} - , fManager(manager) - , fQueued(false) - , fMeta{hdr} - , fRegionPtr(nullptr) - , fLocalPtr(nullptr) -{ - static_cast(GetTransport())->IncrementMsgCounter(); -} - -Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint, FairMQTransportFactory* factory) - : fair::mq::Message{factory} - , fManager(manager) - , fQueued(false) - , fMeta{0, 0, 0, -1} - , fRegionPtr(nullptr) - , fLocalPtr(nullptr) -{ - if (InitializeChunk(size)) { - std::memcpy(fLocalPtr, data, size); - if (ffn) { - ffn(data, hint); - } else { - free(data); - } - } - static_cast(GetTransport())->IncrementMsgCounter(); -} - -Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint, FairMQTransportFactory* factory) - : fair::mq::Message{factory} - , fManager(manager) - , fQueued(false) - , fMeta{size, static_cast(region.get())->fRegionId, reinterpret_cast(hint), -1} - , fRegionPtr(nullptr) - , fLocalPtr(static_cast(data)) -{ - if (reinterpret_cast(data) >= reinterpret_cast(region->GetData()) || - reinterpret_cast(data) <= reinterpret_cast(region->GetData()) + region->GetSize()) { - fMeta.fHandle = (bipc::managed_shared_memory::handle_t)(reinterpret_cast(data) - reinterpret_cast(region->GetData())); - } else { - LOG(error) << "trying to create region message with data from outside the region"; - throw runtime_error("trying to create region message with data from outside the region"); - } - static_cast(GetTransport())->IncrementMsgCounter(); -} - -bool Message::InitializeChunk(const size_t size) -{ - while (fMeta.fHandle < 0) { - try { - bipc::managed_shared_memory::size_type actualSize = size; - char* hint = 0; // unused for bipc::allocate_new - fLocalPtr = fManager.Segment().allocation_command(bipc::allocate_new, size, actualSize, hint); - } catch (bipc::bad_alloc& ba) { - // LOG(warn) << "Shared memory full..."; - this_thread::sleep_for(chrono::milliseconds(50)); - if (fManager.Interrupted()) { - return false; - } else { - continue; - } - } - fMeta.fHandle = fManager.Segment().get_handle_from_address(fLocalPtr); - } - - fMeta.fSize = size; - return true; -} - -void Message::Rebuild() -{ - CloseMessage(); - fQueued = false; -} - -void Message::Rebuild(const size_t size) -{ - CloseMessage(); - fQueued = false; - InitializeChunk(size); -} - -void Message::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) -{ - CloseMessage(); - fQueued = false; - - if (InitializeChunk(size)) { - std::memcpy(fLocalPtr, data, size); - if (ffn) { - ffn(data, hint); - } else { - free(data); - } - } -} - -void* Message::GetData() const -{ - if (!fLocalPtr) { - if (fMeta.fRegionId == 0) { - if (fMeta.fSize > 0) { - fLocalPtr = reinterpret_cast(fManager.Segment().get_address_from_handle(fMeta.fHandle)); - } else { - fLocalPtr = nullptr; - } - } else { - fRegionPtr = fManager.GetRegion(fMeta.fRegionId); - if (fRegionPtr) { - fLocalPtr = reinterpret_cast(fRegionPtr->fRegion.get_address()) + fMeta.fHandle; - } else { - // LOG(warn) << "could not get pointer from a region message"; - fLocalPtr = nullptr; - } - } - } - - return fLocalPtr; -} - -bool Message::SetUsedSize(const size_t size) -{ - if (size == fMeta.fSize) { - return true; - } else if (size <= fMeta.fSize) { - try { - bipc::managed_shared_memory::size_type shrunkSize = size; - fLocalPtr = fManager.Segment().allocation_command(bipc::shrink_in_place, fMeta.fSize + 128, shrunkSize, fLocalPtr); - fMeta.fSize = size; - return true; - } catch (bipc::interprocess_exception& e) { - LOG(info) << "could not set used size: " << e.what(); - return false; - } - } else { - LOG(error) << "cannot set used size higher than original."; - return false; - } -} - -void Message::Copy(const fair::mq::Message& msg) -{ - if (fMeta.fHandle < 0) { - bipc::managed_shared_memory::handle_t otherHandle = static_cast(msg).fMeta.fHandle; - if (otherHandle) { - if (InitializeChunk(msg.GetSize())) { - std::memcpy(GetData(), msg.GetData(), msg.GetSize()); - } - } else { - LOG(error) << "copy fail: source message not initialized!"; - } - } else { - LOG(error) << "copy fail: target message already initialized!"; - } -} - -void Message::CloseMessage() -{ - if (fMeta.fHandle >= 0 && !fQueued) { - if (fMeta.fRegionId == 0) { - fManager.Segment().deallocate(fManager.Segment().get_address_from_handle(fMeta.fHandle)); - fMeta.fHandle = -1; - } else { - if (!fRegionPtr) { - fRegionPtr = fManager.GetRegion(fMeta.fRegionId); - } - - if (fRegionPtr) { - fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint}); - } else { - LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; - } - } - } - - static_cast(GetTransport())->DecrementMsgCounter(); -} - -Message::~Message() -{ - try { - CloseMessage(); - } catch(SharedMemoryError& sme) { - LOG(error) << "error closing message: " << sme.what(); - } catch(bipc::lock_exception& le) { - LOG(error) << "error closing message: " << le.what(); - } -} - -} -} -} diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index 7646556e..759537db 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -10,7 +10,10 @@ #include "Common.h" #include "Manager.h" +#include "Region.h" +#include "UnmanagedRegion.h" +#include #include #include @@ -33,30 +36,181 @@ class Message final : public fair::mq::Message friend class Socket; public: - Message(Manager& manager, FairMQTransportFactory* factory = nullptr); - Message(Manager& manager, const size_t size, FairMQTransportFactory* factory = nullptr); - Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr); - Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr); + Message(Manager& manager, FairMQTransportFactory* factory = nullptr) + : fair::mq::Message{factory} + , fManager(manager) + , fQueued(false) + , fMeta{0, 0, 0, -1} + , fRegionPtr(nullptr) + , fLocalPtr(nullptr) + { + fManager.IncrementMsgCounter(); + } - Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory = nullptr); + Message(Manager& manager, const size_t size, FairMQTransportFactory* factory = nullptr) + : fair::mq::Message{factory} + , fManager(manager) + , fQueued(false) + , fMeta{0, 0, 0, -1} + , fRegionPtr(nullptr) + , fLocalPtr(nullptr) + { + InitializeChunk(size); + fManager.IncrementMsgCounter(); + } + + Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr) + : fair::mq::Message{factory} + , fManager(manager) + , fQueued(false) + , fMeta{0, 0, 0, -1} + , fRegionPtr(nullptr) + , fLocalPtr(nullptr) + { + if (InitializeChunk(size)) { + std::memcpy(fLocalPtr, data, size); + if (ffn) { + ffn(data, hint); + } else { + free(data); + } + } + fManager.IncrementMsgCounter(); + } + + Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr) + : fair::mq::Message{factory} + , fManager(manager) + , fQueued(false) + , fMeta{size, static_cast(region.get())->fRegionId, reinterpret_cast(hint), -1} + , fRegionPtr(nullptr) + , fLocalPtr(static_cast(data)) + { + if (reinterpret_cast(data) >= reinterpret_cast(region->GetData()) || + reinterpret_cast(data) <= reinterpret_cast(region->GetData()) + region->GetSize()) { + fMeta.fHandle = (boost::interprocess::managed_shared_memory::handle_t)(reinterpret_cast(data) - reinterpret_cast(region->GetData())); + } else { + LOG(error) << "trying to create region message with data from outside the region"; + throw std::runtime_error("trying to create region message with data from outside the region"); + } + fManager.IncrementMsgCounter(); + } + + Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory = nullptr) + : fair::mq::Message{factory} + , fManager(manager) + , fQueued(false) + , fMeta{hdr} + , fRegionPtr(nullptr) + , fLocalPtr(nullptr) + { + fManager.IncrementMsgCounter(); + } Message(const Message&) = delete; Message operator=(const Message&) = delete; - void Rebuild() override; - void Rebuild(const size_t size) override; - void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; + void Rebuild() override + { + CloseMessage(); + fQueued = false; + } + + void Rebuild(const size_t size) override + { + CloseMessage(); + fQueued = false; + InitializeChunk(size); + } + + void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override + { + CloseMessage(); + fQueued = false; + + if (InitializeChunk(size)) { + std::memcpy(fLocalPtr, data, size); + if (ffn) { + ffn(data, hint); + } else { + free(data); + } + } + } + + void* GetData() const override + { + if (!fLocalPtr) { + if (fMeta.fRegionId == 0) { + if (fMeta.fSize > 0) { + fLocalPtr = reinterpret_cast(fManager.Segment().get_address_from_handle(fMeta.fHandle)); + } else { + fLocalPtr = nullptr; + } + } else { + fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + if (fRegionPtr) { + fLocalPtr = reinterpret_cast(fRegionPtr->fRegion.get_address()) + fMeta.fHandle; + } else { + // LOG(warn) << "could not get pointer from a region message"; + fLocalPtr = nullptr; + } + } + } + + return fLocalPtr; + } - void* GetData() const override; size_t GetSize() const override { return fMeta.fSize; } - bool SetUsedSize(const size_t size) override; + bool SetUsedSize(const size_t size) override + { + if (size == fMeta.fSize) { + return true; + } else if (size <= fMeta.fSize) { + try { + boost::interprocess::managed_shared_memory::size_type shrunkSize = size; + fLocalPtr = fManager.Segment().allocation_command(boost::interprocess::shrink_in_place, fMeta.fSize + 128, shrunkSize, fLocalPtr); + fMeta.fSize = size; + return true; + } catch (boost::interprocess::interprocess_exception& e) { + LOG(info) << "could not set used size: " << e.what(); + return false; + } + } else { + LOG(error) << "cannot set used size higher than original."; + return false; + } + } Transport GetType() const override { return fair::mq::Transport::SHM; } - void Copy(const fair::mq::Message& msg) override; + void Copy(const fair::mq::Message& msg) override + { + if (fMeta.fHandle < 0) { + boost::interprocess::managed_shared_memory::handle_t otherHandle = static_cast(msg).fMeta.fHandle; + if (otherHandle) { + if (InitializeChunk(msg.GetSize())) { + std::memcpy(GetData(), msg.GetData(), msg.GetSize()); + } + } else { + LOG(error) << "copy fail: source message not initialized!"; + } + } else { + LOG(error) << "copy fail: target message already initialized!"; + } + } - ~Message() override; + ~Message() override + { + try { + CloseMessage(); + } catch(SharedMemoryError& sme) { + LOG(error) << "error closing message: " << sme.what(); + } catch(boost::interprocess::lock_exception& le) { + LOG(error) << "error closing message: " << le.what(); + } + } private: Manager& fManager; @@ -65,8 +219,50 @@ class Message final : public fair::mq::Message mutable Region* fRegionPtr; mutable char* fLocalPtr; - bool InitializeChunk(const size_t size); - void CloseMessage(); + bool InitializeChunk(const size_t size) + { + while (fMeta.fHandle < 0) { + try { + boost::interprocess::managed_shared_memory::size_type actualSize = size; + char* hint = 0; // unused for boost::interprocess::allocate_new + fLocalPtr = fManager.Segment().allocation_command(boost::interprocess::allocate_new, size, actualSize, hint); + } catch (boost::interprocess::bad_alloc& ba) { + // LOG(warn) << "Shared memory full..."; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + if (fManager.Interrupted()) { + return false; + } else { + continue; + } + } + fMeta.fHandle = fManager.Segment().get_handle_from_address(fLocalPtr); + } + + fMeta.fSize = size; + return true; + } + + void CloseMessage() + { + if (fMeta.fHandle >= 0 && !fQueued) { + if (fMeta.fRegionId == 0) { + fManager.Segment().deallocate(fManager.Segment().get_address_from_handle(fMeta.fHandle)); + fMeta.fHandle = -1; + } else { + if (!fRegionPtr) { + fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + } + + if (fRegionPtr) { + fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint}); + } else { + LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; + } + } + } + + fManager.DecrementMsgCounter(); + } }; } diff --git a/fairmq/shmem/Poller.cxx b/fairmq/shmem/Poller.cxx deleted file mode 100644 index a70bf2eb..00000000 --- a/fairmq/shmem/Poller.cxx +++ /dev/null @@ -1,222 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ -/** - * Poller.cxx - * - * @since 2014-01-23 - * @author A. Rybalchenko - */ - -#include "Poller.h" -#include "Socket.h" - -#include - -#include - -using namespace std; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -Poller::Poller(const vector& channels) - : fItems() - , fNumItems(0) - , fOffsetMap() -{ - fNumItems = channels.size(); - fItems = new zmq_pollitem_t[fNumItems]; - - for (int i = 0; i < fNumItems; ++i) - { - fItems[i].socket = static_cast(&(channels.at(i).GetSocket()))->GetSocket(); - fItems[i].fd = 0; - fItems[i].revents = 0; - - int type = 0; - size_t size = sizeof(type); - zmq_getsockopt(static_cast(&(channels.at(i).GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); - - SetItemEvents(fItems[i], type); - } -} - -Poller::Poller(const vector& channels) - : fItems() - , fNumItems(0) - , fOffsetMap() -{ - fNumItems = channels.size(); - fItems = new zmq_pollitem_t[fNumItems]; - - for (int i = 0; i < fNumItems; ++i) - { - fItems[i].socket = static_cast(&(channels.at(i)->GetSocket()))->GetSocket(); - fItems[i].fd = 0; - fItems[i].revents = 0; - - int type = 0; - size_t size = sizeof(type); - zmq_getsockopt(static_cast(&(channels.at(i)->GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); - - SetItemEvents(fItems[i], type); - } -} - -Poller::Poller(const unordered_map>& channelsMap, const vector& channelList) - : fItems() - , fNumItems(0) - , fOffsetMap() -{ - try - { - int offset = 0; - // calculate offsets and the total size of the poll item set - for (string channel : channelList) - { - fOffsetMap[channel] = offset; - offset += channelsMap.at(channel).size(); - fNumItems += channelsMap.at(channel).size(); - } - - fItems = new zmq_pollitem_t[fNumItems]; - - int index = 0; - for (string channel : channelList) - { - for (unsigned int i = 0; i < channelsMap.at(channel).size(); ++i) - { - index = fOffsetMap[channel] + i; - - fItems[index].socket = static_cast(&(channelsMap.at(channel).at(i).GetSocket()))->GetSocket(); - fItems[index].fd = 0; - fItems[index].revents = 0; - - int type = 0; - size_t size = sizeof(type); - zmq_getsockopt(static_cast(&(channelsMap.at(channel).at(i).GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); - - SetItemEvents(fItems[index], type); - } - } - } - catch (const out_of_range& oor) - { - LOG(error) << "at least one of the provided channel keys for poller initialization is invalid"; - LOG(error) << "out of range error: " << oor.what() << '\n'; - exit(EXIT_FAILURE); - } -} - -void Poller::SetItemEvents(zmq_pollitem_t& item, const int type) -{ - if (type == ZMQ_REQ || type == ZMQ_REP || type == ZMQ_PAIR || type == ZMQ_DEALER || type == ZMQ_ROUTER) - { - item.events = ZMQ_POLLIN|ZMQ_POLLOUT; - } - else if (type == ZMQ_PUSH || type == ZMQ_PUB || type == ZMQ_XPUB) - { - item.events = ZMQ_POLLOUT; - } - else if (type == ZMQ_PULL || type == ZMQ_SUB || type == ZMQ_XSUB) - { - item.events = ZMQ_POLLIN; - } - else - { - LOG(error) << "invalid poller configuration, exiting."; - exit(EXIT_FAILURE); - } -} - -void Poller::Poll(const int timeout) -{ - if (zmq_poll(fItems, fNumItems, timeout) < 0) - { - if (errno == ETERM) - { - LOG(debug) << "polling exited, reason: " << zmq_strerror(errno); - } - else - { - LOG(error) << "polling failed, reason: " << zmq_strerror(errno); - throw runtime_error("polling failed"); - } - } -} - -bool Poller::CheckInput(const int index) -{ - if (fItems[index].revents & ZMQ_POLLIN) - { - return true; - } - - return false; -} - -bool Poller::CheckOutput(const int index) -{ - if (fItems[index].revents & ZMQ_POLLOUT) - { - return true; - } - - return false; -} - -bool Poller::CheckInput(const string& channelKey, const int index) -{ - try - { - if (fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLIN) - { - return true; - } - - return false; - } - catch (const out_of_range& oor) - { - LOG(error) << "invalid channel key: \"" << channelKey << "\""; - LOG(error) << "out of range error: " << oor.what() << '\n'; - exit(EXIT_FAILURE); - } -} - -bool Poller::CheckOutput(const string& channelKey, const int index) -{ - try - { - if (fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLOUT) - { - return true; - } - - return false; - } - catch (const out_of_range& oor) - { - LOG(error) << "Invalid channel key: \"" << channelKey << "\""; - LOG(error) << "out of range error: " << oor.what() << '\n'; - exit(EXIT_FAILURE); - } -} - -Poller::~Poller() -{ - delete[] fItems; -} - -} -} -} diff --git a/fairmq/shmem/Poller.h b/fairmq/shmem/Poller.h index d8498b25..016495b0 100644 --- a/fairmq/shmem/Poller.h +++ b/fairmq/shmem/Poller.h @@ -8,42 +8,183 @@ #ifndef FAIR_MQ_SHMEM_POLLER_H_ #define FAIR_MQ_SHMEM_POLLER_H_ -#include +#include "Socket.h" -#include #include - -#include +#include +#include #include +#include +#include class FairMQChannel; -namespace fair -{ -namespace mq -{ -namespace shmem -{ +namespace fair { +namespace mq { +namespace shmem { class Poller final : public fair::mq::Poller { public: - Poller(const std::vector& channels); - Poller(const std::vector& channels); - Poller(const std::unordered_map>& channelsMap, const std::vector& channelList); + Poller(const std::vector& channels) + : fItems() + , fNumItems(0) + , fOffsetMap() + { + fNumItems = channels.size(); + fItems = new zmq_pollitem_t[fNumItems]; + + for (int i = 0; i < fNumItems; ++i) { + fItems[i].socket = static_cast(&(channels.at(i).GetSocket()))->GetSocket(); + fItems[i].fd = 0; + fItems[i].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(static_cast(&(channels.at(i).GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[i], type); + } + } + + Poller(const std::vector& channels) + : fItems() + , fNumItems(0) + , fOffsetMap() + { + fNumItems = channels.size(); + fItems = new zmq_pollitem_t[fNumItems]; + + for (int i = 0; i < fNumItems; ++i) { + fItems[i].socket = static_cast(&(channels.at(i)->GetSocket()))->GetSocket(); + fItems[i].fd = 0; + fItems[i].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(static_cast(&(channels.at(i)->GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[i], type); + } + } + + Poller(const std::unordered_map>& channelsMap, const std::vector& channelList) + : fItems() + , fNumItems(0) + , fOffsetMap() + { + try { + int offset = 0; + // calculate offsets and the total size of the poll item set + for (std::string channel : channelList) { + fOffsetMap[channel] = offset; + offset += channelsMap.at(channel).size(); + fNumItems += channelsMap.at(channel).size(); + } + + fItems = new zmq_pollitem_t[fNumItems]; + + int index = 0; + for (std::string channel : channelList) { + for (unsigned int i = 0; i < channelsMap.at(channel).size(); ++i) { + index = fOffsetMap[channel] + i; + + fItems[index].socket = static_cast(&(channelsMap.at(channel).at(i).GetSocket()))->GetSocket(); + fItems[index].fd = 0; + fItems[index].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(static_cast(&(channelsMap.at(channel).at(i).GetSocket()))->GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[index], type); + } + } + } catch (const std::out_of_range& oor) { + LOG(error) << "at least one of the provided channel keys for poller initialization is invalid"; + LOG(error) << "out of range error: " << oor.what() << '\n'; + exit(EXIT_FAILURE); + } + } Poller(const Poller&) = delete; Poller operator=(const Poller&) = delete; - void SetItemEvents(zmq_pollitem_t& item, const int type); + void SetItemEvents(zmq_pollitem_t& item, const int type) + { + if (type == ZMQ_REQ || type == ZMQ_REP || type == ZMQ_PAIR || type == ZMQ_DEALER || type == ZMQ_ROUTER) { + item.events = ZMQ_POLLIN | ZMQ_POLLOUT; + } else if (type == ZMQ_PUSH || type == ZMQ_PUB || type == ZMQ_XPUB) { + item.events = ZMQ_POLLOUT; + } else if (type == ZMQ_PULL || type == ZMQ_SUB || type == ZMQ_XSUB) { + item.events = ZMQ_POLLIN; + } else { + LOG(error) << "invalid poller configuration, exiting."; + exit(EXIT_FAILURE); + } + } - void Poll(const int timeout) override; - bool CheckInput(const int index) override; - bool CheckOutput(const int index) override; - bool CheckInput(const std::string& channelKey, const int index) override; - bool CheckOutput(const std::string& channelKey, const int index) override; + void Poll(const int timeout) override + { + if (zmq_poll(fItems, fNumItems, timeout) < 0) { + if (errno == ETERM) { + LOG(debug) << "polling exited, reason: " << zmq_strerror(errno); + } else { + LOG(error) << "polling failed, reason: " << zmq_strerror(errno); + throw std::runtime_error("polling failed"); + } + } + } - ~Poller() override; + bool CheckInput(const int index) override + { + if (fItems[index].revents & ZMQ_POLLIN) { + return true; + } + + return false; + } + + bool CheckOutput(const int index) override + { + if (fItems[index].revents & ZMQ_POLLOUT) { + return true; + } + + return false; + } + + bool CheckInput(const std::string& channelKey, const int index) override + { + try { + if (fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLIN) { + return true; + } + + return false; + } catch (const std::out_of_range& oor) { + LOG(error) << "invalid channel key: \"" << channelKey << "\""; + LOG(error) << "out of range error: " << oor.what() << '\n'; + exit(EXIT_FAILURE); + } + } + + bool CheckOutput(const std::string& channelKey, const int index) override + { + try { + if (fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLOUT) { + return true; + } + + return false; + } catch (const std::out_of_range& oor) { + LOG(error) << "Invalid channel key: \"" << channelKey << "\""; + LOG(error) << "out of range error: " << oor.what() << '\n'; + exit(EXIT_FAILURE); + } + } + + ~Poller() override { delete[] fItems; } private: zmq_pollitem_t* fItems; @@ -52,8 +193,8 @@ class Poller final : public fair::mq::Poller std::unordered_map fOffsetMap; }; -} -} -} +} // namespace shmem +} // namespace mq +} // namespace fair #endif /* FAIR_MQ_SHMEM_POLLER_H_ */ diff --git a/fairmq/shmem/Region.cxx b/fairmq/shmem/Region.cxx deleted file mode 100644 index a40ba57d..00000000 --- a/fairmq/shmem/Region.cxx +++ /dev/null @@ -1,226 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ - -#include "Region.h" -#include "Common.h" -#include "Manager.h" - -#include -#include - -#include -#include -#include - -#include -#include -#include - -using namespace std; - -namespace bipc = ::boost::interprocess; -namespace bpt = ::boost::posix_time; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -Region::Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback, RegionBulkCallback bulkCallback, const string& path, int flags) - : fManager(manager) - , fRemote(remote) - , fStop(false) - , fName("fmq_" + fManager.fShmId + "_rg_" + to_string(id)) - , fQueueName("fmq_" + fManager.fShmId + "_rgq_" + to_string(id)) - , fShmemObject() - , fFile(nullptr) - , fFileMapping() - , fQueue(nullptr) - , fReceiveAcksWorker() - , fSendAcksWorker() - , fCallback(callback) - , fBulkCallback(bulkCallback) -{ - if (path != "") { - fName = string(path + fName); - - if (!fRemote) { - // create a file - filebuf fbuf; - if (fbuf.open(fName, ios_base::in | ios_base::out | ios_base::trunc | ios_base::binary)) { - // set the size - fbuf.pubseekoff(size - 1, ios_base::beg); - fbuf.sputc(0); - } - } - - fFile = fopen(fName.c_str(), "r+"); - - if (!fFile) { - LOG(error) << "Failed to initialize file: " << fName; - LOG(error) << "errno: " << errno << ": " << strerror(errno); - throw runtime_error(tools::ToString("Failed to initialize file for shared memory region: ", strerror(errno))); - } - fFileMapping = bipc::file_mapping(fName.c_str(), bipc::read_write); - LOG(debug) << "shmem: initialized file: " << fName; - fRegion = bipc::mapped_region(fFileMapping, bipc::read_write, 0, size, 0, flags); - } else { - if (fRemote) { - fShmemObject = bipc::shared_memory_object(bipc::open_only, fName.c_str(), bipc::read_write); - } else { - fShmemObject = bipc::shared_memory_object(bipc::create_only, fName.c_str(), bipc::read_write); - fShmemObject.truncate(size); - } - fRegion = bipc::mapped_region(fShmemObject, bipc::read_write, 0, 0, 0, flags); - } - - InitializeQueues(); - StartSendingAcks(); - LOG(debug) << "shmem: initialized region: " << fName; -} - -void Region::InitializeQueues() -{ - if (fRemote) { - fQueue = tools::make_unique(bipc::open_only, fQueueName.c_str()); - } else { - fQueue = tools::make_unique(bipc::create_only, fQueueName.c_str(), 1024, fAckBunchSize * sizeof(RegionBlock)); - } - LOG(debug) << "shmem: initialized region queue: " << fQueueName; -} - -void Region::StartSendingAcks() -{ - fSendAcksWorker = thread(&Region::SendAcks, this); -} - -void Region::StartReceivingAcks() -{ - fReceiveAcksWorker = thread(&Region::ReceiveAcks, this); -} - -void Region::ReceiveAcks() -{ - unsigned int priority; - bipc::message_queue::size_type recvdSize; - unique_ptr blocks = tools::make_unique(fAckBunchSize); - std::vector result; - result.reserve(fAckBunchSize); - - while (!fStop) { // end thread condition (should exist until region is destroyed) - auto rcvTill = bpt::microsec_clock::universal_time() + bpt::milliseconds(500); - - while (fQueue->timed_receive(blocks.get(), fAckBunchSize * sizeof(RegionBlock), recvdSize, priority, rcvTill)) { - // LOG(debug) << "received: " << block.fHandle << " " << block.fSize << " " << block.fMessageId; - const auto numBlocks = recvdSize / sizeof(RegionBlock); - if (fBulkCallback) { - result.clear(); - for (size_t i = 0; i < numBlocks; i++) { - result.emplace_back(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); - } - fBulkCallback(result); - } else if (fCallback) { - for (size_t i = 0; i < numBlocks; i++) { - fCallback(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); - } - } - } - } // while !fStop - - LOG(debug) << "ReceiveAcks() worker for " << fName << " leaving."; -} - -void Region::ReleaseBlock(const RegionBlock &block) -{ - unique_lock lock(fBlockMtx); - - fBlocksToFree.emplace_back(block); - - if (fBlocksToFree.size() >= fAckBunchSize) { - lock.unlock(); // reduces contention on fBlockMtx - fBlockSendCV.notify_one(); - } -} - -void Region::SendAcks() -{ - unique_ptr blocks = tools::make_unique(fAckBunchSize); - - while (true) { // we'll try to send all acks before stopping - size_t blocksToSend = 0; - - { // mutex locking block - unique_lock lock(fBlockMtx); - - // try to get more blocks without waiting (we can miss a notify from CloseMessage()) - if (!fStop && (fBlocksToFree.size() < fAckBunchSize)) { - // cv.wait() timeout: send whatever blocks we have - fBlockSendCV.wait_for(lock, chrono::milliseconds(500)); - } - - blocksToSend = min(fBlocksToFree.size(), fAckBunchSize); - - copy_n(fBlocksToFree.end() - blocksToSend, blocksToSend, blocks.get()); - fBlocksToFree.resize(fBlocksToFree.size() - blocksToSend); - } // unlock the block mutex here while sending over IPC - - if (blocksToSend > 0) { - while (!fQueue->try_send(blocks.get(), blocksToSend * sizeof(RegionBlock), 0) && !fStop) { - // receiver slow? yield and try again... - this_thread::yield(); - } - } else { // blocksToSend == 0 - if (fStop) { - break; - } - } - } - - LOG(debug) << "send ack worker for " << fName << " leaving."; -} - -Region::~Region() -{ - fStop = true; - - if (fSendAcksWorker.joinable()) { - fBlockSendCV.notify_one(); - fSendAcksWorker.join(); - } - - if (!fRemote) { - if (fReceiveAcksWorker.joinable()) { - fReceiveAcksWorker.join(); - } - - if (bipc::shared_memory_object::remove(fName.c_str())) { - LOG(debug) << "shmem: destroyed region " << fName; - } - - if (bipc::file_mapping::remove(fName.c_str())) { - LOG(debug) << "shmem: destroyed file mapping " << fName; - } - - if (fFile) { - fclose(fFile); - } - - if (bipc::message_queue::remove(fQueueName.c_str())) { - LOG(debug) << "shmem: removed region queue " << fQueueName; - } - } else { - // LOG(debug) << "shmem: region '" << fName << "' is remote, no cleanup necessary."; - LOG(debug) << "shmem: region queue '" << fQueueName << "' is remote, no cleanup necessary"; - } -} - -} // namespace shmem -} // namespace mq -} // namespace fair diff --git a/fairmq/shmem/Region.h b/fairmq/shmem/Region.h index b15dfbcb..16441a9f 100644 --- a/fairmq/shmem/Region.h +++ b/fairmq/shmem/Region.h @@ -19,15 +19,24 @@ #include #include +#include +#include +#include +#include +#include #include #include #include +#include // min #include #include #include #include +#include +#include +#include namespace fair { @@ -36,28 +45,204 @@ namespace mq namespace shmem { -class Manager; - struct Region { - Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback, RegionBulkCallback bulkCallback, const std::string& path, int flags); + Region(const std::string& shmId, uint64_t id, uint64_t size, bool remote, RegionCallback callback, RegionBulkCallback bulkCallback, const std::string& path, int flags) + : fRemote(remote) + , fStop(false) + , fName("fmq_" + shmId + "_rg_" + std::to_string(id)) + , fQueueName("fmq_" + shmId + "_rgq_" + std::to_string(id)) + , fShmemObject() + , fFile(nullptr) + , fFileMapping() + , fQueue(nullptr) + , fReceiveAcksWorker() + , fSendAcksWorker() + , fCallback(callback) + , fBulkCallback(bulkCallback) + { + using namespace boost::interprocess; + + if (path != "") { + fName = std::string(path + fName); + + if (!fRemote) { + // create a file + std::filebuf fbuf; + if (fbuf.open(fName, std::ios_base::in | std::ios_base::out | std::ios_base::trunc | std::ios_base::binary)) { + // set the size + fbuf.pubseekoff(size - 1, std::ios_base::beg); + fbuf.sputc(0); + } + } + + fFile = fopen(fName.c_str(), "r+"); + + if (!fFile) { + LOG(error) << "Failed to initialize file: " << fName; + LOG(error) << "errno: " << errno << ": " << strerror(errno); + throw std::runtime_error(tools::ToString("Failed to initialize file for shared memory region: ", strerror(errno))); + } + fFileMapping = file_mapping(fName.c_str(), read_write); + LOG(debug) << "shmem: initialized file: " << fName; + fRegion = mapped_region(fFileMapping, read_write, 0, size, 0, flags); + } else { + if (fRemote) { + fShmemObject = shared_memory_object(open_only, fName.c_str(), read_write); + } else { + fShmemObject = shared_memory_object(create_only, fName.c_str(), read_write); + fShmemObject.truncate(size); + } + fRegion = mapped_region(fShmemObject, read_write, 0, 0, 0, flags); + } + + InitializeQueues(); + StartSendingAcks(); + LOG(debug) << "shmem: initialized region: " << fName; + } Region() = delete; Region(const Region&) = delete; Region(Region&&) = delete; - void InitializeQueues(); + void InitializeQueues() + { + using namespace boost::interprocess; - void StartSendingAcks(); - void SendAcks(); - void StartReceivingAcks(); - void ReceiveAcks(); - void ReleaseBlock(const RegionBlock &); + if (fRemote) { + fQueue = tools::make_unique(open_only, fQueueName.c_str()); + } else { + fQueue = tools::make_unique(create_only, fQueueName.c_str(), 1024, fAckBunchSize * sizeof(RegionBlock)); + } + LOG(debug) << "shmem: initialized region queue: " << fQueueName; + } - ~Region(); + void StartSendingAcks() + { + fSendAcksWorker = std::thread(&Region::SendAcks, this); + } + + void SendAcks() + { + std::unique_ptr blocks = tools::make_unique(fAckBunchSize); + + while (true) { // we'll try to send all acks before stopping + size_t blocksToSend = 0; + + { // mutex locking block + std::unique_lock lock(fBlockMtx); + + // try to get more blocks without waiting (we can miss a notify from CloseMessage()) + if (!fStop && (fBlocksToFree.size() < fAckBunchSize)) { + // cv.wait() timeout: send whatever blocks we have + fBlockSendCV.wait_for(lock, std::chrono::milliseconds(500)); + } + + blocksToSend = std::min(fBlocksToFree.size(), fAckBunchSize); + + copy_n(fBlocksToFree.end() - blocksToSend, blocksToSend, blocks.get()); + fBlocksToFree.resize(fBlocksToFree.size() - blocksToSend); + } // unlock the block mutex here while sending over IPC + + if (blocksToSend > 0) { + while (!fQueue->try_send(blocks.get(), blocksToSend * sizeof(RegionBlock), 0) && !fStop) { + // receiver slow? yield and try again... + std::this_thread::yield(); + } + } else { // blocksToSend == 0 + if (fStop) { + break; + } + } + } + + LOG(debug) << "send ack worker for " << fName << " leaving."; + } + + void StartReceivingAcks() + { + fReceiveAcksWorker = std::thread(&Region::ReceiveAcks, this); + } + + void ReceiveAcks() + { + unsigned int priority; + boost::interprocess::message_queue::size_type recvdSize; + std::unique_ptr blocks = tools::make_unique(fAckBunchSize); + std::vector result; + result.reserve(fAckBunchSize); + + while (!fStop) { // end thread condition (should exist until region is destroyed) + auto rcvTill = boost::posix_time::microsec_clock::universal_time() + boost::posix_time::milliseconds(500); + + while (fQueue->timed_receive(blocks.get(), fAckBunchSize * sizeof(RegionBlock), recvdSize, priority, rcvTill)) { + // LOG(debug) << "received: " << block.fHandle << " " << block.fSize << " " << block.fMessageId; + const auto numBlocks = recvdSize / sizeof(RegionBlock); + if (fBulkCallback) { + result.clear(); + for (size_t i = 0; i < numBlocks; i++) { + result.emplace_back(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); + } + fBulkCallback(result); + } else if (fCallback) { + for (size_t i = 0; i < numBlocks; i++) { + fCallback(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); + } + } + } + } // while !fStop + + LOG(debug) << "ReceiveAcks() worker for " << fName << " leaving."; + } + + void ReleaseBlock(const RegionBlock& block) + { + std::unique_lock lock(fBlockMtx); + + fBlocksToFree.emplace_back(block); + + if (fBlocksToFree.size() >= fAckBunchSize) { + lock.unlock(); // reduces contention on fBlockMtx + fBlockSendCV.notify_one(); + } + } + + ~Region() + { + fStop = true; + + if (fSendAcksWorker.joinable()) { + fBlockSendCV.notify_one(); + fSendAcksWorker.join(); + } + + if (!fRemote) { + if (fReceiveAcksWorker.joinable()) { + fReceiveAcksWorker.join(); + } + + if (boost::interprocess::shared_memory_object::remove(fName.c_str())) { + LOG(debug) << "shmem: destroyed region " << fName; + } + + if (boost::interprocess::file_mapping::remove(fName.c_str())) { + LOG(debug) << "shmem: destroyed file mapping " << fName; + } + + if (fFile) { + fclose(fFile); + } + + if (boost::interprocess::message_queue::remove(fQueueName.c_str())) { + LOG(debug) << "shmem: removed region queue " << fQueueName; + } + } else { + // LOG(debug) << "shmem: region '" << fName << "' is remote, no cleanup necessary."; + LOG(debug) << "shmem: region queue '" << fQueueName << "' is remote, no cleanup necessary"; + } + } - Manager& fManager; bool fRemote; bool fStop; std::string fName; diff --git a/fairmq/shmem/Socket.cxx b/fairmq/shmem/Socket.cxx deleted file mode 100644 index 42e183a4..00000000 --- a/fairmq/shmem/Socket.cxx +++ /dev/null @@ -1,496 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2014-2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ - -#include "Common.h" -#include "Socket.h" -#include "Message.h" -#include "UnmanagedRegion.h" -#include "TransportFactory.h" - -#include -#include - -#include - -#include - -using namespace std; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -struct ZMsg -{ - ZMsg() { int rc __attribute__((unused)) = zmq_msg_init(&fMsg); assert(rc == 0); } - explicit ZMsg(size_t size) { int rc __attribute__((unused)) = zmq_msg_init_size(&fMsg, size); assert(rc == 0); } - ~ZMsg() { int rc __attribute__((unused)) = zmq_msg_close(&fMsg); assert(rc == 0); } - - void* Data() { return zmq_msg_data(&fMsg); } - size_t Size() { return zmq_msg_size(&fMsg); } - zmq_msg_t* Msg() { return &fMsg; } - - zmq_msg_t fMsg; -}; - -Socket::Socket(Manager& manager, const string& type, const string& name, const string& id /*= ""*/, void* context, FairMQTransportFactory* fac /*=nullptr*/) - : fair::mq::Socket{fac} - , fSocket(nullptr) - , fManager(manager) - , fId(id + "." + name + "." + type) - , fBytesTx(0) - , fBytesRx(0) - , fMessagesTx(0) - , fMessagesRx(0) - , fSndTimeout(100) - , fRcvTimeout(100) -{ - assert(context); - fSocket = zmq_socket(context, GetConstant(type)); - - if (fSocket == nullptr) { - LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno); - throw SocketError(tools::ToString("Failed creating socket ", fId, ", reason: ", zmq_strerror(errno))); - } - - if (zmq_setsockopt(fSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) { - LOG(error) << "Failed setting ZMQ_IDENTITY socket option, reason: " << zmq_strerror(errno); - } - - // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. - // Default value for ZeroMQ is -1, which is to wait forever. - int linger = 1000; - if (zmq_setsockopt(fSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) { - LOG(error) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); - } - - if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) { - LOG(error) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); - } - - if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) { - LOG(error) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); - } - - // if (type == "sub") - // { - // if (zmq_setsockopt(fSocket, ZMQ_SUBSCRIBE, nullptr, 0) != 0) - // { - // LOG(error) << "Failed setting ZMQ_SUBSCRIBE socket option, reason: " << zmq_strerror(errno); - // } - // } - - if (type == "sub" || type == "pub") { - LOG(error) << "PUB/SUB socket type is not supported for shared memory transport"; - throw SocketError("PUB/SUB socket type is not supported for shared memory transport"); - } - LOG(debug) << "Created socket " << GetId(); -} - -bool Socket::Bind(const string& address) -{ - // LOG(info) << "binding socket " << fId << " on " << address; - if (zmq_bind(fSocket, address.c_str()) != 0) { - if (errno == EADDRINUSE) { - // do not print error in this case, this is handled by FairMQDevice in case no connection could be established after trying a number of random ports from a range. - return false; - } - LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); - return false; - } - return true; -} - -bool Socket::Connect(const string& address) -{ - // LOG(info) << "connecting socket " << fId << " on " << address; - if (zmq_connect(fSocket, address.c_str()) != 0) { - LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); - return false; - } - return true; -} - -int Socket::Send(MessagePtr& msg, const int timeout) -{ - int flags = 0; - if (timeout == 0) { - flags = ZMQ_DONTWAIT; - } - int elapsed = 0; - - Message* shmMsg = static_cast(msg.get()); - ZMsg zmqMsg(sizeof(MetaHeader)); - std::memcpy(zmqMsg.Data(), &(shmMsg->fMeta), sizeof(MetaHeader)); - - while (true && !fManager.Interrupted()) { - int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); - if (nbytes > 0) { - shmMsg->fQueued = true; - ++fMessagesTx; - size_t size = msg->GetSize(); - fBytesTx += size; - return size; - } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } - continue; - } else { - return -2; - } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Send interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; - } - } - - return -1; -} - -int Socket::Receive(MessagePtr& msg, const int timeout) -{ - int flags = 0; - if (timeout == 0) { - flags = ZMQ_DONTWAIT; - } - int elapsed = 0; - - ZMsg zmqMsg; - - while (true) { - Message* shmMsg = static_cast(msg.get()); - int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags); - if (nbytes > 0) { - // check for number of received messages. must be 1 - if (nbytes != sizeof(MetaHeader)) { - throw SocketError( - tools::ToString("Received message is not a valid FairMQ shared memory message. ", - "Possibly due to a misconfigured transport on the sender side. ", - "Expected size of ", sizeof(MetaHeader), " bytes, received ", nbytes)); - } - - MetaHeader* hdr = static_cast(zmqMsg.Data()); - size_t size = hdr->fSize; - shmMsg->fMeta = *hdr; - - fBytesRx += size; - ++fMessagesRx; - return size; - } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } - continue; - } else { - return -2; - } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; - } - } -} - -int64_t Socket::Send(vector& msgVec, const int timeout) -{ - int flags = 0; - if (timeout == 0) { - flags = ZMQ_DONTWAIT; - } - int elapsed = 0; - - // put it into zmq message - const unsigned int vecSize = msgVec.size(); - ZMsg zmqMsg(vecSize * sizeof(MetaHeader)); - - // prepare the message with shm metas - MetaHeader* metas = static_cast(zmqMsg.Data()); - - for (auto& msg : msgVec) { - Message* shmMsg = static_cast(msg.get()); - std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader)); - } - - while (!fManager.Interrupted()) { - int64_t totalSize = 0; - int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); - if (nbytes > 0) { - assert(static_cast(nbytes) == (vecSize * sizeof(MetaHeader))); // all or nothing - - for (auto& msg : msgVec) { - Message* shmMsg = static_cast(msg.get()); - shmMsg->fQueued = true; - totalSize += shmMsg->fMeta.fSize; - } - - // store statistics on how many messages have been sent - fMessagesTx++; - fBytesTx += totalSize; - - return totalSize; - } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } - continue; - } else { - return -2; - } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Send interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; - } - } - - return -1; -} - -int64_t Socket::Receive(vector& msgVec, const int timeout) -{ - int flags = 0; - if (timeout == 0) { - flags = ZMQ_DONTWAIT; - } - int elapsed = 0; - - ZMsg zmqMsg; - - while (!fManager.Interrupted()) { - int64_t totalSize = 0; - int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags); - if (nbytes > 0) { - MetaHeader* hdrVec = static_cast(zmqMsg.Data()); - const auto hdrVecSize = zmqMsg.Size(); - - assert(hdrVecSize > 0); - if (hdrVecSize % sizeof(MetaHeader) != 0) { - throw SocketError( - tools::ToString("Received message is not a valid FairMQ shared memory message. ", - "Possibly due to a misconfigured transport on the sender side. ", - "Expected size of ", sizeof(MetaHeader), " bytes, received ", nbytes)); - } - - const auto numMessages = hdrVecSize / sizeof(MetaHeader); - msgVec.reserve(numMessages); - - for (size_t m = 0; m < numMessages; m++) { - // create new message (part) - msgVec.emplace_back(tools::make_unique(fManager, hdrVec[m], GetTransport())); - Message* shmMsg = static_cast(msgVec.back().get()); - totalSize += shmMsg->GetSize(); - } - - // store statistics on how many messages have been received (handle all parts as a single message) - fMessagesRx++; - fBytesRx += totalSize; - - return totalSize; - } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } - continue; - } else { - return -2; - } - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; - } else { - LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; - } - } - - return -1; -} - -void Socket::Close() -{ - // LOG(debug) << "Closing socket " << fId; - - if (fSocket == nullptr) { - return; - } - - if (zmq_close(fSocket) != 0) { - LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno); - } - - fSocket = nullptr; -} - -void Socket::SetOption(const string& option, const void* value, size_t valueSize) -{ - if (zmq_setsockopt(fSocket, GetConstant(option), value, valueSize) < 0) { - LOG(error) << "Failed setting socket option, reason: " << zmq_strerror(errno); - } -} - -void Socket::GetOption(const string& option, void* value, size_t* valueSize) -{ - if (zmq_getsockopt(fSocket, GetConstant(option), value, valueSize) < 0) { - LOG(error) << "Failed getting socket option, reason: " << zmq_strerror(errno); - } -} - -void Socket::SetLinger(const int value) -{ - if (zmq_setsockopt(fSocket, ZMQ_LINGER, &value, sizeof(value)) < 0) { - throw SocketError(tools::ToString("failed setting ZMQ_LINGER, reason: ", zmq_strerror(errno))); - } -} - -int Socket::GetLinger() const -{ - int value = 0; - size_t valueSize = sizeof(value); - if (zmq_getsockopt(fSocket, ZMQ_LINGER, &value, &valueSize) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_LINGER, reason: ", zmq_strerror(errno))); - } - return value; -} - -void Socket::SetSndBufSize(const int value) -{ - if (zmq_setsockopt(fSocket, ZMQ_SNDHWM, &value, sizeof(value)) < 0) { - throw SocketError(tools::ToString("failed setting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); - } -} - -int Socket::GetSndBufSize() const -{ - int value = 0; - size_t valueSize = sizeof(value); - if (zmq_getsockopt(fSocket, ZMQ_SNDHWM, &value, &valueSize) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); - } - return value; -} - -void Socket::SetRcvBufSize(const int value) -{ - if (zmq_setsockopt(fSocket, ZMQ_RCVHWM, &value, sizeof(value)) < 0) { - throw SocketError(tools::ToString("failed setting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); - } -} - -int Socket::GetRcvBufSize() const -{ - int value = 0; - size_t valueSize = sizeof(value); - if (zmq_getsockopt(fSocket, ZMQ_RCVHWM, &value, &valueSize) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); - } - return value; -} - -void Socket::SetSndKernelSize(const int value) -{ - if (zmq_setsockopt(fSocket, ZMQ_SNDBUF, &value, sizeof(value)) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); - } -} - -int Socket::GetSndKernelSize() const -{ - int value = 0; - size_t valueSize = sizeof(value); - if (zmq_getsockopt(fSocket, ZMQ_SNDBUF, &value, &valueSize) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); - } - return value; -} - -void Socket::SetRcvKernelSize(const int value) -{ - if (zmq_setsockopt(fSocket, ZMQ_RCVBUF, &value, sizeof(value)) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); - } -} - -int Socket::GetRcvKernelSize() const -{ - int value = 0; - size_t valueSize = sizeof(value); - if (zmq_getsockopt(fSocket, ZMQ_RCVBUF, &value, &valueSize) < 0) { - throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); - } - return value; -} - -int Socket::GetConstant(const string& constant) -{ - if (constant == "") return 0; - if (constant == "sub") return ZMQ_SUB; - if (constant == "pub") return ZMQ_PUB; - if (constant == "xsub") return ZMQ_XSUB; - if (constant == "xpub") return ZMQ_XPUB; - if (constant == "push") return ZMQ_PUSH; - if (constant == "pull") return ZMQ_PULL; - if (constant == "req") return ZMQ_REQ; - if (constant == "rep") return ZMQ_REP; - if (constant == "dealer") return ZMQ_DEALER; - if (constant == "router") return ZMQ_ROUTER; - if (constant == "pair") return ZMQ_PAIR; - - if (constant == "snd-hwm") return ZMQ_SNDHWM; - if (constant == "rcv-hwm") return ZMQ_RCVHWM; - if (constant == "snd-size") return ZMQ_SNDBUF; - if (constant == "rcv-size") return ZMQ_RCVBUF; - if (constant == "snd-more") return ZMQ_SNDMORE; - if (constant == "rcv-more") return ZMQ_RCVMORE; - - if (constant == "linger") return ZMQ_LINGER; - if (constant == "no-block") return ZMQ_DONTWAIT; - if (constant == "snd-more no-block") return ZMQ_DONTWAIT|ZMQ_SNDMORE; - - return -1; -} - -} -} -} diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index 143ffdb4..05a273de 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -8,13 +8,18 @@ #ifndef FAIR_MQ_SHMEM_SOCKET_H_ #define FAIR_MQ_SHMEM_SOCKET_H_ +#include "Common.h" #include "Manager.h" +#include "Message.h" #include #include +#include +#include + +#include #include -#include // unique_ptr class FairMQTransportFactory; @@ -25,47 +30,483 @@ namespace mq namespace shmem { +struct ZMsg +{ + ZMsg() { int rc __attribute__((unused)) = zmq_msg_init(&fMsg); assert(rc == 0); } + explicit ZMsg(size_t size) { int rc __attribute__((unused)) = zmq_msg_init_size(&fMsg, size); assert(rc == 0); } + ~ZMsg() { int rc __attribute__((unused)) = zmq_msg_close(&fMsg); assert(rc == 0); } + + void* Data() { return zmq_msg_data(&fMsg); } + size_t Size() { return zmq_msg_size(&fMsg); } + zmq_msg_t* Msg() { return &fMsg; } + + zmq_msg_t fMsg; +}; + class Socket final : public fair::mq::Socket { public: - Socket(Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* fac = nullptr); + Socket(Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* fac = nullptr) + : fair::mq::Socket(fac) + , fSocket(nullptr) + , fManager(manager) + , fId(id + "." + name + "." + type) + , fBytesTx(0) + , fBytesRx(0) + , fMessagesTx(0) + , fMessagesRx(0) + , fSndTimeout(100) + , fRcvTimeout(100) + { + assert(context); + fSocket = zmq_socket(context, GetConstant(type)); + + if (fSocket == nullptr) { + LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno); + throw SocketError(tools::ToString("Failed creating socket ", fId, ", reason: ", zmq_strerror(errno))); + } + + if (zmq_setsockopt(fSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) { + LOG(error) << "Failed setting ZMQ_IDENTITY socket option, reason: " << zmq_strerror(errno); + } + + // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. + // Default value for ZeroMQ is -1, which is to wait forever. + int linger = 1000; + if (zmq_setsockopt(fSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) { + LOG(error) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); + } + + if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) { + LOG(error) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); + } + + if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) { + LOG(error) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); + } + + // if (type == "sub") + // { + // if (zmq_setsockopt(fSocket, ZMQ_SUBSCRIBE, nullptr, 0) != 0) + // { + // LOG(error) << "Failed setting ZMQ_SUBSCRIBE socket option, reason: " << zmq_strerror(errno); + // } + // } + + if (type == "sub" || type == "pub") { + LOG(error) << "PUB/SUB socket type is not supported for shared memory transport"; + throw SocketError("PUB/SUB socket type is not supported for shared memory transport"); + } + LOG(debug) << "Created socket " << GetId(); + } + Socket(const Socket&) = delete; Socket operator=(const Socket&) = delete; std::string GetId() const override { return fId; } - bool Bind(const std::string& address) override; - bool Connect(const std::string& address) override; + bool Bind(const std::string& address) override + { + // LOG(info) << "binding socket " << fId << " on " << address; + if (zmq_bind(fSocket, address.c_str()) != 0) { + if (errno == EADDRINUSE) { + // do not print error in this case, this is handled by FairMQDevice in case no connection could be established after trying a number of random ports from a range. + return false; + } + LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); + return false; + } + return true; + } - int Send(MessagePtr& msg, const int timeout = -1) override; - int Receive(MessagePtr& msg, const int timeout = -1) override; - int64_t Send(std::vector& msgVec, const int timeout = -1) override; - int64_t Receive(std::vector& msgVec, const int timeout = -1) override; + bool Connect(const std::string& address) override + { + // LOG(info) << "connecting socket " << fId << " on " << address; + if (zmq_connect(fSocket, address.c_str()) != 0) { + LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); + return false; + } + return true; + } + + int Send(MessagePtr& msg, const int timeout = -1) override + { + int flags = 0; + if (timeout == 0) { + flags = ZMQ_DONTWAIT; + } + int elapsed = 0; + + Message* shmMsg = static_cast(msg.get()); + ZMsg zmqMsg(sizeof(MetaHeader)); + std::memcpy(zmqMsg.Data(), &(shmMsg->fMeta), sizeof(MetaHeader)); + + while (true && !fManager.Interrupted()) { + int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); + if (nbytes > 0) { + shmMsg->fQueued = true; + ++fMessagesTx; + size_t size = msg->GetSize(); + fBytesTx += size; + return size; + } else if (zmq_errno() == EAGAIN) { + if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fSndTimeout; + if (elapsed >= timeout) { + return -2; + } + } + continue; + } else { + return -2; + } + } else if (zmq_errno() == ETERM) { + LOG(info) << "terminating socket " << fId; + return -1; + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Send interrupted by system call"; + return nbytes; + }else { + LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; + return nbytes; + } + } + + return -1; + } + + int Receive(MessagePtr& msg, const int timeout = -1) override + { + int flags = 0; + if (timeout == 0) { + flags = ZMQ_DONTWAIT; + } + int elapsed = 0; + + ZMsg zmqMsg; + + while (true) { + Message* shmMsg = static_cast(msg.get()); + int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags); + if (nbytes > 0) { + // check for number of received messages. must be 1 + if (nbytes != sizeof(MetaHeader)) { + throw SocketError( + tools::ToString("Received message is not a valid FairMQ shared memory message. ", + "Possibly due to a misconfigured transport on the sender side. ", + "Expected size of ", sizeof(MetaHeader), " bytes, received ", nbytes)); + } + + MetaHeader* hdr = static_cast(zmqMsg.Data()); + size_t size = hdr->fSize; + shmMsg->fMeta = *hdr; + + fBytesRx += size; + ++fMessagesRx; + return size; + } else if (zmq_errno() == EAGAIN) { + if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fRcvTimeout; + if (elapsed >= timeout) { + return -2; + } + } + continue; + } else { + return -2; + } + } else if (zmq_errno() == ETERM) { + LOG(info) << "terminating socket " << fId; + return -1; + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Receive interrupted by system call"; + return nbytes; + }else { + LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; + return nbytes; + } + } + } + + int64_t Send(std::vector& msgVec, const int timeout = -1) override + { + int flags = 0; + if (timeout == 0) { + flags = ZMQ_DONTWAIT; + } + int elapsed = 0; + + // put it into zmq message + const unsigned int vecSize = msgVec.size(); + ZMsg zmqMsg(vecSize * sizeof(MetaHeader)); + + // prepare the message with shm metas + MetaHeader* metas = static_cast(zmqMsg.Data()); + + for (auto& msg : msgVec) { + Message* shmMsg = static_cast(msg.get()); + std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader)); + } + + while (!fManager.Interrupted()) { + int64_t totalSize = 0; + int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); + if (nbytes > 0) { + assert(static_cast(nbytes) == (vecSize * sizeof(MetaHeader))); // all or nothing + + for (auto& msg : msgVec) { + Message* shmMsg = static_cast(msg.get()); + shmMsg->fQueued = true; + totalSize += shmMsg->fMeta.fSize; + } + + // store statistics on how many messages have been sent + fMessagesTx++; + fBytesTx += totalSize; + + return totalSize; + } else if (zmq_errno() == EAGAIN) { + if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fSndTimeout; + if (elapsed >= timeout) { + return -2; + } + } + continue; + } else { + return -2; + } + } else if (zmq_errno() == ETERM) { + LOG(info) << "terminating socket " << fId; + return -1; + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Send interrupted by system call"; + return nbytes; + }else { + LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; + return nbytes; + } + } + + return -1; + } + + int64_t Receive(std::vector& msgVec, const int timeout = -1) override + { + int flags = 0; + if (timeout == 0) { + flags = ZMQ_DONTWAIT; + } + int elapsed = 0; + + ZMsg zmqMsg; + + while (!fManager.Interrupted()) { + int64_t totalSize = 0; + int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags); + if (nbytes > 0) { + MetaHeader* hdrVec = static_cast(zmqMsg.Data()); + const auto hdrVecSize = zmqMsg.Size(); + + assert(hdrVecSize > 0); + if (hdrVecSize % sizeof(MetaHeader) != 0) { + throw SocketError( + tools::ToString("Received message is not a valid FairMQ shared memory message. ", + "Possibly due to a misconfigured transport on the sender side. ", + "Expected size of ", sizeof(MetaHeader), " bytes, received ", nbytes)); + } + + const auto numMessages = hdrVecSize / sizeof(MetaHeader); + msgVec.reserve(numMessages); + + for (size_t m = 0; m < numMessages; m++) { + // create new message (part) + msgVec.emplace_back(tools::make_unique(fManager, hdrVec[m], GetTransport())); + Message* shmMsg = static_cast(msgVec.back().get()); + totalSize += shmMsg->GetSize(); + } + + // store statistics on how many messages have been received (handle all parts as a single message) + fMessagesRx++; + fBytesRx += totalSize; + + return totalSize; + } else if (zmq_errno() == EAGAIN) { + if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fRcvTimeout; + if (elapsed >= timeout) { + return -2; + } + } + continue; + } else { + return -2; + } + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Receive interrupted by system call"; + return nbytes; + } else { + LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; + return nbytes; + } + } + + return -1; + } void* GetSocket() const { return fSocket; } - void Close() override; + void Close() override + { + // LOG(debug) << "Closing socket " << fId; - void SetOption(const std::string& option, const void* value, size_t valueSize) override; - void GetOption(const std::string& option, void* value, size_t* valueSize) override; + if (fSocket == nullptr) { + return; + } - void SetLinger(const int value) override; - int GetLinger() const override; - void SetSndBufSize(const int value) override; - int GetSndBufSize() const override; - void SetRcvBufSize(const int value) override; - int GetRcvBufSize() const override; - void SetSndKernelSize(const int value) override; - int GetSndKernelSize() const override; - void SetRcvKernelSize(const int value) override; - int GetRcvKernelSize() const override; + if (zmq_close(fSocket) != 0) { + LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno); + } + + fSocket = nullptr; + } + + void SetOption(const std::string& option, const void* value, size_t valueSize) override + { + if (zmq_setsockopt(fSocket, GetConstant(option), value, valueSize) < 0) { + LOG(error) << "Failed setting socket option, reason: " << zmq_strerror(errno); + } + } + + void GetOption(const std::string& option, void* value, size_t* valueSize) override + { + if (zmq_getsockopt(fSocket, GetConstant(option), value, valueSize) < 0) { + LOG(error) << "Failed getting socket option, reason: " << zmq_strerror(errno); + } + } + + void SetLinger(const int value) override + { + if (zmq_setsockopt(fSocket, ZMQ_LINGER, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed setting ZMQ_LINGER, reason: ", zmq_strerror(errno))); + } + } + + int GetLinger() const override + { + int value = 0; + size_t valueSize = sizeof(value); + if (zmq_getsockopt(fSocket, ZMQ_LINGER, &value, &valueSize) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_LINGER, reason: ", zmq_strerror(errno))); + } + return value; + } + + void SetSndBufSize(const int value) override + { + if (zmq_setsockopt(fSocket, ZMQ_SNDHWM, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed setting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); + } + } + + int GetSndBufSize() const override + { + int value = 0; + size_t valueSize = sizeof(value); + if (zmq_getsockopt(fSocket, ZMQ_SNDHWM, &value, &valueSize) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_SNDHWM, reason: ", zmq_strerror(errno))); + } + return value; + } + + void SetRcvBufSize(const int value) override + { + if (zmq_setsockopt(fSocket, ZMQ_RCVHWM, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed setting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); + } + } + + int GetRcvBufSize() const override + { + int value = 0; + size_t valueSize = sizeof(value); + if (zmq_getsockopt(fSocket, ZMQ_RCVHWM, &value, &valueSize) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_RCVHWM, reason: ", zmq_strerror(errno))); + } + return value; + } + + void SetSndKernelSize(const int value) override + { + if (zmq_setsockopt(fSocket, ZMQ_SNDBUF, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); + } + } + + int GetSndKernelSize() const override + { + int value = 0; + size_t valueSize = sizeof(value); + if (zmq_getsockopt(fSocket, ZMQ_SNDBUF, &value, &valueSize) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno))); + } + return value; + } + + void SetRcvKernelSize(const int value) override + { + if (zmq_setsockopt(fSocket, ZMQ_RCVBUF, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); + } + } + + int GetRcvKernelSize() const override + { + int value = 0; + size_t valueSize = sizeof(value); + if (zmq_getsockopt(fSocket, ZMQ_RCVBUF, &value, &valueSize) < 0) { + throw SocketError(tools::ToString("failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno))); + } + return value; + } unsigned long GetBytesTx() const override { return fBytesTx; } unsigned long GetBytesRx() const override { return fBytesRx; } unsigned long GetMessagesTx() const override { return fMessagesTx; } unsigned long GetMessagesRx() const override { return fMessagesRx; } - static int GetConstant(const std::string& constant); + static int GetConstant(const std::string& constant) + { + if (constant == "") return 0; + if (constant == "sub") return ZMQ_SUB; + if (constant == "pub") return ZMQ_PUB; + if (constant == "xsub") return ZMQ_XSUB; + if (constant == "xpub") return ZMQ_XPUB; + if (constant == "push") return ZMQ_PUSH; + if (constant == "pull") return ZMQ_PULL; + if (constant == "req") return ZMQ_REQ; + if (constant == "rep") return ZMQ_REP; + if (constant == "dealer") return ZMQ_DEALER; + if (constant == "router") return ZMQ_ROUTER; + if (constant == "pair") return ZMQ_PAIR; + + if (constant == "snd-hwm") return ZMQ_SNDHWM; + if (constant == "rcv-hwm") return ZMQ_RCVHWM; + if (constant == "snd-size") return ZMQ_SNDBUF; + if (constant == "rcv-size") return ZMQ_RCVBUF; + if (constant == "snd-more") return ZMQ_SNDMORE; + if (constant == "rcv-more") return ZMQ_RCVMORE; + + if (constant == "linger") return ZMQ_LINGER; + if (constant == "no-block") return ZMQ_DONTWAIT; + if (constant == "snd-more no-block") return ZMQ_DONTWAIT|ZMQ_SNDMORE; + + return -1; + } ~Socket() override { Close(); } diff --git a/fairmq/shmem/TransportFactory.cxx b/fairmq/shmem/TransportFactory.cxx deleted file mode 100644 index 104327f6..00000000 --- a/fairmq/shmem/TransportFactory.cxx +++ /dev/null @@ -1,231 +0,0 @@ -/******************************************************************************** - * Copyright (C) 2016-2017 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * - * * - * This software is distributed under the terms of the * - * GNU Lesser General Public Licence (LGPL) version 3, * - * copied verbatim in the file "LICENSE" * - ********************************************************************************/ - -#include "TransportFactory.h" - -#include -#include - -#include - -#include -#include - -#include -#include - -#include - -#include -#include -#include // getenv - -using namespace std; - -namespace bpt = ::boost::posix_time; -namespace bipc = ::boost::interprocess; - -namespace fair -{ -namespace mq -{ -namespace shmem -{ - -TransportFactory::TransportFactory(const string& id, const ProgOptions* config) - : fair::mq::TransportFactory(id) - , fDeviceId(id) - , fShmId() - , fZMQContext(nullptr) - , fManager(nullptr) - , fHeartbeatThread() - , fSendHeartbeats(true) - , fMsgCounter(0) -{ - int major, minor, patch; - zmq_version(&major, &minor, &patch); - LOG(debug) << "Transport: Using ZeroMQ (" << major << "." << minor << "." << patch << ") & " - << "boost::interprocess (" << (BOOST_VERSION / 100000) << "." << (BOOST_VERSION / 100 % 1000) << "." << (BOOST_VERSION % 100) << ")"; - - fZMQContext = zmq_ctx_new(); - if (!fZMQContext) { - throw runtime_error(tools::ToString("failed creating context, reason: ", zmq_strerror(errno))); - } - - int numIoThreads = 1; - string sessionName = "default"; - size_t segmentSize = 2000000000; - bool autolaunchMonitor = false; - if (config) { - numIoThreads = config->GetProperty("io-threads", numIoThreads); - sessionName = config->GetProperty("session", sessionName); - segmentSize = config->GetProperty("shm-segment-size", segmentSize); - autolaunchMonitor = config->GetProperty("shm-monitor", autolaunchMonitor); - } else { - LOG(debug) << "ProgOptions not available! Using defaults."; - } - - fShmId = buildShmIdFromSessionIdAndUserId(sessionName); - - try { - if (zmq_ctx_set(fZMQContext, ZMQ_IO_THREADS, numIoThreads) != 0) { - LOG(error) << "failed configuring context, reason: " << zmq_strerror(errno); - } - - // Set the maximum number of allowed sockets on the context. - if (zmq_ctx_set(fZMQContext, ZMQ_MAX_SOCKETS, 10000) != 0) { - LOG(error) << "failed configuring context, reason: " << zmq_strerror(errno); - } - - if (autolaunchMonitor) { - Manager::StartMonitor(fShmId); - } - - fManager = tools::make_unique(fShmId, segmentSize); - - } catch (bipc::interprocess_exception& e) { - LOG(error) << "Could not initialize shared memory transport: " << e.what(); - throw runtime_error(tools::ToString("Could not initialize shared memory transport: ", e.what())); - } - - fSendHeartbeats = true; - fHeartbeatThread = thread(&TransportFactory::SendHeartbeats, this); -} - -void TransportFactory::SendHeartbeats() -{ - string controlQueueName("fmq_" + fShmId + "_cq"); - while (fSendHeartbeats) { - try { - bipc::message_queue mq(bipc::open_only, controlQueueName.c_str()); - bpt::ptime sndTill = bpt::microsec_clock::universal_time() + bpt::milliseconds(100); - if (mq.timed_send(fDeviceId.c_str(), fDeviceId.size(), 0, sndTill)) { - this_thread::sleep_for(chrono::milliseconds(100)); - } else { - LOG(debug) << "control queue timeout"; - } - } catch (bipc::interprocess_exception& ie) { - this_thread::sleep_for(chrono::milliseconds(500)); - // LOG(warn) << "no " << controlQueueName << " found"; - } - } -} - -MessagePtr TransportFactory::CreateMessage() -{ - return tools::make_unique(*fManager, this); -} - -MessagePtr TransportFactory::CreateMessage(const size_t size) -{ - return tools::make_unique(*fManager, size, this); -} - -MessagePtr TransportFactory::CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) -{ - return tools::make_unique(*fManager, data, size, ffn, hint, this); -} - -MessagePtr TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint) -{ - return tools::make_unique(*fManager, region, data, size, hint, this); -} - -SocketPtr TransportFactory::CreateSocket(const string& type, const string& name) -{ - assert(fZMQContext); - return tools::make_unique(*fManager, type, name, GetId(), fZMQContext, this); -} - -PollerPtr TransportFactory::CreatePoller(const vector& channels) const -{ - return tools::make_unique(channels); -} - -PollerPtr TransportFactory::CreatePoller(const vector& channels) const -{ - return tools::make_unique(channels); -} - -PollerPtr TransportFactory::CreatePoller(const unordered_map>& channelsMap, const vector& channelList) const -{ - return tools::make_unique(channelsMap, channelList); -} - -UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, RegionCallback callback, const std::string& path /* = "" */, int flags /* = 0 */) -{ - return tools::make_unique(*fManager, size, 0, callback, nullptr, path, flags, this); -} - -UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, RegionBulkCallback bulkCallback, const std::string& path /* = "" */, int flags /* = 0 */) -{ - return tools::make_unique(*fManager, size, 0, nullptr, bulkCallback, path, flags, this); -} - -UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, const int64_t userFlags, RegionCallback callback, const std::string& path /* = "" */, int flags /* = 0 */) -{ - return tools::make_unique(*fManager, size, userFlags, callback, nullptr, path, flags, this); -} - -UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, const int64_t userFlags, RegionBulkCallback bulkCallback, const std::string& path /* = "" */, int flags /* = 0 */) -{ - return tools::make_unique(*fManager, size, userFlags, nullptr, bulkCallback, path, flags, this); -} - -void TransportFactory::SubscribeToRegionEvents(RegionEventCallback callback) -{ - fManager->SubscribeToRegionEvents(callback); -} - -bool TransportFactory::SubscribedToRegionEvents() -{ - return fManager->SubscribedToRegionEvents(); -} - -void TransportFactory::UnsubscribeFromRegionEvents() -{ - fManager->UnsubscribeFromRegionEvents(); -} - -vector TransportFactory::GetRegionInfo() -{ - return fManager->GetRegionInfo(); -} - -void TransportFactory::Reset() -{ - if (fMsgCounter.load() != 0) { - LOG(error) << "Message counter during Reset expected to be 0, found: " << fMsgCounter.load(); - throw MessageError(tools::ToString("Message counter during Reset expected to be 0, found: ", fMsgCounter.load())); - } -} - - -TransportFactory::~TransportFactory() -{ - LOG(debug) << "Destroying Shared Memory transport..."; - fSendHeartbeats = false; - fHeartbeatThread.join(); - - if (fZMQContext) { - if (zmq_ctx_term(fZMQContext) != 0) { - if (errno == EINTR) { - LOG(error) << "failed closing context, reason: " << zmq_strerror(errno); - } else { - fZMQContext = nullptr; - return; - } - } - } else { - LOG(error) << "context not available for shutdown"; - } -} - -} // namespace shmem -} // namespace mq -} // namespace fair diff --git a/fairmq/shmem/TransportFactory.h b/fairmq/shmem/TransportFactory.h index bb92b39c..77f38405 100644 --- a/fairmq/shmem/TransportFactory.h +++ b/fairmq/shmem/TransportFactory.h @@ -18,11 +18,18 @@ #include #include +#include +#include + +#include + +#include #include #include -#include #include +#include +#include namespace fair { @@ -34,52 +41,173 @@ namespace shmem class TransportFactory final : public fair::mq::TransportFactory { public: - TransportFactory(const std::string& id = "", const ProgOptions* config = nullptr); + TransportFactory(const std::string& id = "", const ProgOptions* config = nullptr) + : fair::mq::TransportFactory(id) + , fDeviceId(id) + , fShmId() + , fZMQContext(nullptr) + , fManager(nullptr) + { + int major, minor, patch; + zmq_version(&major, &minor, &patch); + LOG(debug) << "Transport: Using ZeroMQ (" << major << "." << minor << "." << patch << ") & " + << "boost::interprocess (" << (BOOST_VERSION / 100000) << "." << (BOOST_VERSION / 100 % 1000) << "." << (BOOST_VERSION % 100) << ")"; + + fZMQContext = zmq_ctx_new(); + if (!fZMQContext) { + throw std::runtime_error(tools::ToString("failed creating context, reason: ", zmq_strerror(errno))); + } + + int numIoThreads = 1; + std::string sessionName = "default"; + size_t segmentSize = 2000000000; + bool autolaunchMonitor = false; + if (config) { + numIoThreads = config->GetProperty("io-threads", numIoThreads); + sessionName = config->GetProperty("session", sessionName); + segmentSize = config->GetProperty("shm-segment-size", segmentSize); + autolaunchMonitor = config->GetProperty("shm-monitor", autolaunchMonitor); + } else { + LOG(debug) << "ProgOptions not available! Using defaults."; + } + + fShmId = buildShmIdFromSessionIdAndUserId(sessionName); + + try { + if (zmq_ctx_set(fZMQContext, ZMQ_IO_THREADS, numIoThreads) != 0) { + LOG(error) << "failed configuring context, reason: " << zmq_strerror(errno); + } + + // Set the maximum number of allowed sockets on the context. + if (zmq_ctx_set(fZMQContext, ZMQ_MAX_SOCKETS, 10000) != 0) { + LOG(error) << "failed configuring context, reason: " << zmq_strerror(errno); + } + + if (autolaunchMonitor) { + Manager::StartMonitor(fShmId); + } + + fManager = tools::make_unique(fShmId, fDeviceId, segmentSize); + + } catch (boost::interprocess::interprocess_exception& e) { + LOG(error) << "Could not initialize shared memory transport: " << e.what(); + throw std::runtime_error(tools::ToString("Could not initialize shared memory transport: ", e.what())); + } + } + TransportFactory(const TransportFactory&) = delete; TransportFactory operator=(const TransportFactory&) = delete; - MessagePtr CreateMessage() override; - MessagePtr CreateMessage(const size_t size) override; - MessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; - MessagePtr CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; + MessagePtr CreateMessage() override + { + return tools::make_unique(*fManager, this); + } - SocketPtr CreateSocket(const std::string& type, const std::string& name) override; + MessagePtr CreateMessage(const size_t size) override + { + return tools::make_unique(*fManager, size, this); + } - PollerPtr CreatePoller(const std::vector& channels) const override; - PollerPtr CreatePoller(const std::vector& channels) const override; - PollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const override; + MessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override + { + return tools::make_unique(*fManager, data, size, ffn, hint, this); + } - UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override; - UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) override; - UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override; - UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) override; + MessagePtr CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override + { + return tools::make_unique(*fManager, region, data, size, hint, this); + } - void SubscribeToRegionEvents(RegionEventCallback callback) override; - bool SubscribedToRegionEvents() override; - void UnsubscribeFromRegionEvents() override; - std::vector GetRegionInfo() override; + SocketPtr CreateSocket(const std::string& type, const std::string& name) override + { + assert(fZMQContext); + return tools::make_unique(*fManager, type, name, GetId(), fZMQContext, this); + } + + PollerPtr CreatePoller(const std::vector& channels) const override + { + return tools::make_unique(channels); + } + + PollerPtr CreatePoller(const std::vector& channels) const override + { + return tools::make_unique(channels); + } + + PollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const override + { + return tools::make_unique(channelsMap, channelList); + } + + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override + { + return tools::make_unique(*fManager, size, 0, callback, nullptr, path, flags, this); + } + + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionBulkCallback bulkCallback = nullptr, const std::string& path = "", int flags = 0) override + { + return tools::make_unique(*fManager, size, 0, nullptr, bulkCallback, path, flags, this); + } + + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override + { + return tools::make_unique(*fManager, size, userFlags, callback, nullptr, path, flags, this); + } + + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionBulkCallback bulkCallback = nullptr, const std::string& path = "", int flags = 0) override + { + return tools::make_unique(*fManager, size, userFlags, nullptr, bulkCallback, path, flags, this); + } + + void SubscribeToRegionEvents(RegionEventCallback callback) override + { + fManager->SubscribeToRegionEvents(callback); + } + + bool SubscribedToRegionEvents() override + { + return fManager->SubscribedToRegionEvents(); + } + + void UnsubscribeFromRegionEvents() override + { + fManager->UnsubscribeFromRegionEvents(); + } + + std::vector GetRegionInfo() override + { + return fManager->GetRegionInfo(); + } Transport GetType() const override { return fair::mq::Transport::SHM; } void Interrupt() override { fManager->Interrupt(); } void Resume() override { fManager->Resume(); } - void Reset() override; + void Reset() override { fManager->Reset(); } - void IncrementMsgCounter() { ++fMsgCounter; } - void DecrementMsgCounter() { --fMsgCounter; } + ~TransportFactory() override + { + LOG(debug) << "Destroying Shared Memory transport..."; - ~TransportFactory() override; + if (fZMQContext) { + if (zmq_ctx_term(fZMQContext) != 0) { + if (errno == EINTR) { + LOG(error) << "failed closing context, reason: " << zmq_strerror(errno); + } else { + fZMQContext = nullptr; + return; + } + } + } else { + LOG(error) << "context not available for shutdown"; + } + } private: - void SendHeartbeats(); - std::string fDeviceId; std::string fShmId; void* fZMQContext; std::unique_ptr fManager; - std::thread fHeartbeatThread; - std::atomic fSendHeartbeats; - std::atomic fMsgCounter; }; } // namespace shmem