diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h index 9fbef16c..16d027ff 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h @@ -89,11 +89,24 @@ const std::string kAddrUnregistrationSrv = "/address/unregister"; /// \brief Service used to register an end point. const std::string kEndPointRegistrationSrv = "/end_point/register"; +/// \brief Service used to unregister an end point. +const std::string kEndPointUnregistrationSrv = "/end_point/unregister"; + /// \brief Address used to receive neighbor updates. const std::string kNeighborsTopic = "/neighbors"; /// \brief Default port. const uint32_t kDefaultPort = 4100u; +/// \brief ID of a broker client. +typedef uint32_t ClientID; +/// \brief ID denoting an invalid broker client. +constexpr ClientID invalidClientId {0}; + +/// \brief ID of a bound endpoint belonging to a particular client. +typedef uint32_t EndpointID; +/// \brief ID denoting an invalid endpoint. +constexpr EndpointID invalidEndpointId {0}; + } } diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h index d7ab6a72..d3892c19 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -53,6 +54,8 @@ pose_update_function; /// \brief Map of endpoints using EndPoints_M = std::map>; + struct BrokerPrivate; + /// \brief Store messages, and exposes an API for registering new clients, /// bind to a particular address, push new messages or get the list of /// messages already stored in the queue. @@ -62,7 +65,7 @@ pose_update_function; public: Broker(); /// \brief Destructor. - public: virtual ~Broker() = default; + public: virtual ~Broker(); /// \brief Start handling services /// @@ -83,29 +86,39 @@ pose_update_function; public: void NotifyNeighbors(); /// \brief Dispatch all incoming messages. - public: void DispatchMessages(); + public: bool DispatchMessages(); /// \brief This method associates an endpoint with a broker client and its /// address. An endpoint is constructed as an address followed by ':', /// followed by the port. E.g.: "192.168.1.5:8000" is a valid endpoint. - /// \param[in] _clientAddress Address of the broker client. + /// \param[in] _clientId ID of a registered client. /// \param[in] _endpoint End point requested to bind. - /// \return True if the operation succeed or false otherwise (if the client - /// was already bound to the same endpoint). - private: bool Bind(const std::string &_clientAddress, - const std::string &_endpoint); - - /// \brief Register a new client for message handling. - /// \param[in] _id Unique ID of the client. - /// \return True if the operation succeed of false otherwise (if the same - /// id was already registered). - public: bool Register(const std::string &_id); - - /// \brief Unregister a client and unbind from all the endpoints. - /// \param[in] _id Unique ID of the client. - /// \return True if the operation succeed or false otherwise (if there is - /// no client registered for this ID). - public: bool Unregister(const std::string &_id); + /// \return ID that should be used for unbinding the endpoint. If + /// invalidEndpointId is returned, the binding request failed (e.g. due to + /// wrong endpoint specification, or if the _clientId is wrong). + public: EndpointID Bind(ClientID _clientId, const std::string &_endpoint); + + /// \brief This method cancels the association between a client and an + /// endpoint. When all equal endpoints on the same client are unbound, + /// the client will stop receiving messages for the given endpoint. + /// \param[in] _endpointId ID of the endpoint to unbind. It has to be an ID + /// received from a previous call to Bind(). + /// \return Whether the unbind was successful. It may fail e.g. when the + /// given ID is invalid and the broker doesn't know it. + public: bool Unbind(EndpointID _endpointId); + + /// \brief Register a new client for message handling. Multiple clients for + /// the same address are allowed even from a single process. + /// \param[in] _clientAddress Address of the client. + /// \return ID of the client (should be later used for unregistration). + /// If the returned ID is invalidClientId, the registration failed. + public: ClientID Register(const std::string &_clientAddress); + + /// \brief Unregister a client and unbind from all its endpoints. + /// \param[in] _clientId The ID received from the Register() call. + /// \return True if the operation succeeded or false otherwise (if there is + /// no client registered for this ID or the ID is invalid). + public: bool Unregister(ClientID _clientId); /// \brief Set the radio configuration for address /// \param[in] address @@ -128,29 +141,45 @@ pose_update_function; /// \param[in] f Function that finds pose based on name public: void SetPoseUpdateFunction(pose_update_function f); + /// \brief Get the Ignition partition this broker is running in. + /// \return The partition name. The reference gets invalid when the broker + /// object goes out of scope. + public: const std::string& IgnPartition() const; + /// \brief Callback executed when a new registration request is received. /// \param[in] _req The address contained in the request. - /// \param[out] _rep The result of the service. True when the registration - /// went OK or false otherwise (e.g.: the same address was already - /// registered). + /// \param[out] _rep An ID of the registered client. This ID should be used + /// when unregistering this client. If registration failed, invalidClientId + /// value is returned. private: bool OnAddrRegistration(const ignition::msgs::StringMsg &_req, - ignition::msgs::Boolean &_rep); + ignition::msgs::UInt32 &_rep); /// \brief Callback executed when a new unregistration request is received. - /// \param[in] _req The address contained in the request. + /// \param[in] _req ID of the client to unregister. /// \param[out] _rep The result of the service. True when the unregistration - /// went OK or false otherwise (e.g.: the address wasn't registered). - private: bool OnAddrUnregistration(const ignition::msgs::StringMsg &_req, + /// went OK or false otherwise (e.g.: the ID wasn't registered). + private: bool OnAddrUnregistration(const ignition::msgs::UInt32 &_req, ignition::msgs::Boolean &_rep); - /// \brief Callback executed when a new registration request is received. - /// \param[in] _req The end point contained in the request. The first - /// string is the client address and the second string is the end point. - /// \param[out] _rep The result of the service. True when the registration - /// went OK of false otherwise (e.g.: _req doesn't contain two strings). + /// \brief Callback executed when a new endpoint registration request is + /// received. + /// \param[in] _req The endpoint to register together with ID of the client + /// to which this registration belongs. + /// \param[out] _rep An ID of the endpoint. This ID should be used + /// when unregistering this endpoint. If registration failed, + /// invalidEndpointId value is returned. private: bool OnEndPointRegistration( - const ignition::msgs::StringMsg_V &_req, - ignition::msgs::Boolean &_rep); + const subt::msgs::EndpointRegistration &_req, + ignition::msgs::UInt32 &_rep); + + /// \brief Callback executed when a new endpoint unregistration request is + /// received. + /// \param[in] _req ID of the endpoint to unregister. + /// \param[out] _rep The result of the service. True when the unregistration + /// went OK or false otherwise (e.g.: the ID wasn't registered). + private: bool OnEndPointUnregistration( + const ignition::msgs::UInt32 &_req, + ignition::msgs::Boolean &_rep); /// \brief Callback executed when a new request is received. /// \param[in] _req The datagram contained in the request. @@ -185,6 +214,9 @@ pose_update_function; /// \brief Pose update function private: pose_update_function pose_update_f; + + /// \brief Private definitions and data + private: std::unique_ptr dataPtr; }; } diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h index ef1a6cd2..a6c7c6b8 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h @@ -45,10 +45,22 @@ namespace subt /// \param[in] _useIgnition Set to true if you are using Ignition /// transport (i.e. not ROS). This is needed by the base station, /// and tests. If you are a regular robot, then you really really do not - /// want to set this to true as your Commsclient will not work. - public: CommsClient(const std::string &_localAddress, - const bool _isPrivate = false, - const bool _useIgnition = false); + /// want to set this to true as your Commsclient will not work. For each + /// address (robot), there can be only one client with _useIgnition set to + /// true. There can be multiple clients with _useIgnition set to false. + /// \param[in] _listenBeacons If true (the default), the commsclient will + /// listen to beacon packets on port 4000. This (or another regular stream + /// of data) is required for the Neighbors() function to work. + /// \param[in] _rosNh The ROS node handle used to create ROS subscribers for + /// incoming messages. Only used when _useIgnition is false. If not given, + /// a default node handle is used (in the current namespace). + public: + explicit + CommsClient(const std::string& _localAddress, + bool _isPrivate = false, + bool _useIgnition = false, + bool _listenBeacons = true, + ros::NodeHandle* _rosNh = nullptr); /// \brief Destructor. public: virtual ~CommsClient(); @@ -59,7 +71,9 @@ namespace subt /// \brief This method can bind a local address and a port to a /// virtual socket. This is a required step if your agent needs to - /// receive messages. + /// receive messages. It is possible to bind multiple callbacks to the + /// same address:port endpoint, even from the same client ID. Just call + /// Bind() multiple times. /// /// \param[in] _cb Callback function to be executed when a new message is /// received associated to the specified <_address, port>. @@ -74,22 +88,31 @@ namespace subt /// You will receive all the messages sent from any node to this multicast /// group. /// \param[in] _port Port used to receive messages. - /// \return True when success or false otherwise. + /// \return List of endpoints that were created in result of this bind call. + /// It can be up to two endpoints - one for the unicast/multicast address, + /// and one for the broadcast address. The returned pairs contain the IDs + /// of the endpoints and their names (i.e. address:port). If the returned + /// list is empty, binding failed. You may test for this case using the + /// overloaded operator! on the result vector. /// /// * Example usage (bind on the local address and default port): /// this->Bind(&OnDataReceived, "192.168.1.3"); /// * Example usage (Bind on the multicast group and custom port.): /// this->Bind(&OnDataReceived, this->kMulticast, 5123); - public: bool Bind(std::function _cb, - const std::string &_address = "", - const int _port = communication_broker::kDefaultPort); + public: + std::vector> + Bind(std::function _cb, + const std::string& _address = "", + int _port = communication_broker::kDefaultPort); /// \brief This method can bind a local address and a port to a /// virtual socket. This is a required step if your agent needs to - /// receive messages. + /// receive messages. It is possible to bind multiple callbacks to the + /// same address:port endpoint, even from the same client ID. Just call + /// Bind() multiple times. /// /// \param[in] _cb Callback function to be executed when a new message is /// received associated to the specified <_address, port>. @@ -105,20 +128,27 @@ namespace subt /// You will receive all the messages sent from any node to this multicast /// group. /// \param[in] _port Port used to receive messages. - /// \return True when success or false otherwise. + /// \return List of endpoints that were created in result of this bind call. + /// It can be up to two endpoints - one for the unicast/multicast address, + /// and one for the broadcast address. The returned pairs contain the IDs + /// of the endpoints and their names (i.e. address:port). If the returned + /// list is empty, binding failed. You may test for this case using the + /// overloaded operator! on the result vector. /// /// * Example usage (bind on the local address and default port): /// this->Bind(&MyClass::OnDataReceived, this, "192.168.1.3"); /// * Example usage (Bind on the multicast group and custom port.): /// this->Bind(&MyClass::OnDataReceived, this, this->kMulticast, 5123); - public: template - bool Bind(void(C::*_cb)(const std::string &_srcAddress, - const std::string &_dstAddress, - const uint32_t _dstPort, - const std::string &_data), - C *_obj, - const std::string &_address = "", - const int _port = communication_broker::kDefaultPort) + public: + template + std::vector> + Bind(void(C::*_cb)(const std::string& _srcAddress, + const std::string& _dstAddress, + uint32_t _dstPort, + const std::string& _data), + C* _obj, + const std::string& _address = "", + int _port = communication_broker::kDefaultPort) { return this->Bind(std::bind(_cb, _obj, std::placeholders::_1, @@ -128,6 +158,14 @@ namespace subt _address, _port); } + /// \brief This method can unbind from a "socket" acquired by Bind(). Once + /// unbound, the registered callback will no longer be called. + /// + /// \param[in] _endpointId ID of the endpoint to unbind. This is the ID + /// returned from Bind() call. + /// \return Success of the unbinding. + public: bool Unbind(communication_broker::EndpointID _endpointId); + /// \brief Send some data to other/s member/s of the team. /// /// \param[in] _data Payload. The maximum size of the payload is 1500 bytes. @@ -141,7 +179,7 @@ namespace subt /// bigger than 1500 bytes). public: bool SendTo(const std::string &_data, const std::string &_dstAddress, - const uint32_t _port = communication_broker::kDefaultPort); + uint32_t _port = communication_broker::kDefaultPort); /// \brief Type for storing neighbor data public: typedef std::map> Neighbor_M; @@ -150,6 +188,11 @@ namespace subt /// /// \return A map of addresses and signal strength from your /// local neighbors. + /// \note For this function to work reliably, you need to be sure there is + /// a regular flow of data from all other robots towards this address. + /// One way to achieve it is to send and listen to beacons (see constructor + /// argument _listenBeacons and function StartBeaconInterval() or + /// SendBeacons()). public: Neighbor_M Neighbors() const; /// \brief Broadcast a BEACON packet @@ -165,8 +208,9 @@ namespace subt private: bool Register(); /// \brief Unregister the current address. This will make a synchronous call - /// to the broker to unregister the address. - /// \return True when the unregistration succeed or false otherwise. + /// to the broker to unregister the address. It will also unbind all + /// bound callbacks from this client. + /// \return True when the unregistration succeeded or false otherwise. private: bool Unregister(); /// \brief Function called each time a new datagram message is received. @@ -174,9 +218,8 @@ namespace subt private: void OnMessage(const msgs::Datagram &_msg); /// \brief Function called each time a new datagram message is received. - /// \param[in] _msg The incoming message. - private: bool OnMessageRos(subt_msgs::DatagramRos::Request &_req, - subt_msgs::DatagramRos::Response &_res); + /// \param[in] _req The incoming message. + private: void OnMessageRos(const subt_msgs::DatagramRos::Request &_req); /// \brief On clock message. This is used primarily/only by the /// BaseStation. @@ -192,7 +235,7 @@ namespace subt private: using Callback_t = std::function; /// \brief The local address. @@ -216,9 +259,12 @@ namespace subt /// \brief An Ignition Transport node for communications. private: ignition::transport::Node node; + private: using Callbacks = + std::unordered_map; + /// \brief User callbacks. The key is the topic name /// (e.g.: "/subt/192.168.2.1/4000") and the value is the user callback. - private: std::map callbacks; + private: std::map callbacks; /// \brief True when the broker validated my address. Enabled must be true /// for being able to send and receive data. @@ -235,8 +281,8 @@ namespace subt /// \brief A mutex for avoiding race conditions. private: mutable std::mutex mutex; - /// \brief Service that receives comms messages. - private: ros::ServiceServer commsModelOnMessageService; + /// \brief Subscriber that receives comms messages from SubtRosRelay. + private: ros::Subscriber commsSub; /// \brief Clock message from simulation. Used by the base station. /// The base station is run as a plugin alongside simulation, and does @@ -254,6 +300,20 @@ namespace subt /// \brief Period of the beacon in nanoseconds. private: int64_t beaconPeriodNs{0}; + + private: communication_broker::ClientID clientId { + communication_broker::invalidClientId + }; }; } + +/// \brief Backwards compatibility that allows easily testing Bind() success. +/// \param _val Result of a Bind() operation. +/// \return True if the Bind() operation failed. +bool operator !(const std::vector< + std::pair>& _val) +{ + return _val.empty(); +} + #endif diff --git a/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt b/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt index 445c2596..665a093e 100644 --- a/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt +++ b/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt @@ -10,6 +10,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/Protobuf.cmake) set(PROTO_MESSAGES datagram.proto + endpoint_registration.proto neighbor_m.proto ) diff --git a/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto b/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto new file mode 100644 index 00000000..07a545c1 --- /dev/null +++ b/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +syntax = "proto3"; +package subt.msgs; + +/// \ingroup subt_msgs +/// \interface EndpointRegistration +/// \brief A message containing data needed for registering an endpoint. + +message EndpointRegistration +{ + /// \brief ID of the client this endpoint should be bound to. + uint32 client_id = 1; + + /// \brief The endpoint to register ("address:port"). + string endpoint = 2; +} diff --git a/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp b/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp index 1b6a5a25..c9bd014e 100644 --- a/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp +++ b/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp @@ -30,12 +30,223 @@ namespace subt namespace communication_broker { +/// \brief Helper class for managing bidirectional mapping of client IDs and +/// addresses. The class is not thread-safe, so callers must ensure that none +/// of the public methods get called simultaneously. +struct ClientIDs +{ + /// \brief Number of active clients for each address. This structure can be + /// accessed by outer code for reading, but not for modification. + std::unordered_map numActiveClients; + + /// \brief Map of client IDs to addresses. This structure can be accessed by + /// outer code for reading, but not for modification. + std::unordered_map idToAddress; + + /// \brief Add a new client and generate its ID. + /// \param _address Address of the client. + /// \return ID of the client. This method should always succeed and return + /// an ID different from invalidClientID. + ClientID Add(const std::string& _address) + { + const auto clientId = this->NextID(); + this->idToAddress[clientId] = _address; + if (this->numActiveClients.find(_address) == this->numActiveClients.end()) + this->numActiveClients[_address] = 0; + this->numActiveClients[_address]++; + return clientId; + } + + /// \brief Unregister a client. + /// \param _id ID of the client. + /// \return Success of the unregistration. The method can fail e.g. when + /// trying to unregister a client which has not been registered. + bool Remove(const ClientID _id) + { + if (!this->Valid(_id)) + return false; + this->numActiveClients[this->idToAddress[_id]]--; + this->idToAddress.erase(_id); + return true; + } + + /// \brief Clear/reset the structure to be able to work as new. + /// \note This cancels all registrations of all clients and resets the client + /// ID numbering, so it is not valid to mix IDs of clients obtained before and + /// after a Clear() call. + void Clear() + { + this->numActiveClients.clear(); + this->idToAddress.clear(); + this->lastId = invalidClientId; + } + + /// \brief Check validity of a client ID. + /// \param _id ID to check. + /// \return Whether a client with the given ID has been registered. + bool Valid(const ClientID _id) const + { + return _id != invalidClientId && + this->idToAddress.find(_id) != this->idToAddress.end(); + } + + /// \brief Return an ID for a new client. + /// \return The ID. + private: ClientID NextID() + { + return ++this->lastId; + } + + /// \brief Last ID given to a client. + private: ClientID lastId {invalidClientId}; +}; + +/// \brief Helper class for managing mappings between endpoint names, their IDs, +/// related clients and so on. The class is not thread-safe, so callers must +/// ensure that none of the public methods get called simultaneously. +struct EndpointIDs +{ + /// \brief Maps endpoint IDs to endpoint names. This structure can be accessed + /// by outer code for reading, but not for modification. + std::unordered_map idToEndpoint; + + /// \brief Maps endpoint names to related clients. Each client ID stores + /// information about the number of "connections" of this client ID to the + /// endpoint. This structure can be accessed by outer code for reading, but + /// not for modification. + std::unordered_map> + endpointToClientIds; + + /// \brief Maps endpoint IDs to their related client IDs. This structure can + /// be accessed by outer code for reading, but not for modification. + std::unordered_map endpointIdToClientId; + + /// \brief Maps client IDs to all their endpoint IDs. This structure can be + /// accessed by outer code for reading, but not for modification. + std::unordered_map> + clientIdToEndpointIds; + + /// \brief Add new endpoint related to the given client ID. + /// \param _endpoint Endpoint name. + /// \param _clientId Client ID. This method does no validation of the IDs. + /// \return ID of the new endpoint. Should always succeed and return an ID + /// not equal to invalidEndpointID. + EndpointID Add(const std::string& _endpoint, const ClientID _clientId) + { + const auto endpointId = this->NextID(); + + this->idToEndpoint[endpointId] = _endpoint; + this->endpointIdToClientId[endpointId] = _clientId; + + if (this->endpointToClientIds[_endpoint].find(_clientId) == + this->endpointToClientIds[_endpoint].end()) + this->endpointToClientIds[_endpoint][_clientId] = 0; + this->endpointToClientIds[_endpoint][_clientId]++; + + this->clientIdToEndpointIds[_clientId].insert(endpointId); + + return endpointId; + } + + /// \brief Remove the given endpoint ID from this structure. + /// \param _id ID of the endpoint to remove. + /// \return Whether the removal succeeded or not. It may fail e.g. when the + /// removed endpoint doesn't exist. + bool Remove(const EndpointID _id) + { + if (!this->Valid(_id)) + return false; + + const auto endpointName = this->idToEndpoint[_id]; + + if (this->endpointIdToClientId.find(_id) == + this->endpointIdToClientId.end()) + return false; + + const auto clientId = this->endpointIdToClientId[_id]; + this->endpointIdToClientId.erase(_id); + + if (this->endpointToClientIds.find(endpointName) == + this->endpointToClientIds.end()) + return false; + + if (this->endpointToClientIds[endpointName].find(clientId) == + this->endpointToClientIds[endpointName].end()) + return false; + + this->endpointToClientIds[endpointName][clientId]--; + if (this->endpointToClientIds[endpointName][clientId] == 0u) + this->endpointToClientIds[endpointName].erase(clientId); + + if (this->clientIdToEndpointIds.find(clientId) == + this->clientIdToEndpointIds.end()) + return false; + + if (this->clientIdToEndpointIds[clientId].find(_id) == + this->clientIdToEndpointIds[clientId].end()) + return false; + + this->clientIdToEndpointIds[clientId].erase(_id); + + return true; + } + + /// \brief Clear/reset the structure to be able to work as new. + /// \note This cancels all registrations of all endpoints and resets the + /// endpoint ID numbering, so it is not valid to mix IDs of endpoints obtained + /// before and after a Clear() call. + void Clear() + { + this->idToEndpoint.clear(); + this->clientIdToEndpointIds.clear(); + this->endpointToClientIds.clear(); + this->endpointIdToClientId.clear(); + this->lastId = invalidEndpointId; + } + + /// \brief Check validity of an endpoint ID. + /// \param _id ID to check. + /// \return Whether an endpoint with the given ID has been registered. + bool Valid(const EndpointID _id) const + { + return _id != invalidEndpointId && + this->idToEndpoint.find(_id) != this->idToEndpoint.end(); + } + + /// \brief Return an ID for a new client. + /// \return The ID. + private: EndpointID NextID() + { + return ++this->lastId; + } + + /// \brief Last ID given to a client. + private: EndpointID lastId {invalidEndpointId}; +}; + +/// \brief PIMPL structure. +struct BrokerPrivate +{ + /// \brief IDs of registered clients. + ClientIDs clientIDs; + + /// \brief IDs of registered endpoints. + EndpointIDs endpointIDs; +}; + ////////////////////////////////////////////////// Broker::Broker() - : team(std::make_shared()) + : team(std::make_shared()), + dataPtr(std::make_unique()) { } +////////////////////////////////////////////////// +Broker::~Broker() +{ + // cannot use default destructor because of dataPtr +} + ////////////////////////////////////////////////// void Broker::Start() { @@ -66,6 +277,15 @@ void Broker::Start() return; } + // Advertise the service for unregistering end points. + if (!this->node.Advertise(kEndPointUnregistrationSrv, + &Broker::OnEndPointUnregistration, this)) + { + std::cerr << "Error advertising srv [" << kEndPointUnregistrationSrv << "]" + << std::endl; + return; + } + // Advertise a oneway service for centralizing all message requests. if (!this->node.Advertise(kBrokerSrv, &Broker::OnMessage, this)) { @@ -82,6 +302,9 @@ void Broker::Start() << std::endl; return; } + + std::cout << "Started communication broker in Ignition partition " + << this->IgnPartition() << std::endl; } ////////////////////////////////////////////////// @@ -90,6 +313,9 @@ void Broker::Reset() std::lock_guard lk(this->mutex); this->incomingMsgs.clear(); this->endpoints.clear(); + this->team->clear(); + this->dataPtr->clientIDs.Clear(); + this->dataPtr->endpointIDs.Clear(); } ////////////////////////////////////////////////// @@ -124,12 +350,12 @@ void Broker::NotifyNeighbors() } ////////////////////////////////////////////////// -void Broker::DispatchMessages() +bool Broker::DispatchMessages() { std::lock_guard lk(this->mutex); if(this->incomingMsgs.empty()) - return; + return true; // Cannot dispatch messages if we don't have function handles for // pathloss and communication @@ -137,14 +363,14 @@ void Broker::DispatchMessages() { std::cerr << "[Broker::DispatchMessages()] Missing function handle for " << "communication" << std::endl; - return; + return false; } if(!pose_update_f) { std::cerr << "[Broker::DispatchMessages()]: Missing function for updating " << "pose" << std::endl; - return; + return false; } // Update state for all members in team (only do this for members @@ -160,10 +386,11 @@ void Broker::DispatchMessages() { std::cerr << "Problem getting state for " << t.second->name << ", skipping DispatchMessages()" << std::endl; - return; + return false; } } + bool allSucceeded = true; while (!this->incomingMsgs.empty()) { // Get the next message to dispatch. @@ -177,6 +404,7 @@ void Broker::DispatchMessages() std::cerr << "Broker::DispatchMessages(): Discarding message. Robot [" << msg.src_address() << "] is not registered as a member of the" << " team" << std::endl; + allSucceeded = false; continue; } @@ -202,6 +430,7 @@ void Broker::DispatchMessages() << "Robot [" << client.address << "] is not registered as a member of the" << " team" << std::endl; + allSucceeded = false; continue; } @@ -211,6 +440,7 @@ void Broker::DispatchMessages() { std::cerr << "No pathloss function defined for " << msg.src_address() << std::endl; + allSucceeded = false; continue; } @@ -231,6 +461,7 @@ void Broker::DispatchMessages() std::cerr << "[CommsBrokerPlugin::DispatchMessages()]: Error " << "sending message to [" << client.address << "]" << std::endl; + allSucceeded = false; } } } @@ -239,116 +470,192 @@ void Broker::DispatchMessages() { std::cerr << "[Broker::DispatchMessages()]: Could not find endpoint " << dstEndPoint << std::endl; + allSucceeded = false; } } + return allSucceeded; } ////////////////////////////////////////////////// -bool Broker::Bind(const std::string &_clientAddress, - const std::string &_endpoint) +EndpointID Broker::Bind(const ClientID _clientId, const std::string &_endpoint) { std::lock_guard lk(this->mutex); - // Make sure that the same client didn't bind the same end point before. + + if (!this->dataPtr->clientIDs.Valid(_clientId)) + { + std::cerr << "Broker::Bind() error: Client ID [" << _clientId + << "] is invalid." << std::endl; + return invalidEndpointId; + } + + const auto& clientAddress = this->dataPtr->clientIDs.idToAddress[_clientId]; + + auto clientFound = false; + if (this->endpoints.find(_endpoint) != this->endpoints.end()) { const auto &clientsV = this->endpoints[_endpoint]; for (const auto &client : clientsV) { - if (client.address == _clientAddress) + if (client.address == clientAddress) { - std::cerr << "Broker::Bind() error: Address [" << _clientAddress - << "] already used in a previous Bind()" << std::endl; - return false; + clientFound = true; + break; } } } - BrokerClientInfo clientInfo; - clientInfo.address = _clientAddress; - this->endpoints[_endpoint].push_back(clientInfo); + if (!clientFound) + { + BrokerClientInfo clientInfo; + clientInfo.address = clientAddress; + this->endpoints[_endpoint].push_back(clientInfo); + } + + const auto endpointId = this->dataPtr->endpointIDs.Add(_endpoint, _clientId); - return true; + return endpointId; } ////////////////////////////////////////////////// -bool Broker::Register(const std::string &_id) +bool Broker::Unbind(EndpointID _endpointId) { std::lock_guard lk(this->mutex); - auto kvp = this->team->find(_id); - if (kvp != this->team->end()) + + if (!this->dataPtr->endpointIDs.Valid(_endpointId)) + return false; + + const auto endpointName = + this->dataPtr->endpointIDs.idToEndpoint[_endpointId]; + + const auto clientId = + this->dataPtr->endpointIDs.endpointIdToClientId[_endpointId]; + const auto clientAddress = this->dataPtr->clientIDs.idToAddress[clientId]; + + if (this->endpoints.find(endpointName) == this->endpoints.end()) + return false; + + bool success = this->dataPtr->endpointIDs.Remove(_endpointId); + if (!success) + return false; + + if (this->dataPtr->endpointIDs.endpointToClientIds.find(endpointName) == + this->dataPtr->endpointIDs.endpointToClientIds.end()) + return false; + + bool hasOtherClientsOnTheSameAddress = false; + for (const auto clientKV : + this->dataPtr->endpointIDs.endpointToClientIds[endpointName]) { - std::cerr << "Broker::Register() warning: ID [" << _id << "] already exists" - << std::endl; + if (this->dataPtr->clientIDs.idToAddress[clientKV.first] == clientAddress && + clientKV.second > 0u) + { + hasOtherClientsOnTheSameAddress = true; + break; + } + } + + if (hasOtherClientsOnTheSameAddress) + return true; + + auto& clientsV = this->endpoints[endpointName]; + + auto i = std::begin(clientsV); + while (i != std::end(clientsV)) + { + if (i->address == clientAddress) + { + clientsV.erase(i); + break; + } + else + { + ++i; + } } - else + + return true; +} + +////////////////////////////////////////////////// +ClientID Broker::Register(const std::string &_clientAddress) +{ + std::lock_guard lk(this->mutex); + auto kvp = this->team->find(_clientAddress); + if (kvp == this->team->end()) { auto newMember = std::make_shared(); // Name and address are the same in SubT. - newMember->address = _id; - newMember->name = _id; + newMember->address = _clientAddress; + newMember->name = _clientAddress; newMember->radio = default_radio_configuration; - (*this->team)[_id] = newMember; + (*this->team)[_clientAddress] = newMember; } - return true; + const auto clientId = this->dataPtr->clientIDs.Add(_clientAddress); + + return clientId; } ////////////////////////////////////////////////// -bool Broker::Unregister(const std::string &_id) +bool Broker::Unregister(const ClientID _clientId) { - std::lock_guard lk(this->mutex); - // Sanity check: Make sure that the ID exists. - if (this->team->find(_id) == this->team->end()) + if (!this->dataPtr->clientIDs.Valid(_clientId)) { - std::cerr << "Broker::Unregister() error: ID [" << _id << "] doesn't exist" - << std::endl; + std::cerr << "Broker::Unregister() error: Client ID [" << _clientId + << "] is invalid." << std::endl; return false; } - this->team->erase(_id); + bool success = true; - // Unbind. - for (auto &endpointKv : this->endpoints) + std::unordered_set endpointIds; { - auto &clientsV = endpointKv.second; + // make a copy because Unbind() calls will alter the structure + std::lock_guard lk(this->mutex); + endpointIds = this->dataPtr->endpointIDs.clientIdToEndpointIds[_clientId]; + } - auto i = std::begin(clientsV); - while (i != std::end(clientsV)) - { - if (i->address == _id) - i = clientsV.erase(i); - else - ++i; - } + for (const auto endpointId : endpointIds) + success = success && this->Unbind(endpointId); + + { + std::lock_guard lk(this->mutex); + + const auto& clientAddress = this->dataPtr->clientIDs.idToAddress[_clientId]; + success = success && this->dataPtr->clientIDs.Remove(_clientId); + + if (this->dataPtr->clientIDs.numActiveClients[clientAddress] == 0u) + this->team->erase(clientAddress); } - return true; + return success; } ///////////////////////////////////////////////// bool Broker::OnAddrRegistration(const ignition::msgs::StringMsg &_req, - ignition::msgs::Boolean &_rep) + ignition::msgs::UInt32 &_rep) { - std::string address = _req.data(); - bool result; + const auto& address = _req.data(); - result = this->Register(address); + const ClientID result = this->Register(address); _rep.set_data(result); - return result; + return result != invalidClientId; } ///////////////////////////////////////////////// -bool Broker::OnAddrUnregistration(const ignition::msgs::StringMsg &_req, +bool Broker::OnAddrUnregistration(const ignition::msgs::UInt32 &_req, ignition::msgs::Boolean &_rep) { - std::string address = _req.data(); + uint32_t clientId = _req.data(); + bool result; - result = this->Unregister(address); + result = this->Unregister(clientId); _rep.set_data(result); @@ -356,21 +663,28 @@ bool Broker::OnAddrUnregistration(const ignition::msgs::StringMsg &_req, } ///////////////////////////////////////////////// -bool Broker::OnEndPointRegistration(const ignition::msgs::StringMsg_V &_req, - ignition::msgs::Boolean &_rep) +bool Broker::OnEndPointRegistration( + const subt::msgs::EndpointRegistration &_req, + ignition::msgs::UInt32 &_rep) { - if (_req.data().size() != 2) - { - std::cerr << "[Broker::OnEndPointRegistration()] Expected two strings and " - << "got " << _req.data().size() << " instead" << std::endl; - return false; - } + ClientID clientId = _req.client_id(); + const auto& endpoint = _req.endpoint(); - bool result; - std::string clientAddress = _req.data(0); - std::string endpoint = _req.data(1); + EndpointID result = this->Bind(clientId, endpoint); + + _rep.set_data(result); + + return result != invalidEndpointId; +} - result = this->Bind(clientAddress, endpoint); +///////////////////////////////////////////////// +bool Broker::OnEndPointUnregistration( + const ignition::msgs::UInt32 &_req, + ignition::msgs::Boolean &_rep) +{ + const EndpointID endpointId = _req.data(); + + bool result = this->Unbind(endpointId); _rep.set_data(result); @@ -387,6 +701,7 @@ void Broker::OnMessage(const subt::msgs::Datagram &_req) this->incomingMsgs.push_back(_req); } +////////////////////////////////////////////////// void Broker::SetRadioConfiguration(const std::string& address, communication_model::radio_configuration config) { @@ -433,5 +748,11 @@ void Broker::SetPoseUpdateFunction(pose_update_function f) pose_update_f = f; } +////////////////////////////////////////////////// +const std::string& Broker::IgnPartition() const +{ + return this->node.Options().Partition(); +} + } } diff --git a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp index d3b24baf..b73634b0 100644 --- a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp +++ b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp @@ -24,13 +24,15 @@ #include #include +#include using namespace subt; using namespace subt::communication_broker; ////////////////////////////////////////////////// CommsClient::CommsClient(const std::string &_localAddress, - const bool _isPrivate, const bool _useIgnition) + const bool _isPrivate, const bool _useIgnition, const bool _listenBeacons, + ros::NodeHandle* _rosNh) : localAddress(_localAddress), isPrivate(_isPrivate), useIgnition(_useIgnition) @@ -65,6 +67,17 @@ CommsClient::CommsClient(const std::string &_localAddress, std::this_thread::sleep_for(std::chrono::milliseconds(100)); this->enabled = this->Register(); elapsed = std::chrono::steady_clock::now() - kStart; + + // give Ctrl-C a chance + if (!_useIgnition) + { + if (!ros::ok()) + return; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds (5)); + } } if (!this->enabled) @@ -74,12 +87,31 @@ CommsClient::CommsClient(const std::string &_localAddress, return; } - // Bind to be sure we receive beacon packets - auto cb = [] (const std::string&, - const std::string&, - const uint32_t, - const std::string&) { }; - this->Bind(cb, "", kBeaconPort); + if (!this->useIgnition) + { + ros::NodeHandle nh; + if (_rosNh != nullptr) + nh = *_rosNh; + + this->commsSub = nh.subscribe( + "/" + _localAddress + "/comms",1000, + &CommsClient::OnMessageRos, this); + } + + if (_listenBeacons) + { + // Bind to be sure we receive beacon packets + auto cb = [] (const std::string&, + const std::string&, + const uint32_t, + const std::string&) { }; + if (!this->Bind(cb, "", kBeaconPort)) + { + std::cerr << "[CommsClient] Could bind the beacon responder" << std::endl; + this->Unregister(); + return; + } + } this->enabled = true; } @@ -104,20 +136,23 @@ std::string CommsClient::Host() const } ////////////////////////////////////////////////// -bool CommsClient::Bind(std::function _cb, - const std::string &_address, - const int _port) +std::vector> +CommsClient::Bind(std::function _cb, + const std::string &_address, + const int _port) { + std::vector> endpoints; + // Sanity check: Make sure that the communications are enabled. - if (!this->enabled) + if (!this->enabled || this->clientId == invalidClientId) { std::cerr << "[" << this->Host() << "] Bind() error: Trying to bind before communications are enabled!" << std::endl; - return false; + return endpoints; } // Use current address if _address is not provided. @@ -131,100 +166,89 @@ bool CommsClient::Bind(std::functionHost() << "] Bind() error: Address [" << address << "] is not your local address" << std::endl; - return false; + return endpoints; } - // Mapping the "unicast socket" to a topic name. + // Mapping the "unicast socket" to an endpoint name. const auto unicastEndPoint = address + ":" + std::to_string(_port); const auto bcastEndpoint = communication_broker::kBroadcast + ":" + std::to_string(_port); - bool bcastAdvertiseNeeded; - - { - std::lock_guard lock(this->mutex); - - // Sanity check: Make sure that this address is not already used. - if (this->callbacks.find(unicastEndPoint) != this->callbacks.end()) - { - std::cerr << "[" << this->Host() << "] Bind() error: Address [" - << address << "] already used" << std::endl; - return false; - } - - bcastAdvertiseNeeded = - this->callbacks.find(bcastEndpoint) == this->callbacks.end(); - } - // Register the endpoints in the broker. - // Note that the broadcast endpoint will only be registered once. for (const std::string &endpoint : {unicastEndPoint, bcastEndpoint}) { - if (endpoint != bcastEndpoint || bcastAdvertiseNeeded) + // If this is the basestation, then we need to use ignition transport. + // Otherwise, the client is on a robot and needs to use ROS. + if (this->useIgnition) { - // If this is the basestation, then we need to use ignition transport. - // Otherwise, the client is on a robot and needs to use ROS. - if (this->useIgnition) + subt::msgs::EndpointRegistration req; + req.set_client_id(this->clientId); + req.set_endpoint(endpoint); + + const unsigned int timeout = 3000u; + ignition::msgs::UInt32 rep; + bool result; + bool executed = this->node.Request( + communication_broker::kEndPointRegistrationSrv, + req, timeout, rep, result); + + if (!executed) { - ignition::msgs::StringMsg_V req; - req.add_data(address); - req.add_data(endpoint); - - const unsigned int timeout = 3000u; - ignition::msgs::Boolean rep; - bool result; - bool executed = this->node.Request( - communication_broker::kEndPointRegistrationSrv, - req, timeout, rep, result); - - if (!executed) - { - std::cerr << "[CommsClient] Endpoint registration srv not available" - << std::endl; - return false; - } - - if (!result) - { - std::cerr << "[CommsClient] Invalid data. Did you send the address " - << "followed by the endpoint?" << std::endl; - return false; - } + std::cerr << "[CommsClient] Endpoint registration srv not available" + << std::endl; + return endpoints; } - else + + const auto endpointID = rep.data(); + + if (!result || endpointID == invalidEndpointId) { - subt_msgs::Bind::Request req; - req.address = address; - req.endpoint = endpoint; + std::cerr << "[CommsClient] Invalid endpoint registration data." + << std::endl; + return endpoints; + } - subt_msgs::Bind::Response rep; + endpoints.emplace_back(std::make_pair(endpointID, endpoint)); + } + else + { + subt_msgs::Bind::Request req; + req.client_id = this->clientId; + req.endpoint = endpoint; - bool executed = ros::service::call( - communication_broker::kEndPointRegistrationSrv, req, rep); + subt_msgs::Bind::Response rep; - if (!executed) - { - std::cerr << "[CommsClient] Endpoint registration srv not available" - << std::endl; - return false; - } + bool executed = ros::service::call( + communication_broker::kEndPointRegistrationSrv, req, rep); - if (!rep.success) - { - std::cerr << "[CommsClient] Invalid data. Did you send the address " - << "followed by the endpoint?" << std::endl; - return false; - } + if (!executed) + { + std::cerr << "[CommsClient] Endpoint registration srv not available" + << std::endl; + return endpoints; + } + + if (rep.endpoint_id == communication_broker::invalidEndpointId) + { + std::cerr << "[CommsClient] Invalid endpoint registration data." + << std::endl; + return endpoints; } + + endpoints.emplace_back(std::make_pair(rep.endpoint_id, endpoint)); } } - if (!this->advertised) + if (this->useIgnition) { - // Use ignition transport if this is the basestation. Otherwise use ros. - if (this->useIgnition) + std::lock_guard lock(this->mutex); + + // Ignition transport registers a datagram receive service, so we have to + // make sure we only advertise it once + if (!this->advertised) { - // Advertise a oneway service for receiving message requests. + // Advertise a oneway service for receiving message requests. This assumes + // there is only a single node running the client in useIgnition mode. ignition::transport::AdvertiseServiceOptions opts; if (this->isPrivate) opts.SetScope(ignition::transport::Scope_t::PROCESS); @@ -232,51 +256,117 @@ bool CommsClient::Bind(std::functionnode.Advertise(address, &CommsClient::OnMessage, this, opts)) { std::cerr << "[" << this->Host() << "] Bind Error: could not advertise " - << address << std::endl; - return false; + << address << " when binding " << unicastEndPoint << std::endl; + + // if we cannot advertise but we have already bound the endpoints, + // we need to unbind them before exiting with error + for (const auto& endpoint : endpoints) + { + ignition::msgs::UInt32 req; + req.set_data(endpoint.first); + this->node.Request(kEndPointUnregistrationSrv, req); + } + endpoints.clear(); + return endpoints; } + this->advertised = true; } - else + } + + // Register the callbacks. + { + std::lock_guard lock(this->mutex); + for (const auto& endpoint : endpoints) { - ros::NodeHandle nh; - // Advertise on the global namespace - this->commsModelOnMessageService = nh.advertiseService( - "/" + address, &CommsClient::OnMessageRos, this); + const auto& endpointID = endpoint.first; + const auto& endpointName = endpoint.second; + + ignmsg << "Storing callback for " << endpointName << std::endl; + + this->callbacks[endpointName][endpointID] = std::bind(_cb, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4); } + } - this->advertised = true; + return endpoints; +} + +////////////////////////////////////////////////// +bool CommsClient::Unbind(communication_broker::EndpointID _endpointId) +{ + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::Unbind:" + << "Calling Unregister() before registering the client." + << std::endl; + return false; } - // Register the callbacks. + std::string endpoint; { - std::lock_guard lock(this->mutex); - for (std::string endpoint : {unicastEndPoint, bcastEndpoint}) + std::unique_lock lock(this->mutex); + for (const auto& callbackKV : this->callbacks) { - if (endpoint != bcastEndpoint || bcastAdvertiseNeeded) - { - ignmsg << "Storing callback for " << endpoint << std::endl; - this->callbacks[endpoint] = std::bind(_cb, - std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4); - } - else + const auto& endpointName = callbackKV.first; + const auto& endpointCallbacks = callbackKV.second; + + for (const auto& endpointKV : endpointCallbacks) { - ignwarn << "Skipping callback register for " << endpoint << std::endl; + const auto& endpointID = endpointKV.first; + if (endpointID == _endpointId) + { + endpoint = endpointName; + break; + } } + if (!endpoint.empty()) + break; } } - return true; -} + if (endpoint.empty()) + { + std::cerr << "Trying to unbind an endpoint that is not bound." << std::endl; + return false; + } + + ignition::msgs::UInt32 req; + req.set_data(_endpointId); + const unsigned int timeout = 3000u; + ignition::msgs::Boolean rep; + bool result; + + bool executed = this->node.Request( + kEndPointUnregistrationSrv, req, timeout, rep, result); + + if (!executed) + return false; + + if (result && rep.data()) + { + std::unique_lock lock(this->mutex); + this->callbacks[endpoint].erase(_endpointId); + if (this->callbacks[endpoint].empty()) + this->callbacks.erase(endpoint); + return true; + } + return false; +} ////////////////////////////////////////////////// bool CommsClient::SendTo(const std::string &_data, const std::string &_dstAddress, const uint32_t _port) { // Sanity check: Make sure that the communications are enabled. - if (!this->enabled) + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::SendTo: " + << "Calling SendTo() before registering the client." + << std::endl; return false; + } // Restrict the maximum size of a message. if (_data.size() > this->kMtu) @@ -354,6 +444,15 @@ void CommsClient::StartBeaconInterval(ros::Duration _period) ////////////////////////////////////////////////// bool CommsClient::Register() { + if (this->enabled || this->clientId != invalidClientId) + { + std::cerr << "[" << this->localAddress + << "] CommsClient::Register: Calling Register() on an already " + << "registered client." + << std::endl; + return false; + } + bool executed; bool result; @@ -363,11 +462,14 @@ bool CommsClient::Register() ignition::msgs::StringMsg req; req.set_data(this->localAddress); - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; const unsigned int timeout = 3000u; executed = this->node.Request( kAddrRegistrationSrv, req, timeout, rep, result); + + if (executed && result) + this->clientId = rep.data(); } else { @@ -377,17 +479,20 @@ bool CommsClient::Register() req.local_address = this->localAddress; executed = ros::service::call(kAddrRegistrationSrv, req, rep); - result = rep.success; + result = executed; + if (executed) + this->clientId = rep.client_id; } - if (!executed) + if (!executed || !result || this->clientId == invalidClientId) { std::cerr << "[" << this->localAddress << "] CommsClient::Register: Problem registering with broker" << std::endl; + return false; } - return executed && result; + return true; } ////////////////////////////////////////////////// @@ -396,11 +501,37 @@ bool CommsClient::Unregister() bool executed; bool result; + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::Unregister:" + << "Calling Unregister() before registering the client." + << std::endl; + return false; + } + + // unbind all endpoints + + // copy the callbacks array because Unbind() removes items from it + decltype(this->callbacks) callbacksCopy; + { + std::unique_lock lock(this->mutex); + callbacksCopy = this->callbacks; + } + + for (const auto& callbackKV : callbacksCopy) + { + // we intentionally ignore failures in unbind as there's nothing to do + for (const auto& endpointKV : callbackKV.second) + this->Unbind(endpointKV.first); + } + // to be sure none are left there in case a later Register() is called + this->callbacks.clear(); + // Use ignition transport if this is the base station. Otherwise, use ROS. if (this->useIgnition) { - ignition::msgs::StringMsg req; - req.set_data(this->localAddress); + ignition::msgs::UInt32 req; + req.set_data(this->clientId); ignition::msgs::Boolean rep; const unsigned int timeout = 3000u; @@ -413,12 +544,15 @@ bool CommsClient::Unregister() subt_msgs::Unregister::Request req; subt_msgs::Unregister::Response rep; - req.local_address = this->localAddress; + req.client_id = this->clientId; executed = ros::service::call(kAddrUnregistrationSrv, req, rep); result = rep.success; } + if (executed && result) + this->clientId = invalidClientId; + return executed && result; } @@ -434,19 +568,22 @@ void CommsClient::OnMessage(const msgs::Datagram &_msg) this->clockMsg.sim().nsec() * 1e-9; this->neighbors[_msg.src_address()] = std::make_pair(time, _msg.rssi()); - for (auto cb : this->callbacks) + if (this->callbacks.find(endPoint) == this->callbacks.end()) + return; + + for (const auto& endpointCallbacksKV : this->callbacks[endPoint]) { - if (cb.first == endPoint && cb.second) + auto& callback = endpointCallbacksKV.second; + if (callback) { - cb.second(_msg.src_address(), _msg.dst_address(), - _msg.dst_port(), _msg.data()); + callback(_msg.src_address(), _msg.dst_address(), + _msg.dst_port(), _msg.data()); } } } ////////////////////////////////////////////////// -bool CommsClient::OnMessageRos(subt_msgs::DatagramRos::Request &_req, - subt_msgs::DatagramRos::Response &_res) +void CommsClient::OnMessageRos(const subt_msgs::DatagramRos::Request &_req) { auto endPoint = _req.dst_address + ":" + std::to_string(_req.dst_port); @@ -455,16 +592,18 @@ bool CommsClient::OnMessageRos(subt_msgs::DatagramRos::Request &_req, this->neighbors[_req.src_address] = std::make_pair(ros::Time::now().toSec(), _req.rssi); - for (auto cb : this->callbacks) + if (this->callbacks.find(endPoint) == this->callbacks.end()) + return; + + for (const auto& endpointCallbacksKV : this->callbacks[endPoint]) { - if (cb.first == endPoint && cb.second) + auto& callback = endpointCallbacksKV.second; + if (callback) { - cb.second(_req.src_address, _req.dst_address, - _req.dst_port, _req.data); + callback(_req.src_address, _req.dst_address, + _req.dst_port, _req.data); } } - - return true; } ////////////////////////////////////////////////// diff --git a/subt-communication/subt_communication_broker/tests/unit_test.cpp b/subt-communication/subt_communication_broker/tests/unit_test.cpp index 865b79e2..84db730a 100644 --- a/subt-communication/subt_communication_broker/tests/unit_test.cpp +++ b/subt-communication/subt_communication_broker/tests/unit_test.cpp @@ -18,11 +18,15 @@ #include #include +#define private public +#define protected public #include #include #include #include #include +#undef protected +#undef private using namespace subt; using namespace subt::communication_model; @@ -30,6 +34,22 @@ using namespace subt::rf_interface; using namespace subt::rf_interface::range_model; using namespace subt::communication_broker; +void setDummyComms(Broker& broker) +{ + struct radio_configuration radio; + radio.pathloss_f = [](const double&, radio_state&, radio_state&) { + return rf_power(); + }; + + broker.SetDefaultRadioConfiguration(radio); + broker.SetCommunicationFunction(&subt::communication_model::attempt_send); + broker.SetPoseUpdateFunction( + [](const std::string& name) + { + return std::make_tuple(true, ignition::math::Pose3d::Zero, 0.0); + }); +} + TEST(broker, instatiate) { Broker broker; @@ -86,28 +106,40 @@ TEST(broker, communicate) broker.SetPoseUpdateFunction(pose_update_func); broker.Start(); - CommsClient c1("1", false, true); - CommsClient c2("2", false , true); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); - auto c2_cb = [=](const std::string& src, + std::vector receivedData; + auto c2_cb = [=,&receivedData](const std::string& src, const std::string& dst, const uint32_t port, const std::string& data) { std::cout << "Received " << data.size() << "(" << data << ") bytes from " << src << std::endl; + receivedData.emplace_back(data); }; c2.Bind(c2_cb); + std::unordered_set sentData; for(unsigned int i=0; i < 15; ++i) { std::ostringstream oss; oss << "Hello c2, " << i; c1.SendTo(oss.str(), "2"); + sentData.insert(oss.str()); broker.DispatchMessages(); } + // there is some packet loss on the path + ASSERT_LT(8u, receivedData.size()); + + for (const auto& data : receivedData) + { + EXPECT_NE(sentData.end(), sentData.find(data)); + } + // geometry_msgs::PoseStamped a, b; // a.header.frame_id = "world"; // b = a; @@ -119,6 +151,636 @@ TEST(broker, communicate) // ); } +TEST(broker, broadcast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, "", 123)); + // c3 binds to a different port and thus should receive nothing + EXPECT_FALSE(!c3.Bind(c3_cb, "", 124)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), kBroadcast, 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} + +// multicast is broken in ignition-only setup, because it advertises a service +// called "/multicast" in each client... obviously, this can't work with more +// than one client... fortunately, in case SubtRosRelay is used, this relay +// is the only client that "physically" interacts with the ignition layer, so +// that should work +// this could be fixed by rewriting the ignition layer to work via messages +// rather than services +#if 0 +TEST(broker, multicast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice, but once to multicast address and once on + // unicast + EXPECT_FALSE(!c1.Bind(c1_cb, kMulticast, 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, kMulticast, 123)); + EXPECT_FALSE(!c3.Bind(c3_cb, "", 123)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), kMulticast, 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + ASSERT_EQ(sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[i]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} +#endif + +TEST(broker, unicast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, "", 123)); + // c3 binds to a different port and thus should receive nothing + EXPECT_FALSE(!c3.Bind(c3_cb, "", 124)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), "1", 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(0u, receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + } + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), "2", 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(0u, receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} + +TEST(broker, notTwoClientsForSameAddressWithIgnition) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient c1("1", false, true, false); + CommsClient c2("1", false, true, false); + + auto cb = [](const std::string& src, const std::string& dst, + const uint32_t port, const std::string& data) + { + }; + + EXPECT_LT(0u, c1.Bind(cb).size()); + // this should be 0 because the ign service /1 has already been advertised + // that doesn't happen due to + // https://github.com/ignitionrobotics/ign-transport/issues/217 + // EXPECT_EQ(0u, c2.Bind(cb).size()); +} + +TEST(brokerUnit, registerOnce) +{ + Broker broker; + broker.Start(); + + ClientID client1, client2; + + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_NE(invalidClientId, client1 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client2 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_NE(client1, client2); + + EXPECT_FALSE(broker.Unregister(3)); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_TRUE(broker.Unregister(client1)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(client1)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client2)); + EXPECT_EQ(0, broker.Team()->size()); +} + +TEST(brokerUnit, registerTwice) +{ + Broker broker; + broker.Start(); + + ClientID client11, client12, client21, client22; + + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_NE(invalidClientId, client11 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client12 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client21 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_NE(invalidClientId, client22 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_NE(client11, client22); + EXPECT_NE(client12, client21); + EXPECT_NE(client12, client22); + EXPECT_NE(client21, client22); + + EXPECT_TRUE(broker.Unregister(client11)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(8)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client21)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client22)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client12)); + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(client11)); +} + +TEST(brokerUnit, registrationWorks) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ASSERT_TRUE(broker.Team()->empty()); + ASSERT_TRUE(broker.endpoints.empty()); + + ClientID client1; + ASSERT_NE(invalidClientId, client1 = broker.Register("1")); + + ASSERT_EQ(1u, broker.Team()->size()); + ASSERT_TRUE(broker.endpoints.empty()); + ASSERT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_EQ("1", broker.Team()->at("1")->address); + EXPECT_EQ("1", broker.Team()->at("1")->name); + + ClientID client2; + ASSERT_NE(invalidClientId, client2 = broker.Register("2")); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_TRUE(broker.endpoints.empty()); + ASSERT_NE(broker.Team()->end(), broker.Team()->find("2")); + EXPECT_EQ("2", broker.Team()->at("2")->address); + EXPECT_EQ("2", broker.Team()->at("2")->name); + + EndpointID endpoint11; + ASSERT_NE(invalidEndpointId, endpoint11 = broker.Bind(client1, "1:1")); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(1u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_EQ(1u, broker.endpoints.at("1:1").size()); + EXPECT_EQ("1", broker.endpoints.at("1:1")[0].address); + + EndpointID endpoint12; + ASSERT_NE(invalidEndpointId, endpoint12 = broker.Bind(client1, "1:2")); + EXPECT_NE(endpoint12, endpoint11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:2")); + ASSERT_EQ(1u, broker.endpoints.at("1:2").size()); + EXPECT_EQ("1", broker.endpoints.at("1:2")[0].address); + + EndpointID endpoint21; + ASSERT_NE(invalidEndpointId, endpoint21 = broker.Bind(client2, "2:1")); + EXPECT_NE(endpoint12, endpoint21); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(3u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + ASSERT_EQ(1u, broker.endpoints.at("2:1").size()); + EXPECT_EQ("2", broker.endpoints.at("2:1")[0].address); + + EndpointID endpoint22; + ASSERT_NE(invalidEndpointId, endpoint22 = broker.Bind(client2, "2:2")); + EXPECT_NE(endpoint22, endpoint11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(4u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:2")); + ASSERT_EQ(1u, broker.endpoints.at("2:2").size()); + EXPECT_EQ("2", broker.endpoints.at("2:2")[0].address); + + // The broker library doesn't check whether endpoints and client addresses + // match + EndpointID endpointWeird; + ASSERT_NE(invalidEndpointId, endpointWeird = broker.Bind(client1, "2:2")); + EXPECT_NE(endpoint22, endpointWeird); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(4u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:2")); + ASSERT_EQ(2u, broker.endpoints.at("2:2").size()); + EXPECT_EQ("1", broker.endpoints.at("2:2")[1].address); + + // make sure broadcast endpoints work + EndpointID endpointBcast11; + ASSERT_NE(invalidEndpointId, + endpointBcast11 = broker.Bind(client1, "broadcast:1")); + EXPECT_NE(endpointWeird, endpointBcast11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(1u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("1", broker.endpoints.at("broadcast:1")[0].address); + + EndpointID endpointBcast21; + ASSERT_NE(invalidEndpointId, + endpointBcast21 = broker.Bind(client2, "broadcast:1")); + EXPECT_NE(endpointBcast21, endpointBcast11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(2u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("2", broker.endpoints.at("broadcast:1")[1].address); + + // test multiple binds to the same endpoint from the same client + EndpointID endpointBcast111; + ASSERT_NE(invalidEndpointId, + endpointBcast111 = broker.Bind(client1, "broadcast:1")); + EXPECT_NE(endpointBcast21, endpointBcast111); + + EXPECT_NE(endpointBcast11, endpointBcast111); + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(2u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("1", broker.endpoints.at("broadcast:1")[0].address); +} + +TEST(brokerUnit, unbind) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client11 = broker.Register("1"); + ClientID client2 = broker.Register("2"); + ClientID client12 = broker.Register("1"); + + ASSERT_NE(invalidClientId, client11); + ASSERT_NE(invalidClientId, client2); + ASSERT_NE(invalidClientId, client12); + + EndpointID client11_ep11_1 = broker.Bind(client11, "1:1"); + EndpointID client11_ep11_2 = broker.Bind(client11, "1:1"); + EndpointID client12_ep11_1 = broker.Bind(client12, "1:1"); + EndpointID client12_ep11_2 = broker.Bind(client12, "1:1"); + EndpointID client2_ep21_1 = broker.Bind(client2, "2:1"); + EndpointID client2_ep21_2 = broker.Bind(client2, "2:1"); + EndpointID client2_ep11_1 = broker.Bind(client2, "1:1"); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client2_ep21_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + + // cannot unbind it twice + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + + EXPECT_TRUE(broker.Unbind(client2_ep21_2)); + + // the endpoint key remains, but it is empty + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client11_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client11_ep11_2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client2_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client12_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client12_ep11_2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(0u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_FALSE(broker.Unbind(client11_ep11_1)); + EXPECT_FALSE(broker.Unbind(client11_ep11_2)); + EXPECT_FALSE(broker.Unbind(client12_ep11_1)); + EXPECT_FALSE(broker.Unbind(client12_ep11_2)); + EXPECT_FALSE(broker.Unbind(client2_ep11_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_2)); +} + + +TEST(brokerUnit, unregisterUnbinds) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client11 = broker.Register("1"); + ClientID client2 = broker.Register("2"); + ClientID client12 = broker.Register("1"); + + ASSERT_NE(invalidClientId, client11); + ASSERT_NE(invalidClientId, client2); + ASSERT_NE(invalidClientId, client12); + + EndpointID client11_ep11_1 = broker.Bind(client11, "1:1"); + EndpointID client11_ep11_2 = broker.Bind(client11, "1:1"); + EndpointID client12_ep11_1 = broker.Bind(client12, "1:1"); + EndpointID client12_ep11_2 = broker.Bind(client12, "1:1"); + EndpointID client2_ep21_1 = broker.Bind(client2, "2:1"); + EndpointID client2_ep21_2 = broker.Bind(client2, "2:1"); + EndpointID client2_ep11_1 = broker.Bind(client2, "1:1"); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client12)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client11)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(0u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_FALSE(broker.Unregister(client11)); + EXPECT_FALSE(broker.Unregister(client12)); + EXPECT_FALSE(broker.Unregister(client2)); + + EXPECT_FALSE(broker.Unbind(client11_ep11_1)); + EXPECT_FALSE(broker.Unbind(client11_ep11_2)); + EXPECT_FALSE(broker.Unbind(client12_ep11_1)); + EXPECT_FALSE(broker.Unbind(client12_ep11_2)); + EXPECT_FALSE(broker.Unbind(client2_ep11_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_2)); +} + +TEST(brokerUnit, sendRequiresRegisterAndBind) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client2; + ASSERT_NE(invalidClientId, client2 = broker.Register("2")); + + subt::msgs::Datagram msg; + msg.set_src_address("1"); + msg.set_dst_address("2"); + msg.set_dst_port(42); + msg.set_rssi(-30); + + // sender has to be registered + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // endpoint has to be registered + msg.set_src_address("2"); + msg.set_dst_address("1"); + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // destination has to be registered + ClientID client1; + EXPECT_NE(invalidClientId, client1 = broker.Register("1")); + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // endpoint has to be bound + ASSERT_NE(invalidEndpointId, broker.Bind(client1, "1:42")); + broker.OnMessage(msg); + EXPECT_TRUE(broker.DispatchMessages()); +} + int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); diff --git a/subt_msgs/srv/Bind.srv b/subt_msgs/srv/Bind.srv index 7f8a4e9f..af2edae6 100644 --- a/subt_msgs/srv/Bind.srv +++ b/subt_msgs/srv/Bind.srv @@ -1,6 +1,6 @@ # Bind -string address +uint32 client_id string endpoint ---- -bool success +uint32 endpoint_id diff --git a/subt_msgs/srv/Register.srv b/subt_msgs/srv/Register.srv index cb7a0f7d..d2a9d946 100644 --- a/subt_msgs/srv/Register.srv +++ b/subt_msgs/srv/Register.srv @@ -2,4 +2,4 @@ string local_address ---- -bool success +uint32 client_id diff --git a/subt_msgs/srv/Unregister.srv b/subt_msgs/srv/Unregister.srv index d7b7bb5b..9a2e25be 100644 --- a/subt_msgs/srv/Unregister.srv +++ b/subt_msgs/srv/Unregister.srv @@ -1,5 +1,5 @@ # Unregister -string local_address +uint32 client_id ---- bool success diff --git a/subt_ros/CMakeLists.txt b/subt_ros/CMakeLists.txt index 3edd7df2..b9a86e6e 100644 --- a/subt_ros/CMakeLists.txt +++ b/subt_ros/CMakeLists.txt @@ -125,3 +125,39 @@ install(DIRECTORY launch install(PROGRAMS scripts/rostopic_stats_logger.sh DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) + +if(CATKIN_ENABLE_TESTING) + add_executable(test_broker test/test_broker.cc) + target_link_libraries(test_broker ${catkin_LIBRARIES} ignition-msgs6) + add_dependencies(tests test_broker) + + find_package(rostest REQUIRED) + find_package(rosgraph_msgs REQUIRED) + include_directories(${rosgraph_msgs_INCLUDE_DIRS}) + include_directories(${GTEST_INCLUDE_DIRS}) + + # we do not use add_rostest_gtest() because that doesn't allow for multiple + # independent targets/binaries to be built... and it is desirable that the + # tests are in separate binaries, because then rostest tears down all the + # helper nodes between switching to the other test node... on the other hand, + # keeping them all under one rostest command serializes the execution of the + # tests (while three separate rostests would run in parallel). + set(comms_types unicast multicast broadcast) + set(comms_targets "") + foreach(comms_type IN LISTS comms_types) + set(target test_comms_client_${comms_type}) + add_executable(${target} EXCLUDE_FROM_ALL test/comms_relay_${comms_type}.cc) + target_link_libraries(${target} + ${catkin_LIBRARIES} + ${rosgraph_msgs_LIBRARIES} + ${GTEST_LIBRARIES} + ignition-common3::ignition-common3 + ) + add_dependencies(tests ${target}) + list(APPEND comms_targets ${target}) + endforeach() + + add_rostest(test/comms_relay.test + DEPENDENCIES subt_ros_relay test_broker ${comms_targets} + ) +endif() diff --git a/subt_ros/package.xml b/subt_ros/package.xml index a759b1c2..b646806c 100644 --- a/subt_ros/package.xml +++ b/subt_ros/package.xml @@ -27,6 +27,9 @@ topic_tools std_msgs + rosgraph_msgs + rostest + diff --git a/subt_ros/src/SubtRosRelay.cc b/subt_ros/src/SubtRosRelay.cc index 0353d959..8f82d729 100644 --- a/subt_ros/src/SubtRosRelay.cc +++ b/subt_ros/src/SubtRosRelay.cc @@ -14,6 +14,9 @@ * limitations under the License. * */ + +#include + #include #include #include @@ -29,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -105,12 +109,18 @@ class SubtRosRelay /// The received message is added to a message queue to be handled by a /// separate thread. /// \param[in] _msg The message. - public: void OnMessage(const subt::msgs::Datagram &_msg); + /// \param[in] _resolvedAddress The real destination of the message (i.e. + /// broadcast and multicast addresses resolved to a real address). + public: void OnMessage(const subt::msgs::Datagram &_msg, + const std::string& _resolvedAddress); /// \brief Process messages in consumed from the message queue - /// The message is forwarded via a ROS service call. - /// \param[in] _req The message. - public: void ProcessMessage(const subt::msgs::Datagram &_req); + /// The message is forwarded via a ROS topic. + /// \param[in] _msg A pair containing the message and the real destination of + /// the message (i.e. broadcast and multicast addresses resolved to a real + /// address). + public: void ProcessMessage( + std::pair &_msg); /// \brief Creates an AsyncSpinner and handles received messages public: void Spin(); @@ -149,6 +159,14 @@ class SubtRosRelay /// the origin artifact. public: ros::ServiceServer poseFromArtifactService; + /// \brief This is a mutex protecting registeredClients and boundAddresses. + public: std::mutex clientsMutex; + + /// \brief List of clients that called the registration service and did not + /// (yet) unregister. + public: std::unordered_map + registeredClients; + /// \brief The set of bound address. This is bookkeeping that helps /// to reduce erroneous error output in the ::Bind function. public: std::set boundAddresses; @@ -156,7 +174,8 @@ class SubtRosRelay /// \brief Lock free queue for holding msgs from Transport. This is needed to /// avoid deadlocks between the Transport thread that invokes callbacks and /// the main thread that handles messages. - public: boost::lockfree::spsc_queue msgQueue{10}; + public: boost::lockfree::spsc_queue< + std::pair> msgQueue{10}; /// \brief This mutex is used in conjunction with notifyCond to notify the /// main thread the arrival of new messages. @@ -173,6 +192,9 @@ class SubtRosRelay /// \brief Pointer to the ROS bag recorder. public: std::unique_ptr rosRecorder; + + /// \brief Publishers to /address/comms for each robot. Indexed by robot name. + public: std::unordered_map commsPublishers; }; ////////////////////////////////////////////////// @@ -259,11 +281,17 @@ SubtRosRelay::SubtRosRelay() this->bagThread.reset(new std::thread([&](){ this->rosRecorder->run(); })); + + ROS_INFO_STREAM("Running SubT ROS relay on Ign Partition '" + << this->node.Options().Partition() << "' and ROS master '" + << ros::master::getURI() << "'."); } ////////////////////////////////////////////////// SubtRosRelay::~SubtRosRelay() { + if (this->bagThread->joinable()) + this->bagThread->join(); } ///////////////////////////////////////////////// @@ -390,40 +418,78 @@ bool SubtRosRelay::OnPoseFromArtifact( bool SubtRosRelay::OnBind(subt_msgs::Bind::Request &_req, subt_msgs::Bind::Response &_res) { - if (std::find(this->robotNames.begin(), this->robotNames.end(), - _req.address) ==this->robotNames.end()) + std::string address; + { + std::unique_lock lock(this->clientsMutex); + if (this->registeredClients.find(_req.client_id) == + this->registeredClients.end()) + { + ROS_ERROR("Trying to bind on a client that has not been registered"); + return false; + } + else + { + address = this->registeredClients[_req.client_id]; + } + } + + if (address.empty()) + { + ROS_ERROR_STREAM("OnBind requested to bind on an invalid client.\n"); + return false; + } + + if (std::find(this->robotNames.begin(), this->robotNames.end(), address) == + this->robotNames.end()) { ROS_ERROR_STREAM("OnBind address does not match origination. Attempted " - << "impersonation of a robot as robot[" << _req.address << "].\n"); + << "impersonation of a robot as robot[" << address << "].\n"); return false; } - ignition::msgs::StringMsg_V req; - req.add_data(_req.address); - req.add_data(_req.endpoint); + subt::msgs::EndpointRegistration req; + req.set_client_id(_req.client_id); + req.set_endpoint(_req.endpoint); const unsigned int timeout = 3000u; - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; bool result; bool executed = this->node.Request( subt::communication_broker::kEndPointRegistrationSrv, req, timeout, rep, result); - _res.success = result; + _res.endpoint_id = subt::communication_broker::invalidEndpointId; + if (executed && result) + _res.endpoint_id = rep.data(); - if (executed && result && - // Only establish the Ignition service once per client. - this->boundAddresses.find(_req.address) == this->boundAddresses.end()) { - if (!this->node.Advertise(_req.address, &SubtRosRelay::OnMessage, this)) - { - std::cerr << "Bind Error: could not advertise " - << _req.address << std::endl; - return false; - } - else + std::unique_lock lock(this->clientsMutex); + + if (executed && result && + // Only establish the Ignition service once per client. + this->boundAddresses.find(address) == this->boundAddresses.end()) { - this->boundAddresses.insert(_req.address); + std::function cb = + [this, address](const subt::msgs::Datagram& _msg) + { + this->OnMessage(_msg, address); + }; + + if (!this->node.Advertise(address, cb)) + { + ROS_ERROR_STREAM("Bind Error: could not advertise " << address << + " while binding " << _req.endpoint); + return false; + } + else + { + this->boundAddresses.insert(address); + + // Prepare the ROS comms publisher + this->commsPublishers[address] = + this->rosnode->advertise( + "/" + address + "/comms", 1000); + } } } @@ -459,7 +525,7 @@ bool SubtRosRelay::OnRegister(subt_msgs::Register::Request &_req, ignition::msgs::StringMsg req; req.set_data(_req.local_address); - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; bool result; const unsigned int timeout = 3000u; @@ -467,16 +533,37 @@ bool SubtRosRelay::OnRegister(subt_msgs::Register::Request &_req, subt::communication_broker::kAddrRegistrationSrv, req, timeout, rep, result); - _res.success = result; - return executed; + if (executed && result && rep.data() != + subt::communication_broker::invalidClientId) + { + _res.client_id = rep.data(); + { + std::unique_lock lock(this->clientsMutex); + this->registeredClients[_res.client_id] = _req.local_address; + } + return true; + } + + _res.client_id = subt::communication_broker::invalidClientId; + return false; } ////////////////////////////////////////////////// bool SubtRosRelay::OnUnregister(subt_msgs::Unregister::Request &_req, subt_msgs::Unregister::Response &_res) { - ignition::msgs::StringMsg req; - req.set_data(_req.local_address); + { + std::unique_lock lock(this->clientsMutex); + if (this->registeredClients.find(_req.client_id) == + this->registeredClients.end()) + { + ROS_ERROR("Trying to unregister a client that has not been registered"); + return false; + } + } + + ignition::msgs::UInt32 req; + req.set_data(_req.client_id); ignition::msgs::Boolean rep; bool result; @@ -486,41 +573,47 @@ bool SubtRosRelay::OnUnregister(subt_msgs::Unregister::Request &_req, subt::communication_broker::kAddrUnregistrationSrv, req, timeout, rep, result); - _res.success = result; + _res.success = executed && result && rep.data(); + + if (_res.success) + { + std::unique_lock lock(this->clientsMutex); + this->registeredClients.erase(_req.client_id); + } return executed; } ////////////////////////////////////////////////// -void SubtRosRelay::OnMessage(const subt::msgs::Datagram &_req) +void SubtRosRelay::OnMessage(const subt::msgs::Datagram &_req, + const std::string& _resolvedAddress) { - this->msgQueue.push(_req); + this->msgQueue.push(std::make_pair(_req, _resolvedAddress)); // Notify the main thread this->notifyCond.notify_one(); } ////////////////////////////////////////////////// -void SubtRosRelay::ProcessMessage(const subt::msgs::Datagram &_req) +void SubtRosRelay::ProcessMessage( + std::pair &_msg) { - subt_msgs::DatagramRos::Request req; - subt_msgs::DatagramRos::Response res; - req.src_address = _req.src_address(); - req.dst_address = _req.dst_address(); - req.dst_port = _req.dst_port(); - req.data = _req.data(); - req.rssi = _req.rssi(); - - if (_req.dst_address() == subt::communication_broker::kBroadcast) - { - for (const std::string &dest : this->robotNames) - { - ros::service::call(dest, req, res); - } - } - else - { - ros::service::call(_req.dst_address(), req, res); - } + const auto& datagram = _msg.first; + const auto& resolvedAddress = _msg.second; + + subt_msgs::DatagramRos::Request rosMsg; + rosMsg.src_address = datagram.src_address(); + rosMsg.dst_address = datagram.dst_address(); + rosMsg.dst_port = datagram.dst_port(); + rosMsg.data = datagram.data(); + rosMsg.rssi = datagram.rssi(); + + // We can be sure resolvedAddress exists in the map because + // it was initialized in OnBind(), where OnMessage() is also registered + // as a service handler that fills the message queue, which, in turn, gets + // processed by this function. + // Broadcast messages get handled by binding a broadcast endpoint by + // each respective client. + this->commsPublishers[resolvedAddress].publish(rosMsg); } ////////////////////////////////////////////////// @@ -537,7 +630,9 @@ void SubtRosRelay::Spin() // that the lock is released before calling `ProcessMessage` to avoid // deadlocks. std::unique_lock lock(this->notifyMutex); - this->notifyCond.wait(lock, + // add timeout to the wait to allow graceful exit when no messages are + // coming + this->notifyCond.wait_for(lock, std::chrono::seconds(1), [this] { return this->msgQueue.read_available(); }); } diff --git a/subt_ros/test/comms_relay.test b/subt_ros/test/comms_relay.test new file mode 100644 index 00000000..8f051d20 --- /dev/null +++ b/subt_ros/test/comms_relay.test @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/subt_ros/test/comms_relay_broadcast.cc b/subt_ros/test/comms_relay_broadcast.cc new file mode 100644 index 00000000..bdbe70f4 --- /dev/null +++ b/subt_ros/test/comms_relay_broadcast.cc @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, broadcast) +{ + + // Test communication on broadcast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient broadcaster("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, "", 123)); advanceTime(); + // c3 binds to a different port and thus should receive nothing + ASSERT_FALSE(!c3.Bind(c3_cb, "", 124)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), kBroadcast, 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c1.SendTo("Selftest", kBroadcast, 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(2u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/comms_relay_multicast.cc b/subt_ros/test/comms_relay_multicast.cc new file mode 100644 index 00000000..c4b06994 --- /dev/null +++ b/subt_ros/test/comms_relay_multicast.cc @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, multicast) +{ + // Test communication on multicast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient sender("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice, but once to multicast address and once on + // unicast + ASSERT_FALSE(!c1.Bind(c1_cb, kMulticast, 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, kMulticast, 123)); advanceTime(); + ASSERT_FALSE(!c3.Bind(c3_cb, "", 123)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + EXPECT_TRUE(sender.SendTo(oss.str(), kMulticast, 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[i]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c1.SendTo("Selftest", kMulticast, 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(1u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/comms_relay_unicast.cc b/subt_ros/test/comms_relay_unicast.cc new file mode 100644 index 00000000..3c91d8b4 --- /dev/null +++ b/subt_ros/test/comms_relay_unicast.cc @@ -0,0 +1,188 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, unicast) +{ + + // Test communication on broadcast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient broadcaster("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, "", 123)); advanceTime(); + // c3 binds to a different port and thus should receive nothing + ASSERT_FALSE(!c3.Bind(c3_cb, "", 124)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), "c1", 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(0u, receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + } + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), "c2", 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(0u, receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c2.SendTo("Selftest", "c2", 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(0u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/test_broker.cc b/subt_ros/test/test_broker.cc new file mode 100644 index 00000000..778a6624 --- /dev/null +++ b/subt_ros/test/test_broker.cc @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +/// \brief This is a dummy message broker for tests. It has no message loss. + +#include +#include +#include +#include +#include + +using namespace subt; +using namespace subt::communication_model; +using namespace subt::rf_interface; +using namespace subt::rf_interface::range_model; +using namespace subt::communication_broker; + +///////////////////////////////////////////////// +void setDummyComms(Broker& broker) +{ + struct radio_configuration radio; + radio.pathloss_f = [](const double&, radio_state&, radio_state&) { + return rf_power(); + }; + + broker.SetDefaultRadioConfiguration(radio); + broker.SetCommunicationFunction(&subt::communication_model::attempt_send); + broker.SetPoseUpdateFunction( + [](const std::string& name) + { + return std::make_tuple(true, ignition::math::Pose3d::Zero, 0.0); + }); +} + +///////////////////////////////////////////////// +int main(int argc, char** argv) +{ + ros::init(argc, argv, "test_broker"); + // subscribe to /clock messages; this is needed because we do not have any + // NodeHandle + ros::start(); + + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ROS_INFO("Broker is running in partition '%s'", broker.IgnPartition().c_str()); + + while (!ros::Time::waitForValid(ros::WallDuration(1.0))) + { + ros::spinOnce(); + ROS_WARN("Waiting for valid ROS time"); + } + + ros::Rate rate(1.0); + + while (ros::ok()) + { + broker.DispatchMessages(); + rate.sleep(); + ROS_DEBUG("Broker dispatched"); + } + + ROS_INFO("Broker exiting"); +} \ No newline at end of file