diff --git a/fairmq/sdk/Topology.h b/fairmq/sdk/Topology.h index 3f3a4d72..fbaa451d 100644 --- a/fairmq/sdk/Topology.h +++ b/fairmq/sdk/Topology.h @@ -174,9 +174,6 @@ class BasicTopology : public AsioBase , fDDSTopo(std::move(topo)) , fStateData() , fStateIndex() - , fChangeStateOp() - , fChangeStateOpTimer(ex) - , fChangeStateTarget(DeviceState::Idle) { makeTopologyState(); @@ -216,10 +213,13 @@ class BasicTopology : public AsioBase auto _cmd = static_cast(*cmd); if (_cmd.GetResult() != Result::Ok) { FAIR_LOG(error) << _cmd.GetTransition() << " transition failed for " << _cmd.GetDeviceId(); + DDSTask::Id id(_cmd.GetTaskId()); std::lock_guard lk(fMtx); - if (!fChangeStateOp.IsCompleted() && fStateData.at(fStateIndex.at(_cmd.GetTaskId())).state != fChangeStateTarget) { - fChangeStateOpTimer.cancel(); - fChangeStateOp.Complete(MakeErrorCode(ErrorCode::DeviceChangeStateFailed), fStateData); + for (auto& op : fChangeStateOps) { + if (!op.second.IsCompleted() && op.second.ContainsTask(id) && + fStateData.at(fStateIndex.at(id)).state != op.second.GetTargetState()) { + op.second.Complete(MakeErrorCode(ErrorCode::DeviceChangeStateFailed)); + } } } } @@ -270,15 +270,172 @@ class BasicTopology : public AsioBase std::lock_guard lk(fMtx); fDDSSession.UnsubscribeFromCommands(); try { - fChangeStateOp.Cancel(fStateData); + for (auto& op : fChangeStateOps) { + op.second.Complete(MakeErrorCode(ErrorCode::OperationCanceled)); + } } catch (...) {} } + auto HandleCmd(cmd::StateChange const& cmd) -> void + { + DDSTask::Id taskId(cmd.GetTaskId()); + std::lock_guard lk(fMtx); + + try { + DeviceStatus& task = fStateData.at(fStateIndex.at(taskId)); + task.initialized = true; + task.lastState = cmd.GetLastState(); + task.state = cmd.GetCurrentState(); + // FAIR_LOG(debug) << "Updated state entry: taskId=" << taskId << ", state=" << state; + + for (auto& op : fChangeStateOps) { + op.second.Update(taskId, cmd.GetCurrentState()); + } + for (auto& op : fWaitForStateOps) { + op.second.Update(taskId, cmd.GetLastState(), cmd.GetCurrentState()); + } + } catch (const std::exception& e) { + FAIR_LOG(error) << "Exception in HandleCmd(cmd::StateChange const&): " << e.what(); + } + } + + auto HandleCmd(cmd::Properties const& cmd) -> void + { + std::unique_lock lk(fMtx); + try { + auto& op(fGetPropertiesOps.at(cmd.GetRequestId())); + lk.unlock(); + op.Update(cmd.GetDeviceId(), cmd.GetResult(), cmd.GetProps()); + } catch (std::out_of_range& e) { + FAIR_LOG(debug) << "GetProperties operation (request id: " << cmd.GetRequestId() + << ") not found (probably completed or timed out), " + << "discarding reply of device " << cmd.GetDeviceId(); + } + } + + auto HandleCmd(cmd::PropertiesSet const& cmd) -> void + { + std::unique_lock lk(fMtx); + try { + auto& op(fSetPropertiesOps.at(cmd.GetRequestId())); + lk.unlock(); + op.Update(cmd.GetDeviceId(), cmd.GetResult()); + } catch (std::out_of_range& e) { + FAIR_LOG(debug) << "SetProperties operation (request id: " << cmd.GetRequestId() + << ") not found (probably completed or timed out), " + << "discarding reply of device " << cmd.GetDeviceId(); + } + } + using Duration = std::chrono::milliseconds; using ChangeStateCompletionSignature = void(std::error_code, TopologyState); + private: + struct ChangeStateOp + { + using Id = std::size_t; + using Count = unsigned int; + + template + ChangeStateOp(Id id, + const TopologyTransition transition, + std::vector tasks, + TopologyState& stateData, + Duration timeout, + std::mutex& mutex, + Executor const & ex, + Allocator const & alloc, + Handler&& handler) + : fId(id) + , fOp(ex, alloc, std::move(handler)) + , fStateData(stateData) + , fTimer(ex) + , fCount(0) + , fTasks(std::move(tasks)) + , fTargetState(expectedState.at(transition)) + , fMtx(mutex) + { + if (timeout > std::chrono::milliseconds(0)) { + fTimer.expires_after(timeout); + fTimer.async_wait([&](std::error_code ec) { + if (!ec) { + std::lock_guard lk(fMtx); + fOp.Timeout(fStateData); + } + }); + } + } + ChangeStateOp() = delete; + ChangeStateOp(const ChangeStateOp&) = delete; + ChangeStateOp& operator=(const ChangeStateOp&) = delete; + ChangeStateOp(ChangeStateOp&&) = default; + ChangeStateOp& operator=(ChangeStateOp&&) = default; + ~ChangeStateOp() = default; + + /// precondition: fMtx is locked. + auto ResetCount(const TopologyStateIndex& stateIndex, const TopologyState& stateData) -> void + { + fCount = std::count_if(stateIndex.cbegin(), stateIndex.cend(), [=](const auto& s) { + if (ContainsTask(stateData.at(s.second).taskId)) { + return stateData.at(s.second).state == fTargetState; + } else { + return false; + } + }); + } + + /// precondition: fMtx is locked. + auto Update(const DDSTask::Id taskId, const DeviceState currentState) -> void + { + if (!fOp.IsCompleted() && ContainsTask(taskId)) { + if (currentState == fTargetState) { + ++fCount; + } + TryCompletion(); + } + } + + /// precondition: fMtx is locked. + auto TryCompletion() -> void + { + if (!fOp.IsCompleted() && fCount == fTasks.size()) { + Complete(std::error_code()); + } + } + + /// precondition: fMtx is locked. + auto Complete(std::error_code ec) -> void + { + fTimer.cancel(); + fOp.Complete(ec, fStateData); + } + + /// precondition: fMtx is locked. + auto ContainsTask(DDSTask::Id id) -> bool + { + auto it = std::find_if(fTasks.begin(), fTasks.end(), [id](const DDSTask& t) { return t.GetId() == id; }); + return it != fTasks.end(); + } + + bool IsCompleted() { return fOp.IsCompleted(); } + + auto GetTargetState() const -> DeviceState { return fTargetState; } + + private: + Id const fId; + AsioAsyncOp fOp; + TopologyState& fStateData; + asio::steady_timer fTimer; + Count fCount; + std::vector fTasks; + DeviceState fTargetState; + std::mutex& fMtx; + }; + + public: /// @brief Initiate state transition on all FairMQ devices in this topology /// @param transition FairMQ device state machine transition + /// @param path Select a subset of FairMQ devices in this topology, empty selects all /// @param timeout Timeout in milliseconds, 0 means no timeout /// @param token Asio completion token /// @tparam CompletionToken Asio completion token type @@ -301,8 +458,6 @@ class BasicTopology : public AsioBase /// // async operation canceled /// case fair::mq::ErrorCode::DeviceChangeStateFailed: /// // failed to change state of a fairmq device - /// case fair::mq::ErrorCode::OperationInProgress: - /// // async operation already in progress /// default: /// } /// } @@ -327,8 +482,6 @@ class BasicTopology : public AsioBase /// // async operation canceled /// case fair::mq::ErrorCode::DeviceChangeStateFailed: /// // failed to change state of a fairmq device - /// case fair::mq::ErrorCode::OperationInProgress: - /// // async operation already in progress /// default: /// } /// } @@ -352,8 +505,6 @@ class BasicTopology : public AsioBase /// // async operation canceled /// case fair::mq::ErrorCode::DeviceChangeStateFailed: /// // failed to change state of a fairmq device - /// case fair::mq::ErrorCode::OperationInProgress: - /// // async operation already in progress /// default: /// } /// } @@ -361,76 +512,112 @@ class BasicTopology : public AsioBase /// @endcode template auto AsyncChangeState(const TopologyTransition transition, + const std::string& path, Duration timeout, CompletionToken&& token) { return asio::async_initiate([&](auto handler) { + typename ChangeStateOp::Id const id(tools::UuidHash()); + std::lock_guard lk(fMtx); - if (fChangeStateOp.IsCompleted()) { - fChangeStateOp = ChangeStateOp(AsioBase::GetExecutor(), - AsioBase::GetAllocator(), - std::move(handler)); - fChangeStateTarget = expectedState.at(transition); - ResetTransitionedCount(fChangeStateTarget); - cmd::Cmds cmds(cmd::make(transition)); - fDDSSession.SendCommand(cmds.Serialize()); - if (timeout > std::chrono::milliseconds(0)) { - fChangeStateOpTimer.expires_after(timeout); - fChangeStateOpTimer.async_wait([&](std::error_code ec) { - if (!ec) { - std::lock_guard lk2(fMtx); - fChangeStateOp.Timeout(fStateData); - } - }); + for (auto it = begin(fChangeStateOps); it != end(fChangeStateOps);) { + if (it->second.IsCompleted()) { + it = fChangeStateOps.erase(it); + } else { + ++it; } - } else { - // TODO refactor to hide boiler plate - auto ex2(asio::get_associated_executor(handler, AsioBase::GetExecutor())); - auto alloc2(asio::get_associated_allocator(handler, AsioBase::GetAllocator())); - auto state(GetCurrentStateUnsafe()); - - ex2.post([h = std::move(handler), s = std::move(state)]() mutable { - try { - h(MakeErrorCode(ErrorCode::OperationInProgress), s); - } catch (const std::exception& e) { - FAIR_LOG(error) << "Uncaught exception in completion handler: " << e.what(); - } catch (...) { - FAIR_LOG(error) << "Unknown uncaught exception in completion handler."; - } - }, - alloc2); } + + auto p = fChangeStateOps.emplace( + std::piecewise_construct, + std::forward_as_tuple(id), + std::forward_as_tuple(id, + transition, + fDDSTopo.GetTasks(path), + fStateData, + timeout, + fMtx, + AsioBase::GetExecutor(), + AsioBase::GetAllocator(), + std::move(handler))); + + cmd::Cmds cmds(cmd::make(transition)); + fDDSSession.SendCommand(cmds.Serialize(), path); + + p.first->second.ResetCount(fStateIndex, fStateData); + // TODO: make sure following operation properly queues the completion and not doing it directly out of initiation call. + p.first->second.TryCompletion(); + }, token); } + /// @brief Initiate state transition on all FairMQ devices in this topology + /// @param transition FairMQ device state machine transition + /// @param token Asio completion token + /// @tparam CompletionToken Asio completion token type + /// @throws std::system_error template auto AsyncChangeState(const TopologyTransition transition, CompletionToken&& token) { - return AsyncChangeState(transition, Duration(0), std::move(token)); + return AsyncChangeState(transition, "", Duration(0), std::move(token)); } - /// @brief Perform state transition on all FairMQ devices in this topology + /// @brief Initiate state transition on all FairMQ devices in this topology with a timeout /// @param transition FairMQ device state machine transition /// @param timeout Timeout in milliseconds, 0 means no timeout + /// @param token Asio completion token + /// @tparam CompletionToken Asio completion token type /// @throws std::system_error - auto ChangeState(const TopologyTransition transition, Duration timeout = Duration(0)) + template + auto AsyncChangeState(const TopologyTransition transition, Duration timeout, CompletionToken&& token) + { + return AsyncChangeState(transition, "", timeout, std::move(token)); + } + + /// @brief Initiate state transition on all FairMQ devices in this topology with a timeout + /// @param transition FairMQ device state machine transition + /// @param path Select a subset of FairMQ devices in this topology, empty selects all + /// @param token Asio completion token + /// @tparam CompletionToken Asio completion token type + /// @throws std::system_error + template + auto AsyncChangeState(const TopologyTransition transition, const std::string& path, CompletionToken&& token) + { + return AsyncChangeState(transition, path, Duration(0), std::move(token)); + } + + /// @brief Perform state transition on FairMQ devices in this topology for a specified topology path + /// @param transition FairMQ device state machine transition + /// @param path Select a subset of FairMQ devices in this topology, empty selects all + /// @param timeout Timeout in milliseconds, 0 means no timeout + /// @throws std::system_error + auto ChangeState(const TopologyTransition transition, const std::string& path = "", Duration timeout = Duration(0)) -> std::pair { tools::SharedSemaphore blocker; std::error_code ec; TopologyState state; - AsyncChangeState( - transition, timeout, [&, blocker](std::error_code _ec, TopologyState _state) mutable { - ec = _ec; - state = _state; - blocker.Signal(); - }); + AsyncChangeState(transition, path, timeout, [&, blocker](std::error_code _ec, TopologyState _state) mutable { + ec = _ec; + state = _state; + blocker.Signal(); + }); blocker.Wait(); return {ec, state}; } + /// @brief Perform state transition on all FairMQ devices in this topology with a timeout + /// @param transition FairMQ device state machine transition + /// @param timeout Timeout in milliseconds, 0 means no timeout + /// @throws std::system_error + auto ChangeState(const TopologyTransition transition, Duration timeout) + -> std::pair + { + return ChangeState(transition, "", timeout); + } + /// @brief Returns the current state of the topology /// @return map of id : DeviceStatus (initialized, state) auto GetCurrentState() const -> TopologyState @@ -492,13 +679,9 @@ class BasicTopology : public AsioBase { fCount = std::count_if(stateIndex.cbegin(), stateIndex.cend(), [=](const auto& s) { if (ContainsTask(stateData.at(s.second).taskId)) { - if (stateData.at(s.second).state == fTargetCurrentState && - (stateData.at(s.second).lastState == fTargetLastState || - fTargetLastState == DeviceState::Ok)) { - return true; - } else { - return false; - } + return stateData.at(s.second).state == fTargetCurrentState + && + (stateData.at(s.second).lastState == fTargetLastState || fTargetLastState == DeviceState::Ok); } else { return false; } @@ -547,17 +730,6 @@ class BasicTopology : public AsioBase } }; - auto HandleCmd(cmd::StateChange const& cmd) -> void - { - DDSTask::Id taskId(cmd.GetTaskId()); - UpdateStateEntry(taskId, cmd.GetLastState(), cmd.GetCurrentState()); - - std::lock_guard lk(fMtx); - for (auto& op : fWaitForStateOps) { - op.second.Update(taskId, cmd.GetLastState(), cmd.GetCurrentState()); - } - } - public: /// @brief Initiate waiting for selected FairMQ devices to reach given last & current state in this topology /// @param targetLastState the target last device state to wait for @@ -600,6 +772,7 @@ class BasicTopology : public AsioBase AsioBase::GetAllocator(), std::move(handler))); p.first->second.ResetCount(fStateIndex, fStateData); + // TODO: make sure following operation properly queues the completion and not doing it directly out of initiation call. p.first->second.TryCompletion(); }, token); @@ -736,20 +909,6 @@ class BasicTopology : public AsioBase } }; - auto HandleCmd(cmd::Properties const& cmd) -> void - { - std::unique_lock lk(fMtx); - try { - auto& op(fGetPropertiesOps.at(cmd.GetRequestId())); - lk.unlock(); - op.Update(cmd.GetDeviceId(), cmd.GetResult(), cmd.GetProps()); - } catch (std::out_of_range& e) { - FAIR_LOG(debug) << "GetProperties operation (request id: " << cmd.GetRequestId() - << ") not found (probably completed or timed out), " - << "discarding reply of device " << cmd.GetDeviceId(); - } - } - public: /// @brief Initiate property query on selected FairMQ devices in this topology /// @param query Key(s) to be queried (regex) @@ -903,20 +1062,6 @@ class BasicTopology : public AsioBase } }; - auto HandleCmd(cmd::PropertiesSet const& cmd) -> void - { - std::unique_lock lk(fMtx); - try { - auto& op(fSetPropertiesOps.at(cmd.GetRequestId())); - lk.unlock(); - op.Update(cmd.GetDeviceId(), cmd.GetResult()); - } catch (std::out_of_range& e) { - FAIR_LOG(debug) << "SetProperties operation (request id: " << cmd.GetRequestId() - << ") not found (probably completed or timed out), " - << "discarding reply of device " << cmd.GetDeviceId(); - } - } - public: /// @brief Initiate property update on selected FairMQ devices in this topology /// @param props Properties to set @@ -1002,15 +1147,10 @@ class BasicTopology : public AsioBase TopologyStateIndex fStateIndex; mutable std::mutex fMtx; - using ChangeStateOp = AsioAsyncOp; - ChangeStateOp fChangeStateOp; - asio::steady_timer fChangeStateOpTimer; - DeviceState fChangeStateTarget; - TransitionedCount fTransitionedCount; - + std::unordered_map fChangeStateOps; + std::unordered_map fWaitForStateOps; std::unordered_map fSetPropertiesOps; std::unordered_map fGetPropertiesOps; - std::unordered_map fWaitForStateOps; auto makeTopologyState() -> void { @@ -1025,41 +1165,6 @@ class BasicTopology : public AsioBase } } - auto UpdateStateEntry(const DDSTask::Id taskId, const DeviceState lastState, const DeviceState currentState) -> void - { - try { - std::lock_guard lk(fMtx); - DeviceStatus& task = fStateData.at(fStateIndex.at(taskId)); - task.initialized = true; - task.lastState = lastState; - task.state = currentState; - if (task.state == fChangeStateTarget) { - ++fTransitionedCount; - } - // FAIR_LOG(debug) << "Updated state entry: taskId=" << taskId << ", state=" << state; - TryChangeStateCompletion(); - } catch (const std::exception& e) { - FAIR_LOG(error) << "Exception in UpdateStateEntry: " << e.what(); - } - } - - /// precodition: fMtx is locked. - auto TryChangeStateCompletion() -> void - { - if (!fChangeStateOp.IsCompleted() && fTransitionedCount == fStateData.size()) { - fChangeStateOpTimer.cancel(); - fChangeStateOp.Complete(fStateData); - } - } - - /// precodition: fMtx is locked. - auto ResetTransitionedCount(DeviceState targetState) -> void - { - fTransitionedCount = std::count_if(fStateIndex.cbegin(), fStateIndex.cend(), [=](const auto& s) { - return fStateData.at(s.second).state == targetState; - }); - } - /// precodition: fMtx is locked. auto GetCurrentStateUnsafe() const -> TopologyState { diff --git a/test/sdk/_topology.cxx b/test/sdk/_topology.cxx index 05ecc6a0..c9746cdc 100644 --- a/test/sdk/_topology.cxx +++ b/test/sdk/_topology.cxx @@ -144,18 +144,21 @@ TEST_F(Topology, AsyncChangeStateConcurrent) using namespace fair::mq; sdk::Topology topo(mDDSTopo, mDDSSession); - tools::SharedSemaphore blocker; - topo.AsyncChangeState(sdk::TopologyTransition::InitDevice, - [blocker](std::error_code ec, sdk::TopologyState) mutable { - LOG(info) << "result for valid ChangeState: " << ec; - blocker.Signal(); + topo.AsyncChangeState(sdk::TopologyTransition::InitDevice, "main/Sampler.*", + [](std::error_code ec, sdk::TopologyState) mutable { + LOG(info) << "ChangeState for Sampler: " << ec; + EXPECT_EQ(ec, std::error_code()); }); - topo.AsyncChangeState(sdk::TopologyTransition::Stop, - [](std::error_code ec, sdk::TopologyState) { - LOG(ERROR) << "Expected error: " << ec; - EXPECT_EQ(ec, MakeErrorCode(ErrorCode::OperationInProgress)); + topo.AsyncChangeState(sdk::TopologyTransition::InitDevice, "main/SinkGroup/.*", + [](std::error_code ec, sdk::TopologyState) mutable { + LOG(info) << "ChangeState for Sinks: " << ec; + EXPECT_EQ(ec, std::error_code()); }); - blocker.Wait(); + + topo.WaitForState(sdk::DeviceState::InitializingDevice); + auto const currentState = topo.GetCurrentState(); + EXPECT_NO_THROW(sdk::AggregateState(currentState)); + EXPECT_EQ(sdk::StateEqualsTo(currentState, sdk::DeviceState::InitializingDevice), true); } TEST_F(Topology, AsyncChangeStateTimeout)