diff --git a/fairmq/sdk/Topology.h b/fairmq/sdk/Topology.h index 180d2a7c..1ee1e654 100644 --- a/fairmq/sdk/Topology.h +++ b/fairmq/sdk/Topology.h @@ -68,6 +68,7 @@ const std::map expectedState = struct DeviceStatus { bool initialized; + DeviceState lastState; DeviceState state; DDSTask::Id taskId; DDSCollection::Id collectionId; @@ -192,13 +193,12 @@ class BasicTopology : public AsioBase switch (cmd->GetType()) { case Type::state_change: { auto _cmd = static_cast(*cmd); - DDSTask::Id taskId(_cmd.GetTaskId()); if (_cmd.GetCurrentState() == DeviceState::Exiting) { Cmds outCmds; outCmds.Add(); fDDSSession.SendCommand(outCmds.Serialize(), senderId); } - UpdateStateEntry(taskId, _cmd.GetCurrentState()); + HandleCmd(_cmd); } break; case Type::state_change_subscription: if (static_cast(*cmd).GetResult() != Result::Ok) { @@ -430,6 +430,212 @@ class BasicTopology : public AsioBase auto StateEqualsTo(DeviceState state) const -> bool { return sdk::StateEqualsTo(GetCurrentState(), state); } + + + + + + + + + + + + + using WaitForStateCompletionSignature = void(std::error_code); + + private: + struct WaitForStateOp + { + using Id = std::size_t; + using Count = unsigned int; + + template + WaitForStateOp(Id id, + DeviceState targetLastState, + DeviceState targetCurrentState, + std::vector tasks, + Duration timeout, + std::mutex& mutex, + Executor const & ex, + Allocator const & alloc, + Handler&& handler) + : fId(id) + , fOp(ex, alloc, std::move(handler)) + , fTimer(ex) + , fCount(0) + , fTasks(std::move(tasks)) + , fTargetLastState(targetLastState) + , fTargetCurrentState(targetCurrentState) + , fMtx(mutex) + , fCompleted(false) + { + if (timeout > std::chrono::milliseconds(0)) { + fTimer.expires_after(timeout); + fTimer.async_wait([&](std::error_code ec) { + if (!ec) { + std::lock_guard lk(fMtx); + fOp.Timeout(); + } + }); + } + } + WaitForStateOp() = delete; + WaitForStateOp(const WaitForStateOp&) = delete; + WaitForStateOp& operator=(const WaitForStateOp&) = delete; + WaitForStateOp(WaitForStateOp&&) = default; + WaitForStateOp& operator=(WaitForStateOp&&) = default; + ~WaitForStateOp() = default; + + /// precondition: fMtx is locked. + auto ResetCount(const TopologyStateIndex& stateIndex, const TopologyState& stateData) -> void + { + LOG(info) << "Resetting count and expecting fTargetLastState=" << fTargetLastState << ",fTargetCurrentState=" << fTargetCurrentState; + fCount = std::count_if(stateIndex.cbegin(), stateIndex.cend(), [=](const auto& s) { + if (ContainsTask(stateData.at(s.second).taskId)) { + if (stateData.at(s.second).state == fTargetCurrentState && + (stateData.at(s.second).lastState == fTargetLastState || + fTargetLastState == DeviceState::Ok)) { + return true; + } else { + return false; + } + } else { + return false; + } + }); + } + + /// precondition: fMtx is locked. + auto Update(const DDSTask::Id taskId, const DeviceState lastState, const DeviceState currentState) -> void + { + if (!fCompleted && ContainsTask(taskId)) { + LOG(info) << "Update: lastState=" << lastState << ",currentState=" << currentState; + if (currentState == fTargetCurrentState && + (lastState == fTargetLastState || + fTargetLastState == DeviceState::Ok)) { + ++fCount; + } + TryCompletion(); + } + } + + /// precondition: fMtx is locked. + auto TryCompletion() -> void + { + LOG(info) << "fCount: " << fCount; + if (!fOp.IsCompleted() && fCount == fTasks.size()) { + fCompleted = true; + fTimer.cancel(); + fOp.Complete(); + } + } + + private: + Id const fId; + AsioAsyncOp fOp; + asio::steady_timer fTimer; + Count fCount; + std::vector fTasks; + DeviceState fTargetLastState; + DeviceState fTargetCurrentState; + std::mutex& fMtx; + bool fCompleted; + + /// precondition: fMtx is locked. + auto ContainsTask(DDSTask::Id id) -> bool + { + auto it = std::find_if(fTasks.begin(), fTasks.end(), [id](const DDSTask& t) { return t.GetId() == id; }); + return it != fTasks.end(); + } + }; + + auto HandleCmd(cmd::StateChange const& cmd) -> void + { + DDSTask::Id taskId(cmd.GetTaskId()); + UpdateStateEntry(taskId, cmd.GetLastState(), cmd.GetCurrentState()); + + std::lock_guard lk(fMtx); + for (auto& op : fWaitForStateOps) { + op.second.Update(taskId, cmd.GetLastState(), cmd.GetCurrentState()); + } + } + + public: + template + auto AsyncWaitForState(const DeviceState targetLastState, + const DeviceState targetCurrentState, + const std::string& path, + Duration timeout, + CompletionToken&& token) + { + return asio::async_initiate([&](auto handler) { + typename GetPropertiesOp::Id const id(tools::UuidHash()); + + // TODO Implement garbage collection of completed ops + std::lock_guard lk(fMtx); + auto p = fWaitForStateOps.emplace( + std::piecewise_construct, + std::forward_as_tuple(id), + std::forward_as_tuple(id, + targetLastState, + targetCurrentState, + fDDSTopo.GetTasks(path), + timeout, + fMtx, + AsioBase::GetExecutor(), + AsioBase::GetAllocator(), + std::move(handler))); + p.first->second.ResetCount(fStateIndex, fStateData); + p.first->second.TryCompletion(); + }, + token); + } + + template + auto AsyncWaitForState(const DeviceState targetCurrentState, CompletionToken&& token) + { + return AsyncWaitForState(DeviceState::Ok, targetCurrentState, "", Duration(0), std::move(token)); + } + + template + auto AsyncWaitForState(const DeviceState targetLastState, const DeviceState targetCurrentState, CompletionToken&& token) + { + return AsyncWaitForState(targetLastState, targetCurrentState, "", Duration(0), std::move(token)); + } + + auto WaitForState(const DeviceState targetLastState, const DeviceState targetCurrentState, const std::string& path = "", Duration timeout = Duration(0)) + -> std::error_code + { + tools::SharedSemaphore blocker; + std::error_code ec; + AsyncWaitForState(targetLastState, targetCurrentState, path, timeout, [&, blocker](std::error_code _ec) mutable { + ec = _ec; + blocker.Signal(); + }); + blocker.Wait(); + return ec; + } + + auto WaitForState(const DeviceState targetCurrentState, const std::string& path = "", Duration timeout = Duration(0)) + -> std::error_code + { + return WaitForState(DeviceState::Ok, targetCurrentState, path, timeout); + } + + + + + + + + + + + + + + using GetPropertiesCompletionSignature = void(std::error_code, GetPropertiesResult); private: @@ -762,6 +968,7 @@ class BasicTopology : public AsioBase std::unordered_map fSetPropertiesOps; std::unordered_map fGetPropertiesOps; + std::unordered_map fWaitForStateOps; auto makeTopologyState() -> void { @@ -770,19 +977,20 @@ class BasicTopology : public AsioBase int index = 0; for (const auto& task : fDDSTopo.GetTasks()) { - fStateData.push_back(DeviceStatus{false, DeviceState::Ok, task.GetId(), task.GetCollectionId()}); + fStateData.push_back(DeviceStatus{false, DeviceState::Ok, DeviceState::Ok, task.GetId(), task.GetCollectionId()}); fStateIndex.emplace(task.GetId(), index); index++; } } - auto UpdateStateEntry(const DDSTask::Id taskId, const DeviceState state) -> void + auto UpdateStateEntry(const DDSTask::Id taskId, const DeviceState lastState, const DeviceState currentState) -> void { try { std::lock_guard lk(fMtx); DeviceStatus& task = fStateData.at(fStateIndex.at(taskId)); task.initialized = true; - task.state = state; + task.lastState = lastState; + task.state = currentState; if (task.state == fChangeStateTarget) { ++fTransitionedCount; } diff --git a/test/sdk/_topology.cxx b/test/sdk/_topology.cxx index 93b9fb0f..e5ea101b 100644 --- a/test/sdk/_topology.cxx +++ b/test/sdk/_topology.cxx @@ -26,8 +26,8 @@ TEST(TopologyHelper, MakeTopology) sdk::DDSEnv env(CMAKE_CURRENT_BINARY_DIR); ///////////////////////////////////// - dds::topology_api::CTopology nativeTopo( - tools::ToString(SDK_TESTSUITE_SOURCE_DIR, "/test_topo.xml")); + std::string topoFile(tools::ToString(SDK_TESTSUITE_SOURCE_DIR, "/test_topo.xml")); + dds::topology_api::CTopology nativeTopo(topoFile); auto nativeSession(std::make_shared()); nativeSession->create(); EXPECT_THROW(sdk::MakeTopology(nativeTopo, nativeSession, env), sdk::RuntimeError); @@ -219,6 +219,27 @@ TEST_F(Topology, ChangeStateFullDeviceLifecycle) } } +TEST_F(Topology, WaitForStateFullDeviceLifecycle) +{ + using namespace fair::mq; + using fair::mq::sdk::TopologyTransition; + + sdk::Topology topo(mDDSTopo, mDDSSession); + for (auto transition : {TopologyTransition::InitDevice, + TopologyTransition::CompleteInit, + TopologyTransition::Bind, + TopologyTransition::Connect, + TopologyTransition::InitTask, + TopologyTransition::Run, + TopologyTransition::Stop, + TopologyTransition::ResetTask, + TopologyTransition::ResetDevice, + TopologyTransition::End}) { + LOG(info) << topo.ChangeState(transition).first; + ASSERT_EQ(topo.WaitForState(sdk::expectedState.at(transition)), std::error_code()); + } +} + TEST_F(Topology, ChangeStateFullDeviceLifecycle2) { using namespace fair::mq;