Implement shmem msg zero-copy

This commit is contained in:
Alexey Rybalchenko 2021-07-14 10:46:12 +02:00 committed by Dennis Klein
parent c57410b820
commit bce380d871
7 changed files with 333 additions and 100 deletions

View File

@ -47,6 +47,10 @@ struct Message
TransportFactory* GetTransport() { return fTransport; } TransportFactory* GetTransport() { return fTransport; }
void SetTransport(TransportFactory* transport) { fTransport = transport; } void SetTransport(TransportFactory* transport) { fTransport = transport; }
/// Copy the message buffer from another message
/// Transport may choose not to physically copy the buffer, but to share across the messages.
/// Modifying the buffer after a call to Copy() is undefined behaviour.
/// @param msg message to copy the buffer from.
virtual void Copy(const Message& msg) = 0; virtual void Copy(const Message& msg) = 0;
virtual ~Message() = default; virtual ~Message() = default;

View File

@ -146,9 +146,10 @@ struct MetaHeader
{ {
size_t fSize; size_t fSize;
size_t fHint; size_t fHint;
uint16_t fRegionId;
uint16_t fSegmentId;
boost::interprocess::managed_shared_memory::handle_t fHandle; boost::interprocess::managed_shared_memory::handle_t fHandle;
mutable boost::interprocess::managed_shared_memory::handle_t fShared;
uint16_t fRegionId;
mutable uint16_t fSegmentId;
}; };
#ifdef FAIRMQ_DEBUG_MODE #ifdef FAIRMQ_DEBUG_MODE
@ -271,22 +272,22 @@ struct SegmentHandleFromAddress : public boost::static_visitor<boost::interproce
const void* ptr; const void* ptr;
}; };
struct SegmentAddressFromHandle : public boost::static_visitor<void*> struct SegmentAddressFromHandle : public boost::static_visitor<char*>
{ {
SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {} SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {}
template<typename S> template<typename S>
void* operator()(S& s) const { return s.get_address_from_handle(handle); } char* operator()(S& s) const { return reinterpret_cast<char*>(s.get_address_from_handle(handle)); }
const boost::interprocess::managed_shared_memory::handle_t handle; const boost::interprocess::managed_shared_memory::handle_t handle;
}; };
struct SegmentAllocate : public boost::static_visitor<void*> struct SegmentAllocate : public boost::static_visitor<char*>
{ {
SegmentAllocate(const size_t _size) : size(_size) {} SegmentAllocate(const size_t _size) : size(_size) {}
template<typename S> template<typename S>
void* operator()(S& s) const { return s.allocate(size); } char* operator()(S& s) const { return reinterpret_cast<char*>(s.allocate(size)); }
const size_t size; const size_t size;
}; };
@ -322,12 +323,12 @@ struct SegmentBufferShrink : public boost::static_visitor<char*>
struct SegmentDeallocate : public boost::static_visitor<> struct SegmentDeallocate : public boost::static_visitor<>
{ {
SegmentDeallocate(void* _ptr) : ptr(_ptr) {} SegmentDeallocate(char* _ptr) : ptr(_ptr) {}
template<typename S> template<typename S>
void operator()(S& s) const { return s.deallocate(ptr); } void operator()(S& s) const { return s.deallocate(ptr); }
void* ptr; char* ptr;
}; };
} // namespace fair::mq::shmem } // namespace fair::mq::shmem

View File

@ -52,29 +52,77 @@
#include <unistd.h> // getuid #include <unistd.h> // getuid
#include <sys/types.h> // getuid #include <sys/types.h> // getuid
#include <sys/mman.h> // mlock #include <sys/mman.h> // mlock
namespace fair::mq::shmem namespace fair::mq::shmem
{ {
struct ShmPtr // ShmHeader stores user buffer alignment and the reference count in the following structure:
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// The alignment of Hdr depends on the alignment of std::atomic and is stored in the first entry
struct ShmHeader
{ {
explicit ShmPtr(char* rPtr) struct Hdr
: realPtr(rPtr)
{}
char* RealPtr()
{ {
return realPtr; uint16_t userOffset;
std::atomic<uint16_t> refCount;
};
static Hdr* HdrPtr(char* ptr)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// ^
return reinterpret_cast<Hdr*>(ptr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(ptr)));
} }
char* UserPtr() static uint16_t HdrPartSize() // [HdrOffset(uint16_t)][Hdr alignment][Hdr]
{ {
return realPtr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(realPtr)); // [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// <--------------------------------------->
return sizeof(uint16_t) + alignof(Hdr) + sizeof(Hdr);
} }
char* realPtr; static std::atomic<uint16_t>& RefCountPtr(char* ptr)
{
// get the ref count ptr from the Hdr
return HdrPtr(ptr)->refCount;
}
static char* UserPtr(char* ptr)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// ^
return ptr + HdrPartSize() + HdrPtr(ptr)->userOffset;
}
static uint16_t RefCount(char* ptr) { return RefCountPtr(ptr).load(); }
static uint16_t IncrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_add(1); }
static uint16_t DecrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_sub(1); }
static size_t FullSize(size_t size, size_t alignment)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// <--------------------------------------------------------------------------->
return HdrPartSize() + alignment + size;
}
static void Construct(char* ptr, size_t alignment)
{
// place the Hdr in the aligned location, fill it and store its offset to HdrOffset
// the address alignment should be at least 2
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0);
// offset to the beginning of the Hdr. store it in the beginning
uint16_t hdrOffset = alignof(Hdr) - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignof(Hdr));
memcpy(ptr, &hdrOffset, sizeof(hdrOffset));
// offset to the beginning of the user buffer, store in Hdr together with the ref count
uint16_t userOffset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + HdrPartSize()) % alignment);
new(ptr + sizeof(uint16_t) + hdrOffset) Hdr{ userOffset, std::atomic<uint16_t>(1) };
}
static void Destruct(char* ptr) { RefCountPtr(ptr).~atomic(); }
}; };
class Manager class Manager
@ -635,44 +683,35 @@ class Manager
{ {
return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId)); return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId));
} }
void* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const char* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const
{ {
return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId)); return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId));
} }
ShmPtr Allocate(size_t size, size_t alignment = 0) char* Allocate(size_t size, size_t alignment = 0)
{ {
alignment = std::max(alignment, alignof(std::max_align_t)); alignment = std::max(alignment, alignof(std::max_align_t));
char* ptr = nullptr; char* ptr = nullptr;
// [offset(uint16_t)][alignment][buffer] size_t fullSize = ShmHeader::FullSize(size, alignment);
size_t fullSize = sizeof(uint16_t) + alignment + size;
// tools::RateLimiter rateLimiter(20);
while (ptr == nullptr) { while (ptr == nullptr) {
try { try {
// boost::interprocess::managed_shared_memory::size_type actualSize = size;
// char* hint = 0; // unused for boost::interprocess::allocate_new
// ptr = fSegments.at(fSegmentId).allocation_command<char>(boost::interprocess::allocate_new, size, actualSize, hint);
size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId)); size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId));
if (fullSize > segmentSize) { if (fullSize > segmentSize) {
throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")")); throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")"));
} }
ptr = reinterpret_cast<char*>(boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId))); ptr = boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId));
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0); ShmHeader::Construct(ptr, alignment);
uint16_t offset = 0;
offset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignment);
std::memcpy(ptr, &offset, sizeof(offset));
} catch (boost::interprocess::bad_alloc& ba) { } catch (boost::interprocess::bad_alloc& ba) {
// LOG(warn) << "Shared memory full..."; // LOG(warn) << "Shared memory full...";
if (ThrowingOnBadAlloc()) { if (ThrowingOnBadAlloc()) {
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId)))); throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
} }
// rateLimiter.maybe_sleep();
std::this_thread::sleep_for(std::chrono::milliseconds(50)); std::this_thread::sleep_for(std::chrono::milliseconds(50));
if (Interrupted()) { if (Interrupted()) {
return ShmPtr(ptr); throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
} else { } else {
continue; continue;
} }
@ -684,18 +723,20 @@ class Manager
(*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc); (*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc);
} }
(*fMsgDebug).at(fSegmentId).emplace( (*fMsgDebug).at(fSegmentId).emplace(
static_cast<size_t>(GetHandleFromAddress(ShmPtr(ptr).UserPtr(), fSegmentId)), static_cast<size_t>(GetHandleFromAddress(ShmHeader::UserPtr(ptr), fSegmentId)),
MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count()) MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count())
); );
#endif #endif
} }
return ShmPtr(ptr); return ptr;
} }
void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId)
{ {
boost::apply_visitor(SegmentDeallocate(GetAddressFromHandle(handle, segmentId)), fSegments.at(segmentId)); char* ptr = GetAddressFromHandle(handle, segmentId);
ShmHeader::Destruct(ptr);
boost::apply_visitor(SegmentDeallocate(ptr), fSegments.at(segmentId));
#ifdef FAIRMQ_DEBUG_MODE #ifdef FAIRMQ_DEBUG_MODE
boost::interprocess::scoped_lock<boost::interprocess::named_mutex> lock(fShmMtx); boost::interprocess::scoped_lock<boost::interprocess::named_mutex> lock(fShmMtx);
DecrementShmMsgCounter(segmentId); DecrementShmMsgCounter(segmentId);

View File

@ -38,7 +38,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1} , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
@ -49,7 +49,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1} , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fAlignment(alignment.alignment) , fAlignment(alignment.alignment)
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
@ -61,7 +61,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1} , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
@ -73,7 +73,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1} , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fAlignment(alignment.alignment) , fAlignment(alignment.alignment)
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
@ -86,7 +86,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1} , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
@ -105,7 +105,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory) : fair::mq::Message(factory)
, fManager(manager) , fManager(manager)
, fQueued(false) , fQueued(false)
, fMeta{size, reinterpret_cast<size_t>(hint), static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId(), -1} , fMeta{size, reinterpret_cast<size_t>(hint), -1, -1, static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId()}
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(static_cast<char*>(data)) , fLocalPtr(static_cast<char*>(data))
{ {
@ -187,8 +187,7 @@ class Message final : public fair::mq::Message
if (fMeta.fRegionId == 0) { if (fMeta.fRegionId == 0) {
if (fMeta.fSize > 0) { if (fMeta.fSize > 0) {
fManager.GetSegment(fMeta.fSegmentId); fManager.GetSegment(fMeta.fSegmentId);
ShmPtr shmPtr(reinterpret_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId))); fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
fLocalPtr = shmPtr.UserPtr();
} else { } else {
fLocalPtr = nullptr; fLocalPtr = nullptr;
} }
@ -218,8 +217,8 @@ class Message final : public fair::mq::Message
} else if (newSize <= fMeta.fSize) { } else if (newSize <= fMeta.fSize) {
try { try {
try { try {
ShmPtr shmPtr(fManager.ShrinkInPlace(newSize, static_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)), fMeta.fSegmentId)); char* ptr = fManager.ShrinkInPlace(newSize, fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId), fMeta.fSegmentId);
fLocalPtr = shmPtr.UserPtr(); fLocalPtr = ShmHeader::UserPtr(ptr);
fMeta.fSize = newSize; fMeta.fSize = newSize;
return true; return true;
} catch (boost::interprocess::bad_alloc& e) { } catch (boost::interprocess::bad_alloc& e) {
@ -227,17 +226,12 @@ class Message final : public fair::mq::Message
// unused size >= 1000000 bytes: reallocate fully // unused size >= 1000000 bytes: reallocate fully
// unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction // unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction
if (fMeta.fSize - newSize >= 1000000) { if (fMeta.fSize - newSize >= 1000000) {
ShmPtr shmPtr = fManager.Allocate(newSize, fAlignment); char* ptr = fManager.Allocate(newSize, fAlignment);
if (shmPtr.RealPtr()) { char* userPtr = ShmHeader::UserPtr(ptr);
char* userPtr = shmPtr.UserPtr();
std::memcpy(userPtr, fLocalPtr, newSize); std::memcpy(userPtr, fLocalPtr, newSize);
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fLocalPtr = userPtr; fLocalPtr = userPtr;
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
} else {
LOG(debug) << "could not set used size: " << e.what();
return false;
}
} }
fMeta.fSize = newSize; fMeta.fSize = newSize;
return true; return true;
@ -254,32 +248,64 @@ class Message final : public fair::mq::Message
Transport GetType() const override { return fair::mq::Transport::SHM; } Transport GetType() const override { return fair::mq::Transport::SHM; }
void Copy(const fair::mq::Message& msg) override uint16_t GetRefCount() const
{ {
if (fMeta.fHandle < 0) { if (fMeta.fHandle < 0) {
boost::interprocess::managed_shared_memory::handle_t otherHandle = static_cast<const Message&>(msg).fMeta.fHandle; return 1;
if (otherHandle) {
if (InitializeChunk(msg.GetSize())) {
std::memcpy(GetData(), msg.GetData(), msg.GetSize());
} }
if (fMeta.fRegionId == 0) { // managed segment
fManager.GetSegment(fMeta.fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
} else { // unmanaged region
if (fMeta.fShared < 0) { // UR msg is not yet shared
return 1;
} else { } else {
LOG(error) << "copy fail: source message not initialized!"; fManager.GetSegment(fMeta.fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
} }
} else {
LOG(error) << "copy fail: target message already initialized!";
} }
} }
~Message() override void Copy(const fair::mq::Message& other) override
{ {
try { const Message& otherMsg = static_cast<const Message&>(other);
if (otherMsg.fMeta.fHandle < 0) {
// if the other message is not initialized, close this one too and return
CloseMessage(); CloseMessage();
} catch(SharedMemoryError& sme) { return;
LOG(error) << "error closing message: " << sme.what(); }
} catch(boost::interprocess::lock_exception& le) {
LOG(error) << "error closing message: " << le.what(); if (fMeta.fHandle >= 0) {
// if this msg is already initialized, close it first
CloseMessage();
}
if (otherMsg.fMeta.fRegionId == 0) { // managed segment
fMeta = otherMsg.fMeta;
fManager.GetSegment(fMeta.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
} else { // unmanaged region
if (otherMsg.fMeta.fShared < 0) { // if UR msg is not yet shared
// TODO: minimize the size to 0 and don't create extra space for user buffer alignment
char* ptr = fManager.Allocate(2, 0);
// point the fShared in the unmanaged region message to the refCount holder
otherMsg.fMeta.fShared = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
// the message needs to be able to locate in which segment the refCount is stored
otherMsg.fMeta.fSegmentId = fMeta.fSegmentId;
// point this message to the same content as the unmanaged region message
fMeta = otherMsg.fMeta;
// increment the refCount
ShmHeader::IncrementRefCount(ptr);
} else { // if the UR msg is already shared
fMeta = otherMsg.fMeta;
fManager.GetSegment(fMeta.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
} }
} }
}
~Message() override { CloseMessage(); }
private: private:
Manager& fManager; Manager& fManager;
@ -291,23 +317,48 @@ class Message final : public fair::mq::Message
char* InitializeChunk(const size_t size, size_t alignment = 0) char* InitializeChunk(const size_t size, size_t alignment = 0)
{ {
ShmPtr shmPtr = fManager.Allocate(size, alignment); if (size == 0) {
if (shmPtr.RealPtr()) { fMeta.fSize = 0;
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); return fLocalPtr;
fMeta.fSize = size;
fLocalPtr = shmPtr.UserPtr();
} }
char* ptr = fManager.Allocate(size, alignment);
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
fMeta.fSize = size;
fLocalPtr = ShmHeader::UserPtr(ptr);
return fLocalPtr; return fLocalPtr;
} }
void Deallocate() void Deallocate()
{ {
if (fMeta.fHandle >= 0 && !fQueued) { if (fMeta.fHandle >= 0 && !fQueued) {
if (fMeta.fRegionId == 0) { if (fMeta.fRegionId == 0) { // managed segment
fManager.GetSegment(fMeta.fSegmentId); fManager.GetSegment(fMeta.fSegmentId);
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
if (refCount == 1) {
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fMeta.fHandle = -1; }
} else { // unmanaged region
if (fMeta.fShared >= 0) {
// make sure segment is initialized in this transport
fManager.GetSegment(fMeta.fSegmentId);
// release unmanaged region block if ref count is one
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
if (refCount == 1) {
fManager.Deallocate(fMeta.fShared, fMeta.fSegmentId);
ReleaseUnmanagedRegionBlock();
}
} else { } else {
ReleaseUnmanagedRegionBlock();
}
}
}
fMeta.fHandle = -1;
fLocalPtr = nullptr;
fMeta.fSize = 0;
}
void ReleaseUnmanagedRegionBlock()
{
if (!fRegionPtr) { if (!fRegionPtr) {
fRegionPtr = fManager.GetRegion(fMeta.fRegionId); fRegionPtr = fManager.GetRegion(fMeta.fRegionId);
} }
@ -318,17 +369,18 @@ class Message final : public fair::mq::Message
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
} }
} }
}
fLocalPtr = nullptr;
fMeta.fSize = 0;
}
void CloseMessage() void CloseMessage()
{ {
try {
Deallocate(); Deallocate();
fAlignment = 0; fAlignment = 0;
fManager.DecrementMsgCounter(); fManager.DecrementMsgCounter();
} catch(SharedMemoryError& sme) {
LOG(error) << "error closing message: " << sme.what();
} catch(boost::interprocess::lock_exception& le) {
LOG(error) << "error closing message: " << le.what();
}
} }
}; };

View File

@ -89,7 +89,7 @@ add_testsuite(Message
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx ${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
message/_message.cxx message/_message.cxx
LINKS FairMQ LINKS FairMQ PicoSHA2
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/message ${CMAKE_CURRENT_SOURCE_DIR}/message
${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}

View File

@ -6,19 +6,23 @@
* copied verbatim in the file "LICENSE" * * copied verbatim in the file "LICENSE" *
********************************************************************************/ ********************************************************************************/
#include <array>
#include <cassert>
#include <cstdint>
#include <fairlogger/Logger.h> #include <fairlogger/Logger.h>
#include <fairmq/Channel.h> #include <fairmq/Channel.h>
#include <fairmq/ProgOptions.h> #include <fairmq/ProgOptions.h>
#include <fairmq/TransportFactory.h> #include <fairmq/tools/Semaphore.h>
#include <fairmq/tools/Strings.h> #include <fairmq/tools/Strings.h>
#include <fairmq/tools/Unique.h> #include <fairmq/tools/Unique.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/shmem/Message.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <array>
#include <cassert>
#include <cstdint>
#include <memory> #include <memory>
#include <string>
#include <string_view> #include <string_view>
#include <string>
#include <utility> #include <utility>
namespace namespace
@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
push.Bind(address); push.Bind(address);
pull.Connect(address); pull.Connect(address);
{ {
auto outMsg(push.NewMessage()); auto outMsg(push.NewMessage());
ASSERT_EQ(outMsg->GetData(), nullptr); ASSERT_EQ(outMsg->GetData(), nullptr);
@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
} }
} }
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopy() -> void
{
ProgOptions config;
config.SetProperty<string>("session", tools::Uuid());
auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config));
unique_ptr<string> str(make_unique<string>("asdf"));
const size_t size = 2;
MessagePtr original(factory->CreateMessage(size));
memcpy(original->GetData(), "AB", size);
{
MessagePtr copy(factory->CreateMessage());
copy->Copy(*original);
EXPECT_EQ(original->GetSize(), copy->GetSize());
EXPECT_EQ(original->GetData(), copy->GetData());
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*copy).GetRefCount(), 2);
// buffer must be still intact
ASSERT_EQ(AsStringView(*original)[0], 'A');
ASSERT_EQ(AsStringView(*original)[1], 'B');
ASSERT_EQ(AsStringView(*copy)[0], 'A');
ASSERT_EQ(AsStringView(*copy)[1], 'B');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 1);
}
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopyFromUnmanaged(string const& address) -> void
{
ProgOptions config1;
ProgOptions config2;
string session(tools::Uuid());
config1.SetProperty<string>("session", session);
config2.SetProperty<string>("session", session);
// ref counts should be accessible accross different segments
config2.SetProperty<uint16_t>("shm-segment-id", 2);
auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1));
auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2));
const size_t msgSize{100};
const size_t regionSize{1000000};
tools::Semaphore blocker;
auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) {
blocker.Signal();
});
{
FairMQChannel push("Push", "push", factory1);
FairMQChannel pull("Pull", "pull", factory2);
push.Bind(address);
pull.Connect(address);
const size_t offset = 100;
auto msg1(push.NewMessage(region, static_cast<char*>(region->GetData()), msgSize, nullptr));
auto msg2(push.NewMessage(region, static_cast<char*>(region->GetData()) + offset, msgSize, nullptr));
const size_t contentSize = 2;
memcpy(msg1->GetData(), "AB", contentSize);
memcpy(msg2->GetData(), "CD", contentSize);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
{
auto copyFromOriginal(push.NewMessage());
copyFromOriginal->Copy(*msg1);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromOriginal).GetRefCount());
{
auto copyFromCopy(push.NewMessage());
copyFromCopy->Copy(*copyFromOriginal);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromCopy).GetRefCount());
EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData());
EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData());
EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData());
// messing with the ref count should not have affected the user buffer
ASSERT_EQ(AsStringView(*msg1)[0], 'A');
ASSERT_EQ(AsStringView(*msg1)[1], 'B');
push.Send(copyFromCopy);
push.Send(msg2);
auto incomingCopiedMsg(pull.NewMessage());
auto incomingOriginalMsg(pull.NewMessage());
pull.Receive(incomingCopiedMsg);
pull.Receive(incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingCopiedMsg).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A');
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B');
{
// copying on a different segment should work
auto copyFromIncoming(pull.NewMessage());
copyFromIncoming->Copy(*incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*copyFromIncoming).GetRefCount(), 2);
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C');
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
}
blocker.Wait();
blocker.Wait();
}
TEST(Resize, zeromq) // NOLINT TEST(Resize, zeromq) // NOLINT
{ {
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize"); RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT
EmptyMessage("shmem", "ipc://test_empty_message"); EmptyMessage("shmem", "ipc://test_empty_message");
} }
TEST(ZeroCopy, shmem) // NOLINT
{
ZeroCopy();
}
TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT
{
ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged");
}
} // namespace } // namespace

View File

@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address)
}); });
ptr2 = region2->GetData(); ptr2 = region2->GetData();
{ {
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get())); FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get())); FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));