Shmem: track number of message objects, throw if non-zero at reset

This commit is contained in:
Alexey Rybalchenko 2019-12-19 15:06:14 +01:00 committed by Dennis Klein
parent 5b5fecc994
commit 684e711b8b
5 changed files with 63 additions and 1 deletions

View File

@ -9,6 +9,7 @@
#include "Region.h" #include "Region.h"
#include "Message.h" #include "Message.h"
#include "UnmanagedRegion.h" #include "UnmanagedRegion.h"
#include "TransportFactory.h"
#include <FairMQLogger.h> #include <FairMQLogger.h>
@ -39,6 +40,7 @@ Message::Message(Manager& manager, FairMQTransportFactory* factory)
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
static_cast<TransportFactory*>(GetTransport())->IncrementMsgCounter();
} }
Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* factory) Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* factory)
@ -50,6 +52,7 @@ Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* fa
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
InitializeChunk(size); InitializeChunk(size);
static_cast<TransportFactory*>(GetTransport())->IncrementMsgCounter();
} }
Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory) Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory)
@ -60,6 +63,7 @@ Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* fact
, fRegionPtr(nullptr) , fRegionPtr(nullptr)
, fLocalPtr(nullptr) , fLocalPtr(nullptr)
{ {
static_cast<TransportFactory*>(GetTransport())->IncrementMsgCounter();
} }
Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint, FairMQTransportFactory* factory) Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint, FairMQTransportFactory* factory)
@ -78,6 +82,7 @@ Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn
free(data); free(data);
} }
} }
static_cast<TransportFactory*>(GetTransport())->IncrementMsgCounter();
} }
Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint, FairMQTransportFactory* factory) Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint, FairMQTransportFactory* factory)
@ -95,6 +100,7 @@ Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const
LOG(error) << "trying to create region message with data from outside the region"; LOG(error) << "trying to create region message with data from outside the region";
throw runtime_error("trying to create region message with data from outside the region"); throw runtime_error("trying to create region message with data from outside the region");
} }
static_cast<TransportFactory*>(GetTransport())->IncrementMsgCounter();
} }
bool Message::InitializeChunk(const size_t size) bool Message::InitializeChunk(const size_t size)
@ -225,6 +231,8 @@ void Message::CloseMessage()
} }
} }
} }
static_cast<TransportFactory*>(GetTransport())->DecrementMsgCounter();
} }
} }

View File

@ -10,6 +10,7 @@
#include "Socket.h" #include "Socket.h"
#include "Message.h" #include "Message.h"
#include "UnmanagedRegion.h" #include "UnmanagedRegion.h"
#include "TransportFactory.h"
#include <FairMQLogger.h> #include <FairMQLogger.h>
#include <fairmq/Tools.h> #include <fairmq/Tools.h>

View File

@ -47,6 +47,7 @@ TransportFactory::TransportFactory(const string& id, const ProgOptions* config)
, fManager(nullptr) , fManager(nullptr)
, fHeartbeatThread() , fHeartbeatThread()
, fSendHeartbeats(true) , fSendHeartbeats(true)
, fMsgCounter(0)
{ {
int major, minor, patch; int major, minor, patch;
zmq_version(&major, &minor, &patch); zmq_version(&major, &minor, &patch);
@ -168,6 +169,15 @@ Transport TransportFactory::GetType() const
return fTransportType; return fTransportType;
} }
void TransportFactory::Reset()
{
if (fMsgCounter.load() != 0) {
LOG(error) << "Message counter during Reset expected to be 0, found: " << fMsgCounter.load();
throw MessageError(tools::ToString("Message counter during Reset expected to be 0, found: ", fMsgCounter.load()));
}
}
TransportFactory::~TransportFactory() TransportFactory::~TransportFactory()
{ {
LOG(debug) << "Destroying Shared Memory transport..."; LOG(debug) << "Destroying Shared Memory transport...";

View File

@ -55,7 +55,10 @@ class TransportFactory final : public fair::mq::TransportFactory
void Interrupt() override { Socket::Interrupt(); } void Interrupt() override { Socket::Interrupt(); }
void Resume() override { Socket::Resume(); } void Resume() override { Socket::Resume(); }
void Reset() override {} void Reset() override;
void IncrementMsgCounter() { ++fMsgCounter; }
void DecrementMsgCounter() { --fMsgCounter; }
~TransportFactory() override; ~TransportFactory() override;
@ -69,6 +72,7 @@ class TransportFactory final : public fair::mq::TransportFactory
std::unique_ptr<Manager> fManager; std::unique_ptr<Manager> fManager;
std::thread fHeartbeatThread; std::thread fHeartbeatThread;
std::atomic<bool> fSendHeartbeats; std::atomic<bool> fSendHeartbeats;
std::atomic<int32_t> fMsgCounter;
}; };
} // namespace shmem } // namespace shmem

View File

@ -11,6 +11,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <boost/process.hpp> #include <boost/process.hpp>
#include <fairmq/Tools.h> #include <fairmq/Tools.h>
#include <FairMQDevice.h>
#include <string> #include <string>
#include <thread> #include <thread>
@ -23,6 +24,39 @@ using namespace std;
using namespace fair::mq::test; using namespace fair::mq::test;
using namespace fair::mq::tools; using namespace fair::mq::tools;
class BadDevice : public FairMQDevice
{
public:
BadDevice()
{
fDeviceThread = thread([&](){
EXPECT_THROW(RunStateMachine(), fair::mq::MessageError);
});
SetTransport("shmem");
ChangeState(fair::mq::Transition::InitDevice);
WaitForState(fair::mq::State::InitializingDevice);
ChangeState(fair::mq::Transition::CompleteInit);
WaitForState(fair::mq::State::Initialized);
parts.AddPart(NewMessage());
}
~BadDevice()
{
ChangeState(fair::mq::Transition::ResetDevice);
if (fDeviceThread.joinable()) {
fDeviceThread.join();
}
}
private:
thread fDeviceThread;
FairMQParts parts;
};
void RunErrorStateIn(const string& state, const string& control, const string& input = "") void RunErrorStateIn(const string& state, const string& control, const string& input = "")
{ {
size_t session{fair::mq::tools::UuidHash()}; size_t session{fair::mq::tools::UuidHash()};
@ -118,4 +152,9 @@ TEST(ErrorState, interactive_InReset)
EXPECT_EXIT(RunErrorStateIn("Reset", "interactive", "q"), ::testing::ExitedWithCode(1), ""); EXPECT_EXIT(RunErrorStateIn("Reset", "interactive", "q"), ::testing::ExitedWithCode(1), "");
} }
TEST(ErrorState, OrphanMessages)
{
BadDevice badDevice;
}
} // namespace } // namespace