Implement bulk callbacks for unmanaged regions

This commit is contained in:
Alexey Rybalchenko
2020-05-15 00:20:49 +02:00
parent a15d59c725
commit d22023bcb5
24 changed files with 243 additions and 83 deletions

View File

@@ -32,7 +32,7 @@ void RegionEventSubscriptions(const string& transport)
constexpr int size1 = 1000000;
constexpr int size2 = 5000000;
constexpr int64_t userFlags = 12345;
fair::mq::tools::SharedSemaphore blocker;
fair::mq::tools::Semaphore blocker;
{
auto region1 = factory->CreateUnmanagedRegion(size1, [](void*, size_t, void*) {});
@@ -90,6 +90,72 @@ void RegionEventSubscriptions(const string& transport)
ASSERT_EQ(factory->SubscribedToRegionEvents(), false);
}
void RegionCallbacks(const string& transport, const string& _address)
{
size_t session(fair::mq::tools::UuidHash());
std::string address(fair::mq::tools::ToString(_address, "_", transport));
fair::mq::ProgOptions config;
config.SetProperty<string>("session", to_string(session));
auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config);
unique_ptr<int> intPtr1 = fair::mq::tools::make_unique<int>(42);
unique_ptr<int> intPtr2 = fair::mq::tools::make_unique<int>(43);
fair::mq::tools::Semaphore blocker;
FairMQChannel push("Push", "push", factory);
push.Bind(address);
FairMQChannel pull("Pull", "pull", factory);
pull.Connect(address);
void* ptr1 = nullptr;
size_t size1 = 100;
void* ptr2 = nullptr;
size_t size2 = 200;
auto region1 = factory->CreateUnmanagedRegion(2000000, [&](void* ptr, size_t size, void* hint) {
ASSERT_EQ(ptr, ptr1);
ASSERT_EQ(size, size1);
ASSERT_EQ(hint, intPtr1.get());
ASSERT_EQ(*static_cast<int*>(hint), 42);
blocker.Signal();
});
ptr1 = region1->GetData();
auto region2 = factory->CreateUnmanagedRegion(3000000, [&](const std::vector<fair::mq::RegionBlock>& blocks) {
ASSERT_EQ(blocks.size(), 1);
ASSERT_EQ(blocks.at(0).ptr, ptr2);
ASSERT_EQ(blocks.at(0).size, size2);
ASSERT_EQ(blocks.at(0).hint, intPtr2.get());
ASSERT_EQ(*static_cast<int*>(blocks.at(0).hint), 43);
blocker.Signal();
});
ptr2 = region2->GetData();
{
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));
ASSERT_EQ(push.Send(msg1out), size1);
ASSERT_EQ(push.Send(msg2out), size2);
}
{
FairMQMessagePtr msg1in(pull.NewMessage());
FairMQMessagePtr msg2in(pull.NewMessage());
ASSERT_EQ(pull.Receive(msg1in), size1);
ASSERT_EQ(pull.Receive(msg2in), size2);
}
LOG(info) << "waiting for blockers...";
blocker.Wait();
LOG(info) << "1 done.";
blocker.Wait();
LOG(info) << "2 done.";
}
TEST(EventSubscriptions, zeromq)
{
RegionEventSubscriptions("zeromq");
@@ -100,4 +166,14 @@ TEST(EventSubscriptions, shmem)
RegionEventSubscriptions("shmem");
}
TEST(Callbacks, zeromq)
{
RegionCallbacks("zeromq", "ipc://test_region_callbacks");
}
TEST(Callbacks, shmem)
{
RegionCallbacks("shmem", "ipc://test_region_callbacks");
}
} // namespace

View File

@@ -25,7 +25,7 @@ namespace
using namespace std;
void CheckOldOptionInterface(FairMQChannel& channel, const string& option, const string& transport)
void CheckOldOptionInterface(FairMQChannel& channel, const string& option)
{
int value = 500;
channel.GetSocket().SetOption(option, &value, sizeof(value));
@@ -44,11 +44,11 @@ void RunOptionsTest(const string& transport)
auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config);
FairMQChannel channel("Push", "push", factory);
CheckOldOptionInterface(channel, "linger", transport);
CheckOldOptionInterface(channel, "snd-hwm", transport);
CheckOldOptionInterface(channel, "rcv-hwm", transport);
CheckOldOptionInterface(channel, "snd-size", transport);
CheckOldOptionInterface(channel, "rcv-size", transport);
CheckOldOptionInterface(channel, "linger");
CheckOldOptionInterface(channel, "snd-hwm");
CheckOldOptionInterface(channel, "rcv-hwm");
CheckOldOptionInterface(channel, "snd-size");
CheckOldOptionInterface(channel, "rcv-size");
channel.GetSocket().SetLinger(300);
ASSERT_EQ(channel.GetSocket().GetLinger(), 300);