Use enum transport types instead of strings in Channel/Device

This commit is contained in:
Alexey Rybalchenko 2018-05-04 16:12:37 +02:00 committed by Mohammad Al-Turany
parent 7a4fd96b27
commit d93dc2f7f7
5 changed files with 45 additions and 49 deletions

View File

@ -28,7 +28,6 @@ FairMQChannel::FairMQChannel()
, fType("unspecified") , fType("unspecified")
, fMethod("unspecified") , fMethod("unspecified")
, fAddress("unspecified") , fAddress("unspecified")
, fTransportName("default")
, fSndBufSize(1000) , fSndBufSize(1000)
, fRcvBufSize(1000) , fRcvBufSize(1000)
, fSndKernelSize(0) , fSndKernelSize(0)
@ -49,7 +48,6 @@ FairMQChannel::FairMQChannel(const string& type, const string& method, const str
, fType(type) , fType(type)
, fMethod(method) , fMethod(method)
, fAddress(address) , fAddress(address)
, fTransportName("default")
, fSndBufSize(1000) , fSndBufSize(1000)
, fRcvBufSize(1000) , fRcvBufSize(1000)
, fSndKernelSize(0) , fSndKernelSize(0)
@ -70,7 +68,6 @@ FairMQChannel::FairMQChannel(const string& name, const string& type, std::shared
, fType(type) , fType(type)
, fMethod("unspecified") , fMethod("unspecified")
, fAddress("unspecified") , fAddress("unspecified")
, fTransportName("default") // TODO refactor, either use string representation or enum type
, fSndBufSize(1000) , fSndBufSize(1000)
, fRcvBufSize(1000) , fRcvBufSize(1000)
, fSndKernelSize(0) , fSndKernelSize(0)
@ -91,7 +88,6 @@ FairMQChannel::FairMQChannel(const FairMQChannel& chan)
, fType(chan.fType) , fType(chan.fType)
, fMethod(chan.fMethod) , fMethod(chan.fMethod)
, fAddress(chan.fAddress) , fAddress(chan.fAddress)
, fTransportName(chan.fTransportName)
, fSndBufSize(chan.fSndBufSize) , fSndBufSize(chan.fSndBufSize)
, fRcvBufSize(chan.fRcvBufSize) , fRcvBufSize(chan.fRcvBufSize)
, fSndKernelSize(chan.fSndKernelSize) , fSndKernelSize(chan.fSndKernelSize)
@ -99,7 +95,7 @@ FairMQChannel::FairMQChannel(const FairMQChannel& chan)
, fRateLogging(chan.fRateLogging) , fRateLogging(chan.fRateLogging)
, fName(chan.fName) , fName(chan.fName)
, fIsValid(false) , fIsValid(false)
, fTransportType(fair::mq::Transport::DEFAULT) , fTransportType(chan.fTransportType)
, fTransportFactory(nullptr) , fTransportFactory(nullptr)
, fMultipart(chan.fMultipart) , fMultipart(chan.fMultipart)
, fModified(chan.fModified) , fModified(chan.fModified)
@ -111,7 +107,6 @@ FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan)
fType = chan.fType; fType = chan.fType;
fMethod = chan.fMethod; fMethod = chan.fMethod;
fAddress = chan.fAddress; fAddress = chan.fAddress;
fTransportName = chan.fTransportName;
fSndBufSize = chan.fSndBufSize; fSndBufSize = chan.fSndBufSize;
fRcvBufSize = chan.fRcvBufSize; fRcvBufSize = chan.fRcvBufSize;
fSndKernelSize = chan.fSndKernelSize; fSndKernelSize = chan.fSndKernelSize;
@ -120,7 +115,7 @@ FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan)
fSocket = nullptr; fSocket = nullptr;
fName = chan.fName; fName = chan.fName;
fIsValid = false; fIsValid = false;
fTransportType = fair::mq::Transport::DEFAULT; fTransportType = chan.fTransportType;
fTransportFactory = nullptr; fTransportFactory = nullptr;
return *this; return *this;
@ -199,7 +194,7 @@ string FairMQChannel::GetTransportName() const
try try
{ {
unique_lock<mutex> lock(fChannelMutex); unique_lock<mutex> lock(fChannelMutex);
return fTransportName; return fair::mq::TransportNames.at(fTransportType);
} }
catch (exception& e) catch (exception& e)
{ {
@ -332,7 +327,9 @@ void FairMQChannel::UpdateTransport(const string& transport)
{ {
unique_lock<mutex> lock(fChannelMutex); unique_lock<mutex> lock(fChannelMutex);
fIsValid = false; fIsValid = false;
fTransportName = transport; LOG(WARN) << fName << ": " << transport;
fTransportType = fair::mq::TransportTypes.at(transport);
LOG(WARN) << fName << ": " << fair::mq::TransportNames.at(fTransportType);
fModified = true; fModified = true;
} }
catch (exception& e) catch (exception& e)
@ -586,15 +583,6 @@ bool FairMQChannel::ValidateChannel()
} }
} }
// validate channel transport
if (fair::mq::TransportTypes.find(fTransportName) == fair::mq::TransportTypes.end())
{
ss << "INVALID";
LOG(debug) << ss.str();
LOG(error) << "Invalid channel transport: \"" << fTransportName << "\"";
exit(EXIT_FAILURE);
}
// validate socket buffer size for sending // validate socket buffer size for sending
if (fSndBufSize < 0) if (fSndBufSize < 0)
{ {

View File

@ -301,7 +301,7 @@ class FairMQChannel
std::string fType; std::string fType;
std::string fMethod; std::string fMethod;
std::string fAddress; std::string fAddress;
std::string fTransportName; fair::mq::Transport fTransportType;
int fSndBufSize; int fSndBufSize;
int fRcvBufSize; int fRcvBufSize;
int fSndKernelSize; int fSndKernelSize;
@ -311,7 +311,6 @@ class FairMQChannel
std::string fName; std::string fName;
std::atomic<bool> fIsValid; std::atomic<bool> fIsValid;
fair::mq::Transport fTransportType;
std::shared_ptr<FairMQTransportFactory> fTransportFactory; std::shared_ptr<FairMQTransportFactory> fTransportFactory;
bool CheckCompatibility(std::unique_ptr<FairMQMessage>& msg) const; bool CheckCompatibility(std::unique_ptr<FairMQMessage>& msg) const;

View File

@ -29,6 +29,7 @@
using namespace std; using namespace std;
FairMQDevice::FairMQDevice() FairMQDevice::FairMQDevice()
: fTransportFactory(nullptr) : fTransportFactory(nullptr)
, fTransports() , fTransports()
@ -42,7 +43,7 @@ FairMQDevice::FairMQDevice()
, fPortRangeMin(22000) , fPortRangeMin(22000)
, fPortRangeMax(32000) , fPortRangeMax(32000)
, fNetworkInterface() , fNetworkInterface()
, fDefaultTransportName("default") , fDefaultTransportType(fair::mq::Transport::DEFAULT)
, fInitializationTimeoutInS(120) , fInitializationTimeoutInS(120)
, fDataCallbacks(false) , fDataCallbacks(false)
, fMsgInputs() , fMsgInputs()
@ -72,7 +73,7 @@ FairMQDevice::FairMQDevice(const fair::mq::tools::Version version)
, fPortRangeMin(22000) , fPortRangeMin(22000)
, fPortRangeMax(32000) , fPortRangeMax(32000)
, fNetworkInterface() , fNetworkInterface()
, fDefaultTransportName("default") , fDefaultTransportType(fair::mq::Transport::DEFAULT)
, fInitializationTimeoutInS(120) , fInitializationTimeoutInS(120)
, fDataCallbacks(false) , fDataCallbacks(false)
, fMsgInputs() , fMsgInputs()
@ -246,15 +247,15 @@ bool FairMQDevice::AttachChannel(FairMQChannel& ch)
{ {
if (!ch.fTransportFactory) if (!ch.fTransportFactory)
{ {
if (ch.fTransportName == "default" || ch.fTransportName == fDefaultTransportName) if (ch.fTransportType == fair::mq::Transport::DEFAULT || ch.fTransportType == fTransportFactory->GetType())
{ {
LOG(debug) << ch.fName << ": using default transport"; LOG(debug) << ch.fName << ": using default transport";
ch.InitTransport(fTransportFactory); ch.InitTransport(fTransportFactory);
} }
else else
{ {
LOG(debug) << ch.fName << ": channel transport (" << fDefaultTransportName << ") overriden to " << ch.fTransportName; LOG(debug) << ch.fName << ": channel transport (" << fair::mq::TransportNames.at(fDefaultTransportType) << ") overriden to " << fair::mq::TransportNames.at(ch.fTransportType);
ch.InitTransport(AddTransport(ch.fTransportName)); ch.InitTransport(AddTransport(ch.fTransportType));
} }
ch.fTransportType = ch.fTransportFactory->GetType(); ch.fTransportType = ch.fTransportFactory->GetType();
} }
@ -760,24 +761,24 @@ void FairMQDevice::Pause()
LOG(debug) << "Unpausing"; LOG(debug) << "Unpausing";
} }
shared_ptr<FairMQTransportFactory> FairMQDevice::AddTransport(const string& transport) shared_ptr<FairMQTransportFactory> FairMQDevice::AddTransport(const fair::mq::Transport transport)
{ {
auto i = fTransports.find(fair::mq::TransportTypes.at(transport)); auto i = fTransports.find(transport);
if (i == fTransports.end()) if (i == fTransports.end())
{ {
auto tr = FairMQTransportFactory::CreateTransportFactory(transport, fId, fConfig); auto tr = FairMQTransportFactory::CreateTransportFactory(fair::mq::TransportNames.at(transport), fId, fConfig);
LOG(debug) << "Adding '" << transport << "' transport to the device."; LOG(debug) << "Adding '" << fair::mq::TransportNames.at(transport) << "' transport to the device.";
pair<fair::mq::Transport, shared_ptr<FairMQTransportFactory>> trPair(fair::mq::TransportTypes.at(transport), tr); pair<fair::mq::Transport, shared_ptr<FairMQTransportFactory>> trPair(transport, tr);
fTransports.insert(trPair); fTransports.insert(trPair);
return tr; return tr;
} }
else else
{ {
LOG(debug) << "Reusing existing '" << transport << "' transport."; LOG(debug) << "Reusing existing '" << fair::mq::TransportNames.at(transport) << "' transport.";
return i->second; return i->second;
} }
} }
@ -804,7 +805,11 @@ void FairMQDevice::CreateOwnConfig()
fNumIoThreads = fConfig->GetValue<int>("io-threads"); fNumIoThreads = fConfig->GetValue<int>("io-threads");
fInitializationTimeoutInS = fConfig->GetValue<int>("initialization-timeout"); fInitializationTimeoutInS = fConfig->GetValue<int>("initialization-timeout");
fRate = fConfig->GetValue<float>("rate"); fRate = fConfig->GetValue<float>("rate");
fDefaultTransportName = fConfig->GetValue<string>("transport"); try {
fDefaultTransportType = fair::mq::TransportTypes.at(fConfig->GetValue<string>("transport"));
} catch(const exception& e) {
LOG(ERROR) << "invalid transport type provided: " << fConfig->GetValue<string>("transport");
}
} }
void FairMQDevice::SetTransport(const string& transport) void FairMQDevice::SetTransport(const string& transport)
@ -819,7 +824,7 @@ void FairMQDevice::SetTransport(const string& transport)
if (fTransports.empty()) if (fTransports.empty())
{ {
LOG(debug) << "Requesting '" << transport << "' as default transport for the device"; LOG(debug) << "Requesting '" << transport << "' as default transport for the device";
fTransportFactory = AddTransport(transport); fTransportFactory = AddTransport(fair::mq::TransportTypes.at(transport));
} }
else else
{ {
@ -844,8 +849,12 @@ void FairMQDevice::SetConfig(FairMQProgOptions& config)
fNumIoThreads = config.GetValue<int>("io-threads"); fNumIoThreads = config.GetValue<int>("io-threads");
fInitializationTimeoutInS = config.GetValue<int>("initialization-timeout"); fInitializationTimeoutInS = config.GetValue<int>("initialization-timeout");
fRate = fConfig->GetValue<float>("rate"); fRate = fConfig->GetValue<float>("rate");
fDefaultTransportName = config.GetValue<string>("transport"); try {
SetTransport(fDefaultTransportName); fDefaultTransportType = fair::mq::TransportTypes.at(fConfig->GetValue<string>("transport"));
} catch(const exception& e) {
LOG(ERROR) << "invalid transport type provided: " << fConfig->GetValue<string>("transport");
}
SetTransport(fConfig->GetValue<string>("transport"));
} }
void FairMQDevice::LogSocketRates() void FairMQDevice::LogSocketRates()

View File

@ -196,7 +196,7 @@ class FairMQDevice : public FairMQStateMachine
/// @brief Getter for default transport factory /// @brief Getter for default transport factory
auto Transport() const -> const FairMQTransportFactory* auto Transport() const -> const FairMQTransportFactory*
{ {
return fTransportFactory.get();; return fTransportFactory.get();
} }
template<typename... Args> template<typename... Args>
@ -293,7 +293,7 @@ class FairMQDevice : public FairMQStateMachine
/// Adds a transport to the device if it doesn't exist /// Adds a transport to the device if it doesn't exist
/// @param transport Transport string ("zeromq"/"nanomsg"/"shmem") /// @param transport Transport string ("zeromq"/"nanomsg"/"shmem")
std::shared_ptr<FairMQTransportFactory> AddTransport(const std::string& transport); std::shared_ptr<FairMQTransportFactory> AddTransport(const fair::mq::Transport transport);
/// Sets the default transport for the device /// Sets the default transport for the device
/// @param transport Transport string ("zeromq"/"nanomsg"/"shmem") /// @param transport Transport string ("zeromq"/"nanomsg"/"shmem")
void SetTransport(const std::string& transport = "zeromq"); void SetTransport(const std::string& transport = "zeromq");
@ -407,14 +407,14 @@ class FairMQDevice : public FairMQStateMachine
void SetNetworkInterface(const std::string& networkInterface) { fNetworkInterface = networkInterface; } void SetNetworkInterface(const std::string& networkInterface) { fNetworkInterface = networkInterface; }
std::string GetNetworkInterface() const { return fNetworkInterface; } std::string GetNetworkInterface() const { return fNetworkInterface; }
void SetDefaultTransportName(const std::string& defaultTransportName) { fDefaultTransportName = defaultTransportName; } void SetDefaultTransport(const std::string& name) { fDefaultTransportType = fair::mq::TransportTypes.at(name); }
std::string GetDefaultTransportName() const { return fDefaultTransportName; } std::string GetDefaultTransport() const { return fair::mq::TransportNames.at(fDefaultTransportType); }
void SetInitializationTimeoutInS(int initializationTimeoutInS) { fInitializationTimeoutInS = initializationTimeoutInS; } void SetInitializationTimeoutInS(int initializationTimeoutInS) { fInitializationTimeoutInS = initializationTimeoutInS; }
int GetInitializationTimeoutInS() const { return fInitializationTimeoutInS; } int GetInitializationTimeoutInS() const { return fInitializationTimeoutInS; }
protected: protected:
std::shared_ptr<FairMQTransportFactory> fTransportFactory; ///< Transport factory std::shared_ptr<FairMQTransportFactory> fTransportFactory; ///< Default transport factory
std::unordered_map<fair::mq::Transport, std::shared_ptr<FairMQTransportFactory>> fTransports; ///< Container for transports std::unordered_map<fair::mq::Transport, std::shared_ptr<FairMQTransportFactory>> fTransports; ///< Container for transports
public: public:
@ -472,7 +472,7 @@ class FairMQDevice : public FairMQStateMachine
int fPortRangeMax; ///< Maximum value for the port range (if dynamic) int fPortRangeMax; ///< Maximum value for the port range (if dynamic)
std::string fNetworkInterface; ///< Network interface to use for dynamic binding std::string fNetworkInterface; ///< Network interface to use for dynamic binding
std::string fDefaultTransportName; ///< Default transport for the device fair::mq::Transport fDefaultTransportType; ///< Default transport for the device
int fInitializationTimeoutInS; ///< Timeout for the initialization (in seconds) int fInitializationTimeoutInS; ///< Timeout for the initialization (in seconds)

View File

@ -29,14 +29,6 @@ enum class Transport
OFI OFI
}; };
static std::unordered_map<std::string, Transport> TransportTypes {
{ "default", Transport::DEFAULT },
{ "zeromq", Transport::ZMQ },
{ "nanomsg", Transport::NN },
{ "shmem", Transport::SHM },
{ "ofi", Transport::OFI }
};
} /* namespace mq */ } /* namespace mq */
} /* namespace fair */ } /* namespace fair */
@ -53,6 +45,14 @@ namespace fair
namespace mq namespace mq
{ {
static std::unordered_map<std::string, Transport> TransportTypes {
{ "default", Transport::DEFAULT },
{ "zeromq", Transport::ZMQ },
{ "nanomsg", Transport::NN },
{ "shmem", Transport::SHM },
{ "ofi", Transport::OFI }
};
static std::unordered_map<Transport, std::string> TransportNames { static std::unordered_map<Transport, std::string> TransportNames {
{ Transport::DEFAULT, "default" }, { Transport::DEFAULT, "default" },
{ Transport::ZMQ, "zeromq" }, { Transport::ZMQ, "zeromq" },