From 754162bc7f994a55c31e23a94d246b4a13685228 Mon Sep 17 00:00:00 2001 From: Martin Pecka Date: Mon, 1 Feb 2021 18:16:13 +0100 Subject: [PATCH 1/2] Initial implementation of a shared version of CommsClient. This allows multiple clients from the same ROS node (or even from different ROS nodes), adds support for Unbind() command and allows ROS nodes with CommsClient to be restarted and still able to bind to the comms broker. --- .../subt_communication_broker/common_types.h | 13 + .../subt_communication_broker.h | 98 ++- .../subt_communication_client.h | 124 +++- .../src/protobuf/CMakeLists.txt | 1 + .../src/protobuf/endpoint_registration.proto | 32 + .../src/subt_communication_broker.cpp | 453 ++++++++++-- .../src/subt_communication_client.cpp | 390 ++++++---- .../tests/unit_test.cpp | 668 +++++++++++++++++- subt_msgs/srv/Bind.srv | 4 +- subt_msgs/srv/Register.srv | 2 +- subt_msgs/srv/Unregister.srv | 2 +- subt_ros/CMakeLists.txt | 36 + subt_ros/package.xml | 3 + subt_ros/src/SubtRosRelay.cc | 184 +++-- subt_ros/test/comms_relay.test | 35 + subt_ros/test/comms_relay_broadcast.cc | 159 +++++ subt_ros/test/comms_relay_multicast.cc | 156 ++++ subt_ros/test/comms_relay_unicast.cc | 188 +++++ subt_ros/test/test_broker.cc | 80 +++ 19 files changed, 2317 insertions(+), 311 deletions(-) create mode 100644 subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto create mode 100644 subt_ros/test/comms_relay.test create mode 100644 subt_ros/test/comms_relay_broadcast.cc create mode 100644 subt_ros/test/comms_relay_multicast.cc create mode 100644 subt_ros/test/comms_relay_unicast.cc create mode 100644 subt_ros/test/test_broker.cc diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h index 9fbef16c..16d027ff 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/common_types.h @@ -89,11 +89,24 @@ const std::string kAddrUnregistrationSrv = "/address/unregister"; /// \brief Service used to register an end point. const std::string kEndPointRegistrationSrv = "/end_point/register"; +/// \brief Service used to unregister an end point. +const std::string kEndPointUnregistrationSrv = "/end_point/unregister"; + /// \brief Address used to receive neighbor updates. const std::string kNeighborsTopic = "/neighbors"; /// \brief Default port. const uint32_t kDefaultPort = 4100u; +/// \brief ID of a broker client. +typedef uint32_t ClientID; +/// \brief ID denoting an invalid broker client. +constexpr ClientID invalidClientId {0}; + +/// \brief ID of a bound endpoint belonging to a particular client. +typedef uint32_t EndpointID; +/// \brief ID denoting an invalid endpoint. +constexpr EndpointID invalidEndpointId {0}; + } } diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h index d7ab6a72..d3892c19 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_broker.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -53,6 +54,8 @@ pose_update_function; /// \brief Map of endpoints using EndPoints_M = std::map>; + struct BrokerPrivate; + /// \brief Store messages, and exposes an API for registering new clients, /// bind to a particular address, push new messages or get the list of /// messages already stored in the queue. @@ -62,7 +65,7 @@ pose_update_function; public: Broker(); /// \brief Destructor. - public: virtual ~Broker() = default; + public: virtual ~Broker(); /// \brief Start handling services /// @@ -83,29 +86,39 @@ pose_update_function; public: void NotifyNeighbors(); /// \brief Dispatch all incoming messages. - public: void DispatchMessages(); + public: bool DispatchMessages(); /// \brief This method associates an endpoint with a broker client and its /// address. An endpoint is constructed as an address followed by ':', /// followed by the port. E.g.: "192.168.1.5:8000" is a valid endpoint. - /// \param[in] _clientAddress Address of the broker client. + /// \param[in] _clientId ID of a registered client. /// \param[in] _endpoint End point requested to bind. - /// \return True if the operation succeed or false otherwise (if the client - /// was already bound to the same endpoint). - private: bool Bind(const std::string &_clientAddress, - const std::string &_endpoint); - - /// \brief Register a new client for message handling. - /// \param[in] _id Unique ID of the client. - /// \return True if the operation succeed of false otherwise (if the same - /// id was already registered). - public: bool Register(const std::string &_id); - - /// \brief Unregister a client and unbind from all the endpoints. - /// \param[in] _id Unique ID of the client. - /// \return True if the operation succeed or false otherwise (if there is - /// no client registered for this ID). - public: bool Unregister(const std::string &_id); + /// \return ID that should be used for unbinding the endpoint. If + /// invalidEndpointId is returned, the binding request failed (e.g. due to + /// wrong endpoint specification, or if the _clientId is wrong). + public: EndpointID Bind(ClientID _clientId, const std::string &_endpoint); + + /// \brief This method cancels the association between a client and an + /// endpoint. When all equal endpoints on the same client are unbound, + /// the client will stop receiving messages for the given endpoint. + /// \param[in] _endpointId ID of the endpoint to unbind. It has to be an ID + /// received from a previous call to Bind(). + /// \return Whether the unbind was successful. It may fail e.g. when the + /// given ID is invalid and the broker doesn't know it. + public: bool Unbind(EndpointID _endpointId); + + /// \brief Register a new client for message handling. Multiple clients for + /// the same address are allowed even from a single process. + /// \param[in] _clientAddress Address of the client. + /// \return ID of the client (should be later used for unregistration). + /// If the returned ID is invalidClientId, the registration failed. + public: ClientID Register(const std::string &_clientAddress); + + /// \brief Unregister a client and unbind from all its endpoints. + /// \param[in] _clientId The ID received from the Register() call. + /// \return True if the operation succeeded or false otherwise (if there is + /// no client registered for this ID or the ID is invalid). + public: bool Unregister(ClientID _clientId); /// \brief Set the radio configuration for address /// \param[in] address @@ -128,29 +141,45 @@ pose_update_function; /// \param[in] f Function that finds pose based on name public: void SetPoseUpdateFunction(pose_update_function f); + /// \brief Get the Ignition partition this broker is running in. + /// \return The partition name. The reference gets invalid when the broker + /// object goes out of scope. + public: const std::string& IgnPartition() const; + /// \brief Callback executed when a new registration request is received. /// \param[in] _req The address contained in the request. - /// \param[out] _rep The result of the service. True when the registration - /// went OK or false otherwise (e.g.: the same address was already - /// registered). + /// \param[out] _rep An ID of the registered client. This ID should be used + /// when unregistering this client. If registration failed, invalidClientId + /// value is returned. private: bool OnAddrRegistration(const ignition::msgs::StringMsg &_req, - ignition::msgs::Boolean &_rep); + ignition::msgs::UInt32 &_rep); /// \brief Callback executed when a new unregistration request is received. - /// \param[in] _req The address contained in the request. + /// \param[in] _req ID of the client to unregister. /// \param[out] _rep The result of the service. True when the unregistration - /// went OK or false otherwise (e.g.: the address wasn't registered). - private: bool OnAddrUnregistration(const ignition::msgs::StringMsg &_req, + /// went OK or false otherwise (e.g.: the ID wasn't registered). + private: bool OnAddrUnregistration(const ignition::msgs::UInt32 &_req, ignition::msgs::Boolean &_rep); - /// \brief Callback executed when a new registration request is received. - /// \param[in] _req The end point contained in the request. The first - /// string is the client address and the second string is the end point. - /// \param[out] _rep The result of the service. True when the registration - /// went OK of false otherwise (e.g.: _req doesn't contain two strings). + /// \brief Callback executed when a new endpoint registration request is + /// received. + /// \param[in] _req The endpoint to register together with ID of the client + /// to which this registration belongs. + /// \param[out] _rep An ID of the endpoint. This ID should be used + /// when unregistering this endpoint. If registration failed, + /// invalidEndpointId value is returned. private: bool OnEndPointRegistration( - const ignition::msgs::StringMsg_V &_req, - ignition::msgs::Boolean &_rep); + const subt::msgs::EndpointRegistration &_req, + ignition::msgs::UInt32 &_rep); + + /// \brief Callback executed when a new endpoint unregistration request is + /// received. + /// \param[in] _req ID of the endpoint to unregister. + /// \param[out] _rep The result of the service. True when the unregistration + /// went OK or false otherwise (e.g.: the ID wasn't registered). + private: bool OnEndPointUnregistration( + const ignition::msgs::UInt32 &_req, + ignition::msgs::Boolean &_rep); /// \brief Callback executed when a new request is received. /// \param[in] _req The datagram contained in the request. @@ -185,6 +214,9 @@ pose_update_function; /// \brief Pose update function private: pose_update_function pose_update_f; + + /// \brief Private definitions and data + private: std::unique_ptr dataPtr; }; } diff --git a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h index ef1a6cd2..a6c7c6b8 100644 --- a/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h +++ b/subt-communication/subt_communication_broker/include/subt_communication_broker/subt_communication_client.h @@ -45,10 +45,22 @@ namespace subt /// \param[in] _useIgnition Set to true if you are using Ignition /// transport (i.e. not ROS). This is needed by the base station, /// and tests. If you are a regular robot, then you really really do not - /// want to set this to true as your Commsclient will not work. - public: CommsClient(const std::string &_localAddress, - const bool _isPrivate = false, - const bool _useIgnition = false); + /// want to set this to true as your Commsclient will not work. For each + /// address (robot), there can be only one client with _useIgnition set to + /// true. There can be multiple clients with _useIgnition set to false. + /// \param[in] _listenBeacons If true (the default), the commsclient will + /// listen to beacon packets on port 4000. This (or another regular stream + /// of data) is required for the Neighbors() function to work. + /// \param[in] _rosNh The ROS node handle used to create ROS subscribers for + /// incoming messages. Only used when _useIgnition is false. If not given, + /// a default node handle is used (in the current namespace). + public: + explicit + CommsClient(const std::string& _localAddress, + bool _isPrivate = false, + bool _useIgnition = false, + bool _listenBeacons = true, + ros::NodeHandle* _rosNh = nullptr); /// \brief Destructor. public: virtual ~CommsClient(); @@ -59,7 +71,9 @@ namespace subt /// \brief This method can bind a local address and a port to a /// virtual socket. This is a required step if your agent needs to - /// receive messages. + /// receive messages. It is possible to bind multiple callbacks to the + /// same address:port endpoint, even from the same client ID. Just call + /// Bind() multiple times. /// /// \param[in] _cb Callback function to be executed when a new message is /// received associated to the specified <_address, port>. @@ -74,22 +88,31 @@ namespace subt /// You will receive all the messages sent from any node to this multicast /// group. /// \param[in] _port Port used to receive messages. - /// \return True when success or false otherwise. + /// \return List of endpoints that were created in result of this bind call. + /// It can be up to two endpoints - one for the unicast/multicast address, + /// and one for the broadcast address. The returned pairs contain the IDs + /// of the endpoints and their names (i.e. address:port). If the returned + /// list is empty, binding failed. You may test for this case using the + /// overloaded operator! on the result vector. /// /// * Example usage (bind on the local address and default port): /// this->Bind(&OnDataReceived, "192.168.1.3"); /// * Example usage (Bind on the multicast group and custom port.): /// this->Bind(&OnDataReceived, this->kMulticast, 5123); - public: bool Bind(std::function _cb, - const std::string &_address = "", - const int _port = communication_broker::kDefaultPort); + public: + std::vector> + Bind(std::function _cb, + const std::string& _address = "", + int _port = communication_broker::kDefaultPort); /// \brief This method can bind a local address and a port to a /// virtual socket. This is a required step if your agent needs to - /// receive messages. + /// receive messages. It is possible to bind multiple callbacks to the + /// same address:port endpoint, even from the same client ID. Just call + /// Bind() multiple times. /// /// \param[in] _cb Callback function to be executed when a new message is /// received associated to the specified <_address, port>. @@ -105,20 +128,27 @@ namespace subt /// You will receive all the messages sent from any node to this multicast /// group. /// \param[in] _port Port used to receive messages. - /// \return True when success or false otherwise. + /// \return List of endpoints that were created in result of this bind call. + /// It can be up to two endpoints - one for the unicast/multicast address, + /// and one for the broadcast address. The returned pairs contain the IDs + /// of the endpoints and their names (i.e. address:port). If the returned + /// list is empty, binding failed. You may test for this case using the + /// overloaded operator! on the result vector. /// /// * Example usage (bind on the local address and default port): /// this->Bind(&MyClass::OnDataReceived, this, "192.168.1.3"); /// * Example usage (Bind on the multicast group and custom port.): /// this->Bind(&MyClass::OnDataReceived, this, this->kMulticast, 5123); - public: template - bool Bind(void(C::*_cb)(const std::string &_srcAddress, - const std::string &_dstAddress, - const uint32_t _dstPort, - const std::string &_data), - C *_obj, - const std::string &_address = "", - const int _port = communication_broker::kDefaultPort) + public: + template + std::vector> + Bind(void(C::*_cb)(const std::string& _srcAddress, + const std::string& _dstAddress, + uint32_t _dstPort, + const std::string& _data), + C* _obj, + const std::string& _address = "", + int _port = communication_broker::kDefaultPort) { return this->Bind(std::bind(_cb, _obj, std::placeholders::_1, @@ -128,6 +158,14 @@ namespace subt _address, _port); } + /// \brief This method can unbind from a "socket" acquired by Bind(). Once + /// unbound, the registered callback will no longer be called. + /// + /// \param[in] _endpointId ID of the endpoint to unbind. This is the ID + /// returned from Bind() call. + /// \return Success of the unbinding. + public: bool Unbind(communication_broker::EndpointID _endpointId); + /// \brief Send some data to other/s member/s of the team. /// /// \param[in] _data Payload. The maximum size of the payload is 1500 bytes. @@ -141,7 +179,7 @@ namespace subt /// bigger than 1500 bytes). public: bool SendTo(const std::string &_data, const std::string &_dstAddress, - const uint32_t _port = communication_broker::kDefaultPort); + uint32_t _port = communication_broker::kDefaultPort); /// \brief Type for storing neighbor data public: typedef std::map> Neighbor_M; @@ -150,6 +188,11 @@ namespace subt /// /// \return A map of addresses and signal strength from your /// local neighbors. + /// \note For this function to work reliably, you need to be sure there is + /// a regular flow of data from all other robots towards this address. + /// One way to achieve it is to send and listen to beacons (see constructor + /// argument _listenBeacons and function StartBeaconInterval() or + /// SendBeacons()). public: Neighbor_M Neighbors() const; /// \brief Broadcast a BEACON packet @@ -165,8 +208,9 @@ namespace subt private: bool Register(); /// \brief Unregister the current address. This will make a synchronous call - /// to the broker to unregister the address. - /// \return True when the unregistration succeed or false otherwise. + /// to the broker to unregister the address. It will also unbind all + /// bound callbacks from this client. + /// \return True when the unregistration succeeded or false otherwise. private: bool Unregister(); /// \brief Function called each time a new datagram message is received. @@ -174,9 +218,8 @@ namespace subt private: void OnMessage(const msgs::Datagram &_msg); /// \brief Function called each time a new datagram message is received. - /// \param[in] _msg The incoming message. - private: bool OnMessageRos(subt_msgs::DatagramRos::Request &_req, - subt_msgs::DatagramRos::Response &_res); + /// \param[in] _req The incoming message. + private: void OnMessageRos(const subt_msgs::DatagramRos::Request &_req); /// \brief On clock message. This is used primarily/only by the /// BaseStation. @@ -192,7 +235,7 @@ namespace subt private: using Callback_t = std::function; /// \brief The local address. @@ -216,9 +259,12 @@ namespace subt /// \brief An Ignition Transport node for communications. private: ignition::transport::Node node; + private: using Callbacks = + std::unordered_map; + /// \brief User callbacks. The key is the topic name /// (e.g.: "/subt/192.168.2.1/4000") and the value is the user callback. - private: std::map callbacks; + private: std::map callbacks; /// \brief True when the broker validated my address. Enabled must be true /// for being able to send and receive data. @@ -235,8 +281,8 @@ namespace subt /// \brief A mutex for avoiding race conditions. private: mutable std::mutex mutex; - /// \brief Service that receives comms messages. - private: ros::ServiceServer commsModelOnMessageService; + /// \brief Subscriber that receives comms messages from SubtRosRelay. + private: ros::Subscriber commsSub; /// \brief Clock message from simulation. Used by the base station. /// The base station is run as a plugin alongside simulation, and does @@ -254,6 +300,20 @@ namespace subt /// \brief Period of the beacon in nanoseconds. private: int64_t beaconPeriodNs{0}; + + private: communication_broker::ClientID clientId { + communication_broker::invalidClientId + }; }; } + +/// \brief Backwards compatibility that allows easily testing Bind() success. +/// \param _val Result of a Bind() operation. +/// \return True if the Bind() operation failed. +bool operator !(const std::vector< + std::pair>& _val) +{ + return _val.empty(); +} + #endif diff --git a/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt b/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt index 445c2596..665a093e 100644 --- a/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt +++ b/subt-communication/subt_communication_broker/src/protobuf/CMakeLists.txt @@ -10,6 +10,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/Protobuf.cmake) set(PROTO_MESSAGES datagram.proto + endpoint_registration.proto neighbor_m.proto ) diff --git a/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto b/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto new file mode 100644 index 00000000..07a545c1 --- /dev/null +++ b/subt-communication/subt_communication_broker/src/protobuf/endpoint_registration.proto @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +syntax = "proto3"; +package subt.msgs; + +/// \ingroup subt_msgs +/// \interface EndpointRegistration +/// \brief A message containing data needed for registering an endpoint. + +message EndpointRegistration +{ + /// \brief ID of the client this endpoint should be bound to. + uint32 client_id = 1; + + /// \brief The endpoint to register ("address:port"). + string endpoint = 2; +} diff --git a/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp b/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp index 1b6a5a25..c9bd014e 100644 --- a/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp +++ b/subt-communication/subt_communication_broker/src/subt_communication_broker.cpp @@ -30,12 +30,223 @@ namespace subt namespace communication_broker { +/// \brief Helper class for managing bidirectional mapping of client IDs and +/// addresses. The class is not thread-safe, so callers must ensure that none +/// of the public methods get called simultaneously. +struct ClientIDs +{ + /// \brief Number of active clients for each address. This structure can be + /// accessed by outer code for reading, but not for modification. + std::unordered_map numActiveClients; + + /// \brief Map of client IDs to addresses. This structure can be accessed by + /// outer code for reading, but not for modification. + std::unordered_map idToAddress; + + /// \brief Add a new client and generate its ID. + /// \param _address Address of the client. + /// \return ID of the client. This method should always succeed and return + /// an ID different from invalidClientID. + ClientID Add(const std::string& _address) + { + const auto clientId = this->NextID(); + this->idToAddress[clientId] = _address; + if (this->numActiveClients.find(_address) == this->numActiveClients.end()) + this->numActiveClients[_address] = 0; + this->numActiveClients[_address]++; + return clientId; + } + + /// \brief Unregister a client. + /// \param _id ID of the client. + /// \return Success of the unregistration. The method can fail e.g. when + /// trying to unregister a client which has not been registered. + bool Remove(const ClientID _id) + { + if (!this->Valid(_id)) + return false; + this->numActiveClients[this->idToAddress[_id]]--; + this->idToAddress.erase(_id); + return true; + } + + /// \brief Clear/reset the structure to be able to work as new. + /// \note This cancels all registrations of all clients and resets the client + /// ID numbering, so it is not valid to mix IDs of clients obtained before and + /// after a Clear() call. + void Clear() + { + this->numActiveClients.clear(); + this->idToAddress.clear(); + this->lastId = invalidClientId; + } + + /// \brief Check validity of a client ID. + /// \param _id ID to check. + /// \return Whether a client with the given ID has been registered. + bool Valid(const ClientID _id) const + { + return _id != invalidClientId && + this->idToAddress.find(_id) != this->idToAddress.end(); + } + + /// \brief Return an ID for a new client. + /// \return The ID. + private: ClientID NextID() + { + return ++this->lastId; + } + + /// \brief Last ID given to a client. + private: ClientID lastId {invalidClientId}; +}; + +/// \brief Helper class for managing mappings between endpoint names, their IDs, +/// related clients and so on. The class is not thread-safe, so callers must +/// ensure that none of the public methods get called simultaneously. +struct EndpointIDs +{ + /// \brief Maps endpoint IDs to endpoint names. This structure can be accessed + /// by outer code for reading, but not for modification. + std::unordered_map idToEndpoint; + + /// \brief Maps endpoint names to related clients. Each client ID stores + /// information about the number of "connections" of this client ID to the + /// endpoint. This structure can be accessed by outer code for reading, but + /// not for modification. + std::unordered_map> + endpointToClientIds; + + /// \brief Maps endpoint IDs to their related client IDs. This structure can + /// be accessed by outer code for reading, but not for modification. + std::unordered_map endpointIdToClientId; + + /// \brief Maps client IDs to all their endpoint IDs. This structure can be + /// accessed by outer code for reading, but not for modification. + std::unordered_map> + clientIdToEndpointIds; + + /// \brief Add new endpoint related to the given client ID. + /// \param _endpoint Endpoint name. + /// \param _clientId Client ID. This method does no validation of the IDs. + /// \return ID of the new endpoint. Should always succeed and return an ID + /// not equal to invalidEndpointID. + EndpointID Add(const std::string& _endpoint, const ClientID _clientId) + { + const auto endpointId = this->NextID(); + + this->idToEndpoint[endpointId] = _endpoint; + this->endpointIdToClientId[endpointId] = _clientId; + + if (this->endpointToClientIds[_endpoint].find(_clientId) == + this->endpointToClientIds[_endpoint].end()) + this->endpointToClientIds[_endpoint][_clientId] = 0; + this->endpointToClientIds[_endpoint][_clientId]++; + + this->clientIdToEndpointIds[_clientId].insert(endpointId); + + return endpointId; + } + + /// \brief Remove the given endpoint ID from this structure. + /// \param _id ID of the endpoint to remove. + /// \return Whether the removal succeeded or not. It may fail e.g. when the + /// removed endpoint doesn't exist. + bool Remove(const EndpointID _id) + { + if (!this->Valid(_id)) + return false; + + const auto endpointName = this->idToEndpoint[_id]; + + if (this->endpointIdToClientId.find(_id) == + this->endpointIdToClientId.end()) + return false; + + const auto clientId = this->endpointIdToClientId[_id]; + this->endpointIdToClientId.erase(_id); + + if (this->endpointToClientIds.find(endpointName) == + this->endpointToClientIds.end()) + return false; + + if (this->endpointToClientIds[endpointName].find(clientId) == + this->endpointToClientIds[endpointName].end()) + return false; + + this->endpointToClientIds[endpointName][clientId]--; + if (this->endpointToClientIds[endpointName][clientId] == 0u) + this->endpointToClientIds[endpointName].erase(clientId); + + if (this->clientIdToEndpointIds.find(clientId) == + this->clientIdToEndpointIds.end()) + return false; + + if (this->clientIdToEndpointIds[clientId].find(_id) == + this->clientIdToEndpointIds[clientId].end()) + return false; + + this->clientIdToEndpointIds[clientId].erase(_id); + + return true; + } + + /// \brief Clear/reset the structure to be able to work as new. + /// \note This cancels all registrations of all endpoints and resets the + /// endpoint ID numbering, so it is not valid to mix IDs of endpoints obtained + /// before and after a Clear() call. + void Clear() + { + this->idToEndpoint.clear(); + this->clientIdToEndpointIds.clear(); + this->endpointToClientIds.clear(); + this->endpointIdToClientId.clear(); + this->lastId = invalidEndpointId; + } + + /// \brief Check validity of an endpoint ID. + /// \param _id ID to check. + /// \return Whether an endpoint with the given ID has been registered. + bool Valid(const EndpointID _id) const + { + return _id != invalidEndpointId && + this->idToEndpoint.find(_id) != this->idToEndpoint.end(); + } + + /// \brief Return an ID for a new client. + /// \return The ID. + private: EndpointID NextID() + { + return ++this->lastId; + } + + /// \brief Last ID given to a client. + private: EndpointID lastId {invalidEndpointId}; +}; + +/// \brief PIMPL structure. +struct BrokerPrivate +{ + /// \brief IDs of registered clients. + ClientIDs clientIDs; + + /// \brief IDs of registered endpoints. + EndpointIDs endpointIDs; +}; + ////////////////////////////////////////////////// Broker::Broker() - : team(std::make_shared()) + : team(std::make_shared()), + dataPtr(std::make_unique()) { } +////////////////////////////////////////////////// +Broker::~Broker() +{ + // cannot use default destructor because of dataPtr +} + ////////////////////////////////////////////////// void Broker::Start() { @@ -66,6 +277,15 @@ void Broker::Start() return; } + // Advertise the service for unregistering end points. + if (!this->node.Advertise(kEndPointUnregistrationSrv, + &Broker::OnEndPointUnregistration, this)) + { + std::cerr << "Error advertising srv [" << kEndPointUnregistrationSrv << "]" + << std::endl; + return; + } + // Advertise a oneway service for centralizing all message requests. if (!this->node.Advertise(kBrokerSrv, &Broker::OnMessage, this)) { @@ -82,6 +302,9 @@ void Broker::Start() << std::endl; return; } + + std::cout << "Started communication broker in Ignition partition " + << this->IgnPartition() << std::endl; } ////////////////////////////////////////////////// @@ -90,6 +313,9 @@ void Broker::Reset() std::lock_guard lk(this->mutex); this->incomingMsgs.clear(); this->endpoints.clear(); + this->team->clear(); + this->dataPtr->clientIDs.Clear(); + this->dataPtr->endpointIDs.Clear(); } ////////////////////////////////////////////////// @@ -124,12 +350,12 @@ void Broker::NotifyNeighbors() } ////////////////////////////////////////////////// -void Broker::DispatchMessages() +bool Broker::DispatchMessages() { std::lock_guard lk(this->mutex); if(this->incomingMsgs.empty()) - return; + return true; // Cannot dispatch messages if we don't have function handles for // pathloss and communication @@ -137,14 +363,14 @@ void Broker::DispatchMessages() { std::cerr << "[Broker::DispatchMessages()] Missing function handle for " << "communication" << std::endl; - return; + return false; } if(!pose_update_f) { std::cerr << "[Broker::DispatchMessages()]: Missing function for updating " << "pose" << std::endl; - return; + return false; } // Update state for all members in team (only do this for members @@ -160,10 +386,11 @@ void Broker::DispatchMessages() { std::cerr << "Problem getting state for " << t.second->name << ", skipping DispatchMessages()" << std::endl; - return; + return false; } } + bool allSucceeded = true; while (!this->incomingMsgs.empty()) { // Get the next message to dispatch. @@ -177,6 +404,7 @@ void Broker::DispatchMessages() std::cerr << "Broker::DispatchMessages(): Discarding message. Robot [" << msg.src_address() << "] is not registered as a member of the" << " team" << std::endl; + allSucceeded = false; continue; } @@ -202,6 +430,7 @@ void Broker::DispatchMessages() << "Robot [" << client.address << "] is not registered as a member of the" << " team" << std::endl; + allSucceeded = false; continue; } @@ -211,6 +440,7 @@ void Broker::DispatchMessages() { std::cerr << "No pathloss function defined for " << msg.src_address() << std::endl; + allSucceeded = false; continue; } @@ -231,6 +461,7 @@ void Broker::DispatchMessages() std::cerr << "[CommsBrokerPlugin::DispatchMessages()]: Error " << "sending message to [" << client.address << "]" << std::endl; + allSucceeded = false; } } } @@ -239,116 +470,192 @@ void Broker::DispatchMessages() { std::cerr << "[Broker::DispatchMessages()]: Could not find endpoint " << dstEndPoint << std::endl; + allSucceeded = false; } } + return allSucceeded; } ////////////////////////////////////////////////// -bool Broker::Bind(const std::string &_clientAddress, - const std::string &_endpoint) +EndpointID Broker::Bind(const ClientID _clientId, const std::string &_endpoint) { std::lock_guard lk(this->mutex); - // Make sure that the same client didn't bind the same end point before. + + if (!this->dataPtr->clientIDs.Valid(_clientId)) + { + std::cerr << "Broker::Bind() error: Client ID [" << _clientId + << "] is invalid." << std::endl; + return invalidEndpointId; + } + + const auto& clientAddress = this->dataPtr->clientIDs.idToAddress[_clientId]; + + auto clientFound = false; + if (this->endpoints.find(_endpoint) != this->endpoints.end()) { const auto &clientsV = this->endpoints[_endpoint]; for (const auto &client : clientsV) { - if (client.address == _clientAddress) + if (client.address == clientAddress) { - std::cerr << "Broker::Bind() error: Address [" << _clientAddress - << "] already used in a previous Bind()" << std::endl; - return false; + clientFound = true; + break; } } } - BrokerClientInfo clientInfo; - clientInfo.address = _clientAddress; - this->endpoints[_endpoint].push_back(clientInfo); + if (!clientFound) + { + BrokerClientInfo clientInfo; + clientInfo.address = clientAddress; + this->endpoints[_endpoint].push_back(clientInfo); + } + + const auto endpointId = this->dataPtr->endpointIDs.Add(_endpoint, _clientId); - return true; + return endpointId; } ////////////////////////////////////////////////// -bool Broker::Register(const std::string &_id) +bool Broker::Unbind(EndpointID _endpointId) { std::lock_guard lk(this->mutex); - auto kvp = this->team->find(_id); - if (kvp != this->team->end()) + + if (!this->dataPtr->endpointIDs.Valid(_endpointId)) + return false; + + const auto endpointName = + this->dataPtr->endpointIDs.idToEndpoint[_endpointId]; + + const auto clientId = + this->dataPtr->endpointIDs.endpointIdToClientId[_endpointId]; + const auto clientAddress = this->dataPtr->clientIDs.idToAddress[clientId]; + + if (this->endpoints.find(endpointName) == this->endpoints.end()) + return false; + + bool success = this->dataPtr->endpointIDs.Remove(_endpointId); + if (!success) + return false; + + if (this->dataPtr->endpointIDs.endpointToClientIds.find(endpointName) == + this->dataPtr->endpointIDs.endpointToClientIds.end()) + return false; + + bool hasOtherClientsOnTheSameAddress = false; + for (const auto clientKV : + this->dataPtr->endpointIDs.endpointToClientIds[endpointName]) { - std::cerr << "Broker::Register() warning: ID [" << _id << "] already exists" - << std::endl; + if (this->dataPtr->clientIDs.idToAddress[clientKV.first] == clientAddress && + clientKV.second > 0u) + { + hasOtherClientsOnTheSameAddress = true; + break; + } + } + + if (hasOtherClientsOnTheSameAddress) + return true; + + auto& clientsV = this->endpoints[endpointName]; + + auto i = std::begin(clientsV); + while (i != std::end(clientsV)) + { + if (i->address == clientAddress) + { + clientsV.erase(i); + break; + } + else + { + ++i; + } } - else + + return true; +} + +////////////////////////////////////////////////// +ClientID Broker::Register(const std::string &_clientAddress) +{ + std::lock_guard lk(this->mutex); + auto kvp = this->team->find(_clientAddress); + if (kvp == this->team->end()) { auto newMember = std::make_shared(); // Name and address are the same in SubT. - newMember->address = _id; - newMember->name = _id; + newMember->address = _clientAddress; + newMember->name = _clientAddress; newMember->radio = default_radio_configuration; - (*this->team)[_id] = newMember; + (*this->team)[_clientAddress] = newMember; } - return true; + const auto clientId = this->dataPtr->clientIDs.Add(_clientAddress); + + return clientId; } ////////////////////////////////////////////////// -bool Broker::Unregister(const std::string &_id) +bool Broker::Unregister(const ClientID _clientId) { - std::lock_guard lk(this->mutex); - // Sanity check: Make sure that the ID exists. - if (this->team->find(_id) == this->team->end()) + if (!this->dataPtr->clientIDs.Valid(_clientId)) { - std::cerr << "Broker::Unregister() error: ID [" << _id << "] doesn't exist" - << std::endl; + std::cerr << "Broker::Unregister() error: Client ID [" << _clientId + << "] is invalid." << std::endl; return false; } - this->team->erase(_id); + bool success = true; - // Unbind. - for (auto &endpointKv : this->endpoints) + std::unordered_set endpointIds; { - auto &clientsV = endpointKv.second; + // make a copy because Unbind() calls will alter the structure + std::lock_guard lk(this->mutex); + endpointIds = this->dataPtr->endpointIDs.clientIdToEndpointIds[_clientId]; + } - auto i = std::begin(clientsV); - while (i != std::end(clientsV)) - { - if (i->address == _id) - i = clientsV.erase(i); - else - ++i; - } + for (const auto endpointId : endpointIds) + success = success && this->Unbind(endpointId); + + { + std::lock_guard lk(this->mutex); + + const auto& clientAddress = this->dataPtr->clientIDs.idToAddress[_clientId]; + success = success && this->dataPtr->clientIDs.Remove(_clientId); + + if (this->dataPtr->clientIDs.numActiveClients[clientAddress] == 0u) + this->team->erase(clientAddress); } - return true; + return success; } ///////////////////////////////////////////////// bool Broker::OnAddrRegistration(const ignition::msgs::StringMsg &_req, - ignition::msgs::Boolean &_rep) + ignition::msgs::UInt32 &_rep) { - std::string address = _req.data(); - bool result; + const auto& address = _req.data(); - result = this->Register(address); + const ClientID result = this->Register(address); _rep.set_data(result); - return result; + return result != invalidClientId; } ///////////////////////////////////////////////// -bool Broker::OnAddrUnregistration(const ignition::msgs::StringMsg &_req, +bool Broker::OnAddrUnregistration(const ignition::msgs::UInt32 &_req, ignition::msgs::Boolean &_rep) { - std::string address = _req.data(); + uint32_t clientId = _req.data(); + bool result; - result = this->Unregister(address); + result = this->Unregister(clientId); _rep.set_data(result); @@ -356,21 +663,28 @@ bool Broker::OnAddrUnregistration(const ignition::msgs::StringMsg &_req, } ///////////////////////////////////////////////// -bool Broker::OnEndPointRegistration(const ignition::msgs::StringMsg_V &_req, - ignition::msgs::Boolean &_rep) +bool Broker::OnEndPointRegistration( + const subt::msgs::EndpointRegistration &_req, + ignition::msgs::UInt32 &_rep) { - if (_req.data().size() != 2) - { - std::cerr << "[Broker::OnEndPointRegistration()] Expected two strings and " - << "got " << _req.data().size() << " instead" << std::endl; - return false; - } + ClientID clientId = _req.client_id(); + const auto& endpoint = _req.endpoint(); - bool result; - std::string clientAddress = _req.data(0); - std::string endpoint = _req.data(1); + EndpointID result = this->Bind(clientId, endpoint); + + _rep.set_data(result); + + return result != invalidEndpointId; +} - result = this->Bind(clientAddress, endpoint); +///////////////////////////////////////////////// +bool Broker::OnEndPointUnregistration( + const ignition::msgs::UInt32 &_req, + ignition::msgs::Boolean &_rep) +{ + const EndpointID endpointId = _req.data(); + + bool result = this->Unbind(endpointId); _rep.set_data(result); @@ -387,6 +701,7 @@ void Broker::OnMessage(const subt::msgs::Datagram &_req) this->incomingMsgs.push_back(_req); } +////////////////////////////////////////////////// void Broker::SetRadioConfiguration(const std::string& address, communication_model::radio_configuration config) { @@ -433,5 +748,11 @@ void Broker::SetPoseUpdateFunction(pose_update_function f) pose_update_f = f; } +////////////////////////////////////////////////// +const std::string& Broker::IgnPartition() const +{ + return this->node.Options().Partition(); +} + } } diff --git a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp index d3b24baf..73c51d5e 100644 --- a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp +++ b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp @@ -24,13 +24,15 @@ #include #include +#include using namespace subt; using namespace subt::communication_broker; ////////////////////////////////////////////////// CommsClient::CommsClient(const std::string &_localAddress, - const bool _isPrivate, const bool _useIgnition) + const bool _isPrivate, const bool _useIgnition, const bool _listenBeacons, + ros::NodeHandle* _rosNh) : localAddress(_localAddress), isPrivate(_isPrivate), useIgnition(_useIgnition) @@ -65,6 +67,17 @@ CommsClient::CommsClient(const std::string &_localAddress, std::this_thread::sleep_for(std::chrono::milliseconds(100)); this->enabled = this->Register(); elapsed = std::chrono::steady_clock::now() - kStart; + + // give Ctrl-C a chance + if (!_useIgnition) + { + if (!ros::ok()) + return; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds (5)); + } } if (!this->enabled) @@ -74,12 +87,31 @@ CommsClient::CommsClient(const std::string &_localAddress, return; } - // Bind to be sure we receive beacon packets - auto cb = [] (const std::string&, - const std::string&, - const uint32_t, - const std::string&) { }; - this->Bind(cb, "", kBeaconPort); + if (!this->useIgnition) + { + ros::NodeHandle nh; + if (_rosNh != nullptr) + nh = *_rosNh; + + this->commsSub = nh.subscribe( + "/" + _localAddress + "/comms",1000, + &CommsClient::OnMessageRos, this); + } + + if (_listenBeacons) + { + // Bind to be sure we receive beacon packets + auto cb = [] (const std::string&, + const std::string&, + const uint32_t, + const std::string&) { }; + if (!this->Bind(cb, "", kBeaconPort)) + { + std::cerr << "[CommsClient] Could bind the beacon responder" << std::endl; + this->Unregister(); + return; + } + } this->enabled = true; } @@ -104,20 +136,23 @@ std::string CommsClient::Host() const } ////////////////////////////////////////////////// -bool CommsClient::Bind(std::function _cb, - const std::string &_address, - const int _port) +std::vector> +CommsClient::Bind(std::function _cb, + const std::string &_address, + const int _port) { + std::vector> endpoints; + // Sanity check: Make sure that the communications are enabled. - if (!this->enabled) + if (!this->enabled || this->clientId == invalidClientId) { std::cerr << "[" << this->Host() << "] Bind() error: Trying to bind before communications are enabled!" << std::endl; - return false; + return endpoints; } // Use current address if _address is not provided. @@ -131,100 +166,87 @@ bool CommsClient::Bind(std::functionHost() << "] Bind() error: Address [" << address << "] is not your local address" << std::endl; - return false; + return endpoints; } - // Mapping the "unicast socket" to a topic name. + // Mapping the "unicast socket" to an endpoint name. const auto unicastEndPoint = address + ":" + std::to_string(_port); const auto bcastEndpoint = communication_broker::kBroadcast + ":" + std::to_string(_port); - bool bcastAdvertiseNeeded; - - { - std::lock_guard lock(this->mutex); - - // Sanity check: Make sure that this address is not already used. - if (this->callbacks.find(unicastEndPoint) != this->callbacks.end()) - { - std::cerr << "[" << this->Host() << "] Bind() error: Address [" - << address << "] already used" << std::endl; - return false; - } - - bcastAdvertiseNeeded = - this->callbacks.find(bcastEndpoint) == this->callbacks.end(); - } - // Register the endpoints in the broker. - // Note that the broadcast endpoint will only be registered once. for (const std::string &endpoint : {unicastEndPoint, bcastEndpoint}) { - if (endpoint != bcastEndpoint || bcastAdvertiseNeeded) + // If this is the basestation, then we need to use ignition transport. + // Otherwise, the client is on a robot and needs to use ROS. + if (this->useIgnition) { - // If this is the basestation, then we need to use ignition transport. - // Otherwise, the client is on a robot and needs to use ROS. - if (this->useIgnition) + subt::msgs::EndpointRegistration req; + req.set_client_id(this->clientId); + req.set_endpoint(endpoint); + + const unsigned int timeout = 3000u; + ignition::msgs::UInt32 rep; + bool result; + bool executed = this->node.Request( + communication_broker::kEndPointRegistrationSrv, + req, timeout, rep, result); + + if (!executed) { - ignition::msgs::StringMsg_V req; - req.add_data(address); - req.add_data(endpoint); - - const unsigned int timeout = 3000u; - ignition::msgs::Boolean rep; - bool result; - bool executed = this->node.Request( - communication_broker::kEndPointRegistrationSrv, - req, timeout, rep, result); - - if (!executed) - { - std::cerr << "[CommsClient] Endpoint registration srv not available" - << std::endl; - return false; - } - - if (!result) - { - std::cerr << "[CommsClient] Invalid data. Did you send the address " - << "followed by the endpoint?" << std::endl; - return false; - } + std::cerr << "[CommsClient] Endpoint registration srv not available" + << std::endl; + return endpoints; } - else + + const auto endpointID = rep.data(); + + if (!result || endpointID == invalidEndpointId) { - subt_msgs::Bind::Request req; - req.address = address; - req.endpoint = endpoint; + std::cerr << "[CommsClient] Invalid endpoint registration data." + << std::endl; + return endpoints; + } - subt_msgs::Bind::Response rep; + endpoints.emplace_back(std::make_pair(endpointID, endpoint)); + } + else + { + subt_msgs::Bind::Request req; + req.client_id = this->clientId; + req.endpoint = endpoint; - bool executed = ros::service::call( - communication_broker::kEndPointRegistrationSrv, req, rep); + subt_msgs::Bind::Response rep; - if (!executed) - { - std::cerr << "[CommsClient] Endpoint registration srv not available" - << std::endl; - return false; - } + bool executed = ros::service::call( + communication_broker::kEndPointRegistrationSrv, req, rep); - if (!rep.success) - { - std::cerr << "[CommsClient] Invalid data. Did you send the address " - << "followed by the endpoint?" << std::endl; - return false; - } + if (!executed) + { + std::cerr << "[CommsClient] Endpoint registration srv not available" + << std::endl; + return endpoints; + } + + if (rep.endpoint_id == communication_broker::invalidEndpointId) + { + std::cerr << "[CommsClient] Invalid endpoint registration data." + << std::endl; + return endpoints; } + + endpoints.emplace_back(std::make_pair(rep.endpoint_id, endpoint)); } } - if (!this->advertised) + if (this->useIgnition) { - // Use ignition transport if this is the basestation. Otherwise use ros. - if (this->useIgnition) + // Ignition transport registers a datagram receive service, so we have to + // make sure we only advertise it once + if (!this->advertised) { - // Advertise a oneway service for receiving message requests. + // Advertise a oneway service for receiving message requests. This assumes + // there is only a single node running the client in useIgnition mode. ignition::transport::AdvertiseServiceOptions opts; if (this->isPrivate) opts.SetScope(ignition::transport::Scope_t::PROCESS); @@ -233,16 +255,19 @@ bool CommsClient::Bind(std::functionHost() << "] Bind Error: could not advertise " << address << std::endl; - return false; + + // if we cannot advertise but we have already bound the endpoints, + // we need to unbind them before exiting with error + for (const auto& endpoint : endpoints) + { + ignition::msgs::UInt32 req; + req.set_data(endpoint.first); + this->node.Request(kEndPointUnregistrationSrv, req); + } + endpoints.clear(); + return endpoints; } } - else - { - ros::NodeHandle nh; - // Advertise on the global namespace - this->commsModelOnMessageService = nh.advertiseService( - "/" + address, &CommsClient::OnMessageRos, this); - } this->advertised = true; } @@ -250,33 +275,97 @@ bool CommsClient::Bind(std::function lock(this->mutex); - for (std::string endpoint : {unicastEndPoint, bcastEndpoint}) + for (const auto& endpoint : endpoints) { - if (endpoint != bcastEndpoint || bcastAdvertiseNeeded) - { - ignmsg << "Storing callback for " << endpoint << std::endl; - this->callbacks[endpoint] = std::bind(_cb, - std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4); - } - else + const auto& endpointID = endpoint.first; + const auto& endpointName = endpoint.second; + + ignmsg << "Storing callback for " << endpointName << std::endl; + + this->callbacks[endpointName][endpointID] = std::bind(_cb, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4); + } + } + + return endpoints; +} + +////////////////////////////////////////////////// +bool CommsClient::Unbind(communication_broker::EndpointID _endpointId) +{ + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::Unbind:" + << "Calling Unregister() before registering the client." + << std::endl; + return false; + } + + std::string endpoint; + { + std::unique_lock lock(this->mutex); + for (const auto& callbackKV : this->callbacks) + { + const auto& endpointName = callbackKV.first; + const auto& endpointCallbacks = callbackKV.second; + + for (const auto& endpointKV : endpointCallbacks) { - ignwarn << "Skipping callback register for " << endpoint << std::endl; + const auto& endpointID = endpointKV.first; + if (endpointID == _endpointId) + { + endpoint = endpointName; + break; + } } + if (!endpoint.empty()) + break; } } - return true; -} + if (endpoint.empty()) + { + std::cerr << "Trying to unbind an endpoint that is not bound." << std::endl; + return false; + } + + ignition::msgs::UInt32 req; + req.set_data(_endpointId); + const unsigned int timeout = 3000u; + ignition::msgs::Boolean rep; + bool result; + + bool executed = this->node.Request( + kEndPointUnregistrationSrv, req, timeout, rep, result); + + if (!executed) + return false; + + if (result && rep.data()) + { + std::unique_lock lock(this->mutex); + this->callbacks[endpoint].erase(_endpointId); + if (this->callbacks[endpoint].empty()) + this->callbacks.erase(endpoint); + return true; + } + return false; +} ////////////////////////////////////////////////// bool CommsClient::SendTo(const std::string &_data, const std::string &_dstAddress, const uint32_t _port) { // Sanity check: Make sure that the communications are enabled. - if (!this->enabled) + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::SendTo:" + << "Calling Unregister() before registering the client." + << std::endl; return false; + } // Restrict the maximum size of a message. if (_data.size() > this->kMtu) @@ -354,6 +443,15 @@ void CommsClient::StartBeaconInterval(ros::Duration _period) ////////////////////////////////////////////////// bool CommsClient::Register() { + if (this->enabled || this->clientId != invalidClientId) + { + std::cerr << "[" << this->localAddress + << "] CommsClient::Register: Calling Register() on an already " + << "registered client." + << std::endl; + return false; + } + bool executed; bool result; @@ -363,11 +461,14 @@ bool CommsClient::Register() ignition::msgs::StringMsg req; req.set_data(this->localAddress); - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; const unsigned int timeout = 3000u; executed = this->node.Request( kAddrRegistrationSrv, req, timeout, rep, result); + + if (executed && result) + this->clientId = rep.data(); } else { @@ -377,17 +478,20 @@ bool CommsClient::Register() req.local_address = this->localAddress; executed = ros::service::call(kAddrRegistrationSrv, req, rep); - result = rep.success; + result = executed; + if (executed) + this->clientId = rep.client_id; } - if (!executed) + if (!executed || !result || this->clientId == invalidClientId) { std::cerr << "[" << this->localAddress << "] CommsClient::Register: Problem registering with broker" << std::endl; + return false; } - return executed && result; + return true; } ////////////////////////////////////////////////// @@ -396,11 +500,37 @@ bool CommsClient::Unregister() bool executed; bool result; + if (!this->enabled || this->clientId == invalidClientId) + { + std::cerr << "[" << this->localAddress << "] CommsClient::Unregister:" + << "Calling Unregister() before registering the client." + << std::endl; + return false; + } + + // unbind all endpoints + + // copy the callbacks array because Unbind() removes items from it + decltype(this->callbacks) callbacksCopy; + { + std::unique_lock lock(this->mutex); + callbacksCopy = this->callbacks; + } + + for (const auto& callbackKV : callbacksCopy) + { + // we intentionally ignore failures in unbind as there's nothing to do + for (const auto& endpointKV : callbackKV.second) + this->Unbind(endpointKV.first); + } + // to be sure none are left there in case a later Register() is called + this->callbacks.clear(); + // Use ignition transport if this is the base station. Otherwise, use ROS. if (this->useIgnition) { - ignition::msgs::StringMsg req; - req.set_data(this->localAddress); + ignition::msgs::UInt32 req; + req.set_data(this->clientId); ignition::msgs::Boolean rep; const unsigned int timeout = 3000u; @@ -413,12 +543,15 @@ bool CommsClient::Unregister() subt_msgs::Unregister::Request req; subt_msgs::Unregister::Response rep; - req.local_address = this->localAddress; + req.client_id = this->clientId; executed = ros::service::call(kAddrUnregistrationSrv, req, rep); result = rep.success; } + if (executed && result) + this->clientId = invalidClientId; + return executed && result; } @@ -434,19 +567,22 @@ void CommsClient::OnMessage(const msgs::Datagram &_msg) this->clockMsg.sim().nsec() * 1e-9; this->neighbors[_msg.src_address()] = std::make_pair(time, _msg.rssi()); - for (auto cb : this->callbacks) + if (this->callbacks.find(endPoint) == this->callbacks.end()) + return; + + for (const auto& endpointCallbacksKV : this->callbacks[endPoint]) { - if (cb.first == endPoint && cb.second) + auto& callback = endpointCallbacksKV.second; + if (callback) { - cb.second(_msg.src_address(), _msg.dst_address(), - _msg.dst_port(), _msg.data()); + callback(_msg.src_address(), _msg.dst_address(), + _msg.dst_port(), _msg.data()); } } } ////////////////////////////////////////////////// -bool CommsClient::OnMessageRos(subt_msgs::DatagramRos::Request &_req, - subt_msgs::DatagramRos::Response &_res) +void CommsClient::OnMessageRos(const subt_msgs::DatagramRos::Request &_req) { auto endPoint = _req.dst_address + ":" + std::to_string(_req.dst_port); @@ -455,16 +591,18 @@ bool CommsClient::OnMessageRos(subt_msgs::DatagramRos::Request &_req, this->neighbors[_req.src_address] = std::make_pair(ros::Time::now().toSec(), _req.rssi); - for (auto cb : this->callbacks) + if (this->callbacks.find(endPoint) == this->callbacks.end()) + return; + + for (const auto& endpointCallbacksKV : this->callbacks[endPoint]) { - if (cb.first == endPoint && cb.second) + auto& callback = endpointCallbacksKV.second; + if (callback) { - cb.second(_req.src_address, _req.dst_address, - _req.dst_port, _req.data); + callback(_req.src_address, _req.dst_address, + _req.dst_port, _req.data); } } - - return true; } ////////////////////////////////////////////////// diff --git a/subt-communication/subt_communication_broker/tests/unit_test.cpp b/subt-communication/subt_communication_broker/tests/unit_test.cpp index 865b79e2..84db730a 100644 --- a/subt-communication/subt_communication_broker/tests/unit_test.cpp +++ b/subt-communication/subt_communication_broker/tests/unit_test.cpp @@ -18,11 +18,15 @@ #include #include +#define private public +#define protected public #include #include #include #include #include +#undef protected +#undef private using namespace subt; using namespace subt::communication_model; @@ -30,6 +34,22 @@ using namespace subt::rf_interface; using namespace subt::rf_interface::range_model; using namespace subt::communication_broker; +void setDummyComms(Broker& broker) +{ + struct radio_configuration radio; + radio.pathloss_f = [](const double&, radio_state&, radio_state&) { + return rf_power(); + }; + + broker.SetDefaultRadioConfiguration(radio); + broker.SetCommunicationFunction(&subt::communication_model::attempt_send); + broker.SetPoseUpdateFunction( + [](const std::string& name) + { + return std::make_tuple(true, ignition::math::Pose3d::Zero, 0.0); + }); +} + TEST(broker, instatiate) { Broker broker; @@ -86,28 +106,40 @@ TEST(broker, communicate) broker.SetPoseUpdateFunction(pose_update_func); broker.Start(); - CommsClient c1("1", false, true); - CommsClient c2("2", false , true); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); - auto c2_cb = [=](const std::string& src, + std::vector receivedData; + auto c2_cb = [=,&receivedData](const std::string& src, const std::string& dst, const uint32_t port, const std::string& data) { std::cout << "Received " << data.size() << "(" << data << ") bytes from " << src << std::endl; + receivedData.emplace_back(data); }; c2.Bind(c2_cb); + std::unordered_set sentData; for(unsigned int i=0; i < 15; ++i) { std::ostringstream oss; oss << "Hello c2, " << i; c1.SendTo(oss.str(), "2"); + sentData.insert(oss.str()); broker.DispatchMessages(); } + // there is some packet loss on the path + ASSERT_LT(8u, receivedData.size()); + + for (const auto& data : receivedData) + { + EXPECT_NE(sentData.end(), sentData.find(data)); + } + // geometry_msgs::PoseStamped a, b; // a.header.frame_id = "world"; // b = a; @@ -119,6 +151,636 @@ TEST(broker, communicate) // ); } +TEST(broker, broadcast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, "", 123)); + // c3 binds to a different port and thus should receive nothing + EXPECT_FALSE(!c3.Bind(c3_cb, "", 124)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), kBroadcast, 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} + +// multicast is broken in ignition-only setup, because it advertises a service +// called "/multicast" in each client... obviously, this can't work with more +// than one client... fortunately, in case SubtRosRelay is used, this relay +// is the only client that "physically" interacts with the ignition layer, so +// that should work +// this could be fixed by rewriting the ignition layer to work via messages +// rather than services +#if 0 +TEST(broker, multicast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice, but once to multicast address and once on + // unicast + EXPECT_FALSE(!c1.Bind(c1_cb, kMulticast, 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, kMulticast, 123)); + EXPECT_FALSE(!c3.Bind(c3_cb, "", 123)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), kMulticast, 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + ASSERT_EQ(sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[i]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} +#endif + +TEST(broker, unicast) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient broadcaster("b", false, true, false); + CommsClient c1("1", false, true, false); + CommsClient c2("2", false , true, false); + CommsClient c3("3", false , true, false); + + std::vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c1.Bind(c1_cb, "", 123)); + EXPECT_FALSE(!c2.Bind(c2_cb, "", 123)); + // c3 binds to a different port and thus should receive nothing + EXPECT_FALSE(!c3.Bind(c3_cb, "", 124)); + + std::vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), "1", 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(0u, receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + } + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + broadcaster.SendTo(oss.str(), "2", 123); + sentData.emplace_back(oss.str()); + broker.DispatchMessages(); + } + + ASSERT_EQ(0u, receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData2[i]); + } +} + +TEST(broker, notTwoClientsForSameAddressWithIgnition) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + CommsClient c1("1", false, true, false); + CommsClient c2("1", false, true, false); + + auto cb = [](const std::string& src, const std::string& dst, + const uint32_t port, const std::string& data) + { + }; + + EXPECT_LT(0u, c1.Bind(cb).size()); + // this should be 0 because the ign service /1 has already been advertised + // that doesn't happen due to + // https://github.com/ignitionrobotics/ign-transport/issues/217 + // EXPECT_EQ(0u, c2.Bind(cb).size()); +} + +TEST(brokerUnit, registerOnce) +{ + Broker broker; + broker.Start(); + + ClientID client1, client2; + + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_NE(invalidClientId, client1 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client2 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_NE(client1, client2); + + EXPECT_FALSE(broker.Unregister(3)); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_TRUE(broker.Unregister(client1)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(client1)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client2)); + EXPECT_EQ(0, broker.Team()->size()); +} + +TEST(brokerUnit, registerTwice) +{ + Broker broker; + broker.Start(); + + ClientID client11, client12, client21, client22; + + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_NE(invalidClientId, client11 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client12 = broker.Register("1")); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_NE(invalidClientId, client21 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_NE(invalidClientId, client22 = broker.Register("2")); + EXPECT_EQ(2, broker.Team()->size()); + + EXPECT_NE(client11, client22); + EXPECT_NE(client12, client21); + EXPECT_NE(client12, client22); + EXPECT_NE(client21, client22); + + EXPECT_TRUE(broker.Unregister(client11)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(8)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client21)); + EXPECT_EQ(2, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client22)); + EXPECT_EQ(1, broker.Team()->size()); + EXPECT_TRUE(broker.Unregister(client12)); + EXPECT_EQ(0, broker.Team()->size()); + EXPECT_FALSE(broker.Unregister(client11)); +} + +TEST(brokerUnit, registrationWorks) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ASSERT_TRUE(broker.Team()->empty()); + ASSERT_TRUE(broker.endpoints.empty()); + + ClientID client1; + ASSERT_NE(invalidClientId, client1 = broker.Register("1")); + + ASSERT_EQ(1u, broker.Team()->size()); + ASSERT_TRUE(broker.endpoints.empty()); + ASSERT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_EQ("1", broker.Team()->at("1")->address); + EXPECT_EQ("1", broker.Team()->at("1")->name); + + ClientID client2; + ASSERT_NE(invalidClientId, client2 = broker.Register("2")); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_TRUE(broker.endpoints.empty()); + ASSERT_NE(broker.Team()->end(), broker.Team()->find("2")); + EXPECT_EQ("2", broker.Team()->at("2")->address); + EXPECT_EQ("2", broker.Team()->at("2")->name); + + EndpointID endpoint11; + ASSERT_NE(invalidEndpointId, endpoint11 = broker.Bind(client1, "1:1")); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(1u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_EQ(1u, broker.endpoints.at("1:1").size()); + EXPECT_EQ("1", broker.endpoints.at("1:1")[0].address); + + EndpointID endpoint12; + ASSERT_NE(invalidEndpointId, endpoint12 = broker.Bind(client1, "1:2")); + EXPECT_NE(endpoint12, endpoint11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:2")); + ASSERT_EQ(1u, broker.endpoints.at("1:2").size()); + EXPECT_EQ("1", broker.endpoints.at("1:2")[0].address); + + EndpointID endpoint21; + ASSERT_NE(invalidEndpointId, endpoint21 = broker.Bind(client2, "2:1")); + EXPECT_NE(endpoint12, endpoint21); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(3u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + ASSERT_EQ(1u, broker.endpoints.at("2:1").size()); + EXPECT_EQ("2", broker.endpoints.at("2:1")[0].address); + + EndpointID endpoint22; + ASSERT_NE(invalidEndpointId, endpoint22 = broker.Bind(client2, "2:2")); + EXPECT_NE(endpoint22, endpoint11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(4u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:2")); + ASSERT_EQ(1u, broker.endpoints.at("2:2").size()); + EXPECT_EQ("2", broker.endpoints.at("2:2")[0].address); + + // The broker library doesn't check whether endpoints and client addresses + // match + EndpointID endpointWeird; + ASSERT_NE(invalidEndpointId, endpointWeird = broker.Bind(client1, "2:2")); + EXPECT_NE(endpoint22, endpointWeird); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(4u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:2")); + ASSERT_EQ(2u, broker.endpoints.at("2:2").size()); + EXPECT_EQ("1", broker.endpoints.at("2:2")[1].address); + + // make sure broadcast endpoints work + EndpointID endpointBcast11; + ASSERT_NE(invalidEndpointId, + endpointBcast11 = broker.Bind(client1, "broadcast:1")); + EXPECT_NE(endpointWeird, endpointBcast11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(1u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("1", broker.endpoints.at("broadcast:1")[0].address); + + EndpointID endpointBcast21; + ASSERT_NE(invalidEndpointId, + endpointBcast21 = broker.Bind(client2, "broadcast:1")); + EXPECT_NE(endpointBcast21, endpointBcast11); + + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(2u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("2", broker.endpoints.at("broadcast:1")[1].address); + + // test multiple binds to the same endpoint from the same client + EndpointID endpointBcast111; + ASSERT_NE(invalidEndpointId, + endpointBcast111 = broker.Bind(client1, "broadcast:1")); + EXPECT_NE(endpointBcast21, endpointBcast111); + + EXPECT_NE(endpointBcast11, endpointBcast111); + ASSERT_EQ(2u, broker.Team()->size()); + ASSERT_EQ(5u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("broadcast:1")); + ASSERT_EQ(2u, broker.endpoints.at("broadcast:1").size()); + EXPECT_EQ("1", broker.endpoints.at("broadcast:1")[0].address); +} + +TEST(brokerUnit, unbind) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client11 = broker.Register("1"); + ClientID client2 = broker.Register("2"); + ClientID client12 = broker.Register("1"); + + ASSERT_NE(invalidClientId, client11); + ASSERT_NE(invalidClientId, client2); + ASSERT_NE(invalidClientId, client12); + + EndpointID client11_ep11_1 = broker.Bind(client11, "1:1"); + EndpointID client11_ep11_2 = broker.Bind(client11, "1:1"); + EndpointID client12_ep11_1 = broker.Bind(client12, "1:1"); + EndpointID client12_ep11_2 = broker.Bind(client12, "1:1"); + EndpointID client2_ep21_1 = broker.Bind(client2, "2:1"); + EndpointID client2_ep21_2 = broker.Bind(client2, "2:1"); + EndpointID client2_ep11_1 = broker.Bind(client2, "1:1"); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client2_ep21_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + + // cannot unbind it twice + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + + EXPECT_TRUE(broker.Unbind(client2_ep21_2)); + + // the endpoint key remains, but it is empty + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client11_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client11_ep11_2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client2_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client12_ep11_1)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_TRUE(broker.Unbind(client12_ep11_2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(0u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + + EXPECT_FALSE(broker.Unbind(client11_ep11_1)); + EXPECT_FALSE(broker.Unbind(client11_ep11_2)); + EXPECT_FALSE(broker.Unbind(client12_ep11_1)); + EXPECT_FALSE(broker.Unbind(client12_ep11_2)); + EXPECT_FALSE(broker.Unbind(client2_ep11_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_2)); +} + + +TEST(brokerUnit, unregisterUnbinds) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client11 = broker.Register("1"); + ClientID client2 = broker.Register("2"); + ClientID client12 = broker.Register("1"); + + ASSERT_NE(invalidClientId, client11); + ASSERT_NE(invalidClientId, client2); + ASSERT_NE(invalidClientId, client12); + + EndpointID client11_ep11_1 = broker.Bind(client11, "1:1"); + EndpointID client11_ep11_2 = broker.Bind(client11, "1:1"); + EndpointID client12_ep11_1 = broker.Bind(client12, "1:1"); + EndpointID client12_ep11_2 = broker.Bind(client12, "1:1"); + EndpointID client2_ep21_1 = broker.Bind(client2, "2:1"); + EndpointID client2_ep21_2 = broker.Bind(client2, "2:1"); + EndpointID client2_ep11_1 = broker.Bind(client2, "1:1"); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client12)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(2u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client11)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(1u, broker.endpoints["1:1"].size()); + EXPECT_EQ(1u, broker.endpoints["2:1"].size()); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_NE(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_TRUE(broker.Unregister(client2)); + + ASSERT_EQ(2u, broker.endpoints.size()); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("1:1")); + ASSERT_NE(broker.endpoints.end(), broker.endpoints.find("2:1")); + EXPECT_EQ(0u, broker.endpoints["1:1"].size()); + EXPECT_EQ(0u, broker.endpoints["2:1"].size()); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("1")); + EXPECT_EQ(broker.Team()->end(), broker.Team()->find("2")); + + EXPECT_FALSE(broker.Unregister(client11)); + EXPECT_FALSE(broker.Unregister(client12)); + EXPECT_FALSE(broker.Unregister(client2)); + + EXPECT_FALSE(broker.Unbind(client11_ep11_1)); + EXPECT_FALSE(broker.Unbind(client11_ep11_2)); + EXPECT_FALSE(broker.Unbind(client12_ep11_1)); + EXPECT_FALSE(broker.Unbind(client12_ep11_2)); + EXPECT_FALSE(broker.Unbind(client2_ep11_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_1)); + EXPECT_FALSE(broker.Unbind(client2_ep21_2)); +} + +TEST(brokerUnit, sendRequiresRegisterAndBind) +{ + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ClientID client2; + ASSERT_NE(invalidClientId, client2 = broker.Register("2")); + + subt::msgs::Datagram msg; + msg.set_src_address("1"); + msg.set_dst_address("2"); + msg.set_dst_port(42); + msg.set_rssi(-30); + + // sender has to be registered + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // endpoint has to be registered + msg.set_src_address("2"); + msg.set_dst_address("1"); + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // destination has to be registered + ClientID client1; + EXPECT_NE(invalidClientId, client1 = broker.Register("1")); + broker.OnMessage(msg); + EXPECT_FALSE(broker.DispatchMessages()); + + // endpoint has to be bound + ASSERT_NE(invalidEndpointId, broker.Bind(client1, "1:42")); + broker.OnMessage(msg); + EXPECT_TRUE(broker.DispatchMessages()); +} + int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); diff --git a/subt_msgs/srv/Bind.srv b/subt_msgs/srv/Bind.srv index 7f8a4e9f..af2edae6 100644 --- a/subt_msgs/srv/Bind.srv +++ b/subt_msgs/srv/Bind.srv @@ -1,6 +1,6 @@ # Bind -string address +uint32 client_id string endpoint ---- -bool success +uint32 endpoint_id diff --git a/subt_msgs/srv/Register.srv b/subt_msgs/srv/Register.srv index cb7a0f7d..d2a9d946 100644 --- a/subt_msgs/srv/Register.srv +++ b/subt_msgs/srv/Register.srv @@ -2,4 +2,4 @@ string local_address ---- -bool success +uint32 client_id diff --git a/subt_msgs/srv/Unregister.srv b/subt_msgs/srv/Unregister.srv index d7b7bb5b..9a2e25be 100644 --- a/subt_msgs/srv/Unregister.srv +++ b/subt_msgs/srv/Unregister.srv @@ -1,5 +1,5 @@ # Unregister -string local_address +uint32 client_id ---- bool success diff --git a/subt_ros/CMakeLists.txt b/subt_ros/CMakeLists.txt index aa4d25f5..661fe207 100644 --- a/subt_ros/CMakeLists.txt +++ b/subt_ros/CMakeLists.txt @@ -108,3 +108,39 @@ install(DIRECTORY launch install(PROGRAMS scripts/rostopic_stats_logger.sh DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) + +if(CATKIN_ENABLE_TESTING) + add_executable(test_broker test/test_broker.cc) + target_link_libraries(test_broker ${catkin_LIBRARIES} ignition-msgs6) + add_dependencies(tests test_broker) + + find_package(rostest REQUIRED) + find_package(rosgraph_msgs REQUIRED) + include_directories(${rosgraph_msgs_INCLUDE_DIRS}) + include_directories(${GTEST_INCLUDE_DIRS}) + + # we do not use add_rostest_gtest() because that doesn't allow for multiple + # independent targets/binaries to be built... and it is desirable that the + # tests are in separate binaries, because then rostest tears down all the + # helper nodes between switching to the other test node... on the other hand, + # keeping them all under one rostest command serializes the execution of the + # tests (while three separate rostests would run in parallel). + set(comms_types unicast multicast broadcast) + set(comms_targets "") + foreach(comms_type IN LISTS comms_types) + set(target test_comms_client_${comms_type}) + add_executable(${target} EXCLUDE_FROM_ALL test/comms_relay_${comms_type}.cc) + target_link_libraries(${target} + ${catkin_LIBRARIES} + ${rosgraph_msgs_LIBRARIES} + ${GTEST_LIBRARIES} + ignition-common3::ignition-common3 + ) + add_dependencies(tests ${target}) + list(APPEND comms_targets ${target}) + endforeach() + + add_rostest(test/comms_relay.test + DEPENDENCIES subt_ros_relay test_broker ${comms_targets} + ) +endif() diff --git a/subt_ros/package.xml b/subt_ros/package.xml index a759b1c2..b646806c 100644 --- a/subt_ros/package.xml +++ b/subt_ros/package.xml @@ -27,6 +27,9 @@ topic_tools std_msgs + rosgraph_msgs + rostest + diff --git a/subt_ros/src/SubtRosRelay.cc b/subt_ros/src/SubtRosRelay.cc index e8cc7657..7383caf4 100644 --- a/subt_ros/src/SubtRosRelay.cc +++ b/subt_ros/src/SubtRosRelay.cc @@ -14,6 +14,9 @@ * limitations under the License. * */ + +#include + #include #include #include @@ -29,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -105,12 +109,18 @@ class SubtRosRelay /// The received message is added to a message queue to be handled by a /// separate thread. /// \param[in] _msg The message. - public: void OnMessage(const subt::msgs::Datagram &_msg); + /// \param[in] _resolvedAddress The real destination of the message (i.e. + /// broadcast and multicast addresses resolved to a real address). + public: void OnMessage(const subt::msgs::Datagram &_msg, + const std::string& _resolvedAddress); /// \brief Process messages in consumed from the message queue - /// The message is forwarded via a ROS service call. - /// \param[in] _req The message. - public: void ProcessMessage(const subt::msgs::Datagram &_req); + /// The message is forwarded via a ROS topic. + /// \param[in] _msg A pair containing the message and the real destination of + /// the message (i.e. broadcast and multicast addresses resolved to a real + /// address). + public: void ProcessMessage( + std::pair &_msg); /// \brief Creates an AsyncSpinner and handles received messages public: void Spin(); @@ -149,6 +159,14 @@ class SubtRosRelay /// the origin artifact. public: ros::ServiceServer poseFromArtifactService; + /// \brief This is a mutex protecting registeredClients and boundAddresses. + public: std::mutex clientsMutex; + + /// \brief List of clients that called the registration service and did not + /// (yet) unregister. + public: std::unordered_map + registeredClients; + /// \brief The set of bound address. This is bookkeeping that helps /// to reduce erroneous error output in the ::Bind function. public: std::set boundAddresses; @@ -156,7 +174,8 @@ class SubtRosRelay /// \brief Lock free queue for holding msgs from Transport. This is needed to /// avoid deadlocks between the Transport thread that invokes callbacks and /// the main thread that handles messages. - public: boost::lockfree::spsc_queue msgQueue{10}; + public: boost::lockfree::spsc_queue< + std::pair> msgQueue{10}; /// \brief This mutex is used in conjunction with notifyCond to notify the /// main thread the arrival of new messages. @@ -173,6 +192,9 @@ class SubtRosRelay /// \brief Pointer to the ROS bag recorder. public: std::unique_ptr rosRecorder; + + /// \brief Publishers to /address/comms for each robot. Indexed by robot name. + public: std::unordered_map commsPublishers; }; ////////////////////////////////////////////////// @@ -249,11 +271,17 @@ SubtRosRelay::SubtRosRelay() this->bagThread.reset(new std::thread([&](){ this->rosRecorder->run(); })); + + ROS_INFO_STREAM("Running SubT ROS relay on Ign Partition '" + << this->node.Options().Partition() << "' and ROS master '" + << ros::master::getURI() << "'."); } ////////////////////////////////////////////////// SubtRosRelay::~SubtRosRelay() { + if (this->bagThread->joinable()) + this->bagThread->join(); } ///////////////////////////////////////////////// @@ -380,40 +408,73 @@ bool SubtRosRelay::OnPoseFromArtifact( bool SubtRosRelay::OnBind(subt_msgs::Bind::Request &_req, subt_msgs::Bind::Response &_res) { - if (std::find(this->robotNames.begin(), this->robotNames.end(), - _req.address) ==this->robotNames.end()) + std::string address; + { + std::unique_lock lock(this->clientsMutex); + if (this->registeredClients.find(_req.client_id) == + this->registeredClients.end()) + { + ROS_ERROR("Trying to bind on a client that has not been registered"); + return false; + } + else + { + address = this->registeredClients[_req.client_id]; + } + } + + if (address.empty()) + { + ROS_ERROR_STREAM("OnBind requested to bind on an invalid client.\n"); + return false; + } + + if (std::find(this->robotNames.begin(), this->robotNames.end(), address) == + this->robotNames.end()) { ROS_ERROR_STREAM("OnBind address does not match origination. Attempted " - << "impersonation of a robot as robot[" << _req.address << "].\n"); + << "impersonation of a robot as robot[" << address << "].\n"); return false; } - ignition::msgs::StringMsg_V req; - req.add_data(_req.address); - req.add_data(_req.endpoint); + subt::msgs::EndpointRegistration req; + req.set_client_id(_req.client_id); + req.set_endpoint(_req.endpoint); const unsigned int timeout = 3000u; - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; bool result; bool executed = this->node.Request( subt::communication_broker::kEndPointRegistrationSrv, req, timeout, rep, result); - _res.success = result; + _res.endpoint_id = subt::communication_broker::invalidEndpointId; + if (executed && result) + _res.endpoint_id = rep.data(); if (executed && result && // Only establish the Ignition service once per client. - this->boundAddresses.find(_req.address) == this->boundAddresses.end()) + this->boundAddresses.find(address) == this->boundAddresses.end()) { - if (!this->node.Advertise(_req.address, &SubtRosRelay::OnMessage, this)) + std::function cb = + [this, address] (const subt::msgs::Datagram& _msg) + { + this->OnMessage(_msg, address); + }; + + if (!this->node.Advertise(address, cb)) { - std::cerr << "Bind Error: could not advertise " - << _req.address << std::endl; + std::cerr << "Bind Error: could not advertise " << address << std::endl; return false; } else { - this->boundAddresses.insert(_req.address); + this->boundAddresses.insert(address); + + // Prepare the ROS comms publisher + this->commsPublishers[address] = + this->rosnode->advertise( + "/" + address + "/comms", 1000); } } @@ -449,7 +510,7 @@ bool SubtRosRelay::OnRegister(subt_msgs::Register::Request &_req, ignition::msgs::StringMsg req; req.set_data(_req.local_address); - ignition::msgs::Boolean rep; + ignition::msgs::UInt32 rep; bool result; const unsigned int timeout = 3000u; @@ -457,16 +518,37 @@ bool SubtRosRelay::OnRegister(subt_msgs::Register::Request &_req, subt::communication_broker::kAddrRegistrationSrv, req, timeout, rep, result); - _res.success = result; - return executed; + if (executed && result && rep.data() != + subt::communication_broker::invalidClientId) + { + _res.client_id = rep.data(); + { + std::unique_lock lock(this->clientsMutex); + this->registeredClients[_res.client_id] = _req.local_address; + } + return true; + } + + _res.client_id = subt::communication_broker::invalidClientId; + return false; } ////////////////////////////////////////////////// bool SubtRosRelay::OnUnregister(subt_msgs::Unregister::Request &_req, subt_msgs::Unregister::Response &_res) { - ignition::msgs::StringMsg req; - req.set_data(_req.local_address); + { + std::unique_lock lock(this->clientsMutex); + if (this->registeredClients.find(_req.client_id) == + this->registeredClients.end()) + { + ROS_ERROR("Trying to unregister a client that has not been registered"); + return false; + } + } + + ignition::msgs::UInt32 req; + req.set_data(_req.client_id); ignition::msgs::Boolean rep; bool result; @@ -476,41 +558,47 @@ bool SubtRosRelay::OnUnregister(subt_msgs::Unregister::Request &_req, subt::communication_broker::kAddrUnregistrationSrv, req, timeout, rep, result); - _res.success = result; + _res.success = executed && result && rep.data(); + + if (_res.success) + { + std::unique_lock lock(this->clientsMutex); + this->registeredClients.erase(_req.client_id); + } return executed; } ////////////////////////////////////////////////// -void SubtRosRelay::OnMessage(const subt::msgs::Datagram &_req) +void SubtRosRelay::OnMessage(const subt::msgs::Datagram &_req, + const std::string& _resolvedAddress) { - this->msgQueue.push(_req); + this->msgQueue.push(std::make_pair(_req, _resolvedAddress)); // Notify the main thread this->notifyCond.notify_one(); } ////////////////////////////////////////////////// -void SubtRosRelay::ProcessMessage(const subt::msgs::Datagram &_req) +void SubtRosRelay::ProcessMessage( + std::pair &_msg) { - subt_msgs::DatagramRos::Request req; - subt_msgs::DatagramRos::Response res; - req.src_address = _req.src_address(); - req.dst_address = _req.dst_address(); - req.dst_port = _req.dst_port(); - req.data = _req.data(); - req.rssi = _req.rssi(); - - if (_req.dst_address() == subt::communication_broker::kBroadcast) - { - for (const std::string &dest : this->robotNames) - { - ros::service::call(dest, req, res); - } - } - else - { - ros::service::call(_req.dst_address(), req, res); - } + const auto& datagram = _msg.first; + const auto& resolvedAddress = _msg.second; + + subt_msgs::DatagramRos::Request rosMsg; + rosMsg.src_address = datagram.src_address(); + rosMsg.dst_address = datagram.dst_address(); + rosMsg.dst_port = datagram.dst_port(); + rosMsg.data = datagram.data(); + rosMsg.rssi = datagram.rssi(); + + // We can be sure resolvedAddress exists in the map because + // it was initialized in OnBind(), where OnMessage() is also registered + // as a service handler that fills the message queue, which, in turn, gets + // processed by this function. + // Broadcast messages get handled by binding a broadcast endpoint by + // each respective client. + this->commsPublishers[resolvedAddress].publish(rosMsg); } ////////////////////////////////////////////////// @@ -527,7 +615,9 @@ void SubtRosRelay::Spin() // that the lock is released before calling `ProcessMessage` to avoid // deadlocks. std::unique_lock lock(this->notifyMutex); - this->notifyCond.wait(lock, + // add timeout to the wait to allow graceful exit when no messages are + // coming + this->notifyCond.wait_for(lock, std::chrono::seconds(1), [this] { return this->msgQueue.read_available(); }); } diff --git a/subt_ros/test/comms_relay.test b/subt_ros/test/comms_relay.test new file mode 100644 index 00000000..8f051d20 --- /dev/null +++ b/subt_ros/test/comms_relay.test @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/subt_ros/test/comms_relay_broadcast.cc b/subt_ros/test/comms_relay_broadcast.cc new file mode 100644 index 00000000..bdbe70f4 --- /dev/null +++ b/subt_ros/test/comms_relay_broadcast.cc @@ -0,0 +1,159 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, broadcast) +{ + + // Test communication on broadcast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient broadcaster("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, "", 123)); advanceTime(); + // c3 binds to a different port and thus should receive nothing + ASSERT_FALSE(!c3.Bind(c3_cb, "", 124)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), kBroadcast, 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c1.SendTo("Selftest", kBroadcast, 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(2u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/comms_relay_multicast.cc b/subt_ros/test/comms_relay_multicast.cc new file mode 100644 index 00000000..c4b06994 --- /dev/null +++ b/subt_ros/test/comms_relay_multicast.cc @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, multicast) +{ + // Test communication on multicast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient sender("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice, but once to multicast address and once on + // unicast + ASSERT_FALSE(!c1.Bind(c1_cb, kMulticast, 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, kMulticast, 123)); advanceTime(); + ASSERT_FALSE(!c3.Bind(c3_cb, "", 123)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + EXPECT_TRUE(sender.SendTo(oss.str(), kMulticast, 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(sentData.size(), receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[i]); + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c1.SendTo("Selftest", kMulticast, 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(1u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/comms_relay_unicast.cc b/subt_ros/test/comms_relay_unicast.cc new file mode 100644 index 00000000..3c91d8b4 --- /dev/null +++ b/subt_ros/test/comms_relay_unicast.cc @@ -0,0 +1,188 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace subt; +using namespace subt::communication_broker; + +std::shared_ptr nh; +ros::Publisher clockPub; +ros::Time currentTime {0, 0}; + +///////////////////////////////////////////////// +void advanceTime() +{ + currentTime += ros::Duration(1); + ros::Time::setNow(currentTime); + rosgraph_msgs::Clock msg; + msg.clock = currentTime; + clockPub.publish(msg); + ros::spinOnce(); + ros::WallDuration(0.01).sleep(); +} + +///////////////////////////////////////////////// +TEST(relay, unicast) +{ + + // Test communication on broadcast utilizing SubtRosRelay and a dummy broker. + + ros::Time::init(); + advanceTime(); + + // give other nodes time to spin up + ros::WallDuration(1).sleep(); + + CommsClient broadcaster("b", false, false, false, nh.get()); + advanceTime(); + CommsClient c1("c1", false, false, false, nh.get()); + advanceTime(); + CommsClient c2("c2", false , false, false, nh.get()); + advanceTime(); + CommsClient c3("c3", false , false, false, nh.get()); + advanceTime(); + + ROS_INFO_STREAM("Test running on master " << ros::master::getURI()); + + vector receivedData1, receivedData2; + auto c1_cb = [=,&receivedData1](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData1.emplace_back(data); + }; + auto c2_cb = [=,&receivedData2](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + receivedData2.emplace_back(data); + }; + auto c3_cb = [](const std::string& src, + const std::string& dst, + const uint32_t port, + const std::string& data) + { + GTEST_NONFATAL_FAILURE_("c3 callback should not have been called"); + }; + + // intentionally bind c1 twice to get twice as many data and verify one client + // can bind one port multiple times + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c1.Bind(c1_cb, "", 123)); advanceTime(); + ASSERT_FALSE(!c2.Bind(c2_cb, "", 123)); advanceTime(); + // c3 binds to a different port and thus should receive nothing + ASSERT_FALSE(!c3.Bind(c3_cb, "", 124)); advanceTime(); + + // give the bound endpoints time to set up + advanceTime(); // advance the test broker + ros::WallDuration(1).sleep(); + ros::spinOnce(); + + vector sentData; + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), "c1", 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(2u * sentData.size(), receivedData1.size()); + ASSERT_EQ(0u, receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData1[2*i]); + EXPECT_EQ(sentData[i], receivedData1[2*i + 1]); + } + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + for(unsigned int i=0; i < 15; ++i) + { + std::ostringstream oss; + oss << "Hello clients, " << i; + ASSERT_TRUE(broadcaster.SendTo(oss.str(), "c2", 123)); + sentData.emplace_back(oss.str()); + advanceTime(); // advance the test broker + ros::spinOnce(); + } + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_EQ(0u, receivedData1.size()); + ASSERT_EQ(sentData.size(), receivedData2.size()); + + for (size_t i = 0; i < sentData.size(); ++i) + { + EXPECT_EQ(sentData[i], receivedData2[i]); + } + + // selftest - test situation when a client sends a message to an endpoint it + // has also bound + + receivedData1.clear(); + receivedData2.clear(); + sentData.clear(); + + ros::spinOnce(); + advanceTime(); // advance the test broker + ros::spinOnce(); + + ASSERT_TRUE(c2.SendTo("Selftest", "c2", 123)); + + sentData.emplace_back("Selftest"); + advanceTime(); // advance the test broker + ros::spinOnce(); + + EXPECT_EQ(0u, receivedData1.size()); + EXPECT_EQ(1u, receivedData2.size()); +} + +///////////////////////////////////////////////// +int main(int argc, char **argv) +{ + ros::init(argc, argv, "test_comms_relay"); + nh.reset(new ros::NodeHandle); + clockPub = nh->advertise("/clock", 1, true); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/subt_ros/test/test_broker.cc b/subt_ros/test/test_broker.cc new file mode 100644 index 00000000..778a6624 --- /dev/null +++ b/subt_ros/test/test_broker.cc @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2021 Open Source Robotics Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * +*/ + +/// \brief This is a dummy message broker for tests. It has no message loss. + +#include +#include +#include +#include +#include + +using namespace subt; +using namespace subt::communication_model; +using namespace subt::rf_interface; +using namespace subt::rf_interface::range_model; +using namespace subt::communication_broker; + +///////////////////////////////////////////////// +void setDummyComms(Broker& broker) +{ + struct radio_configuration radio; + radio.pathloss_f = [](const double&, radio_state&, radio_state&) { + return rf_power(); + }; + + broker.SetDefaultRadioConfiguration(radio); + broker.SetCommunicationFunction(&subt::communication_model::attempt_send); + broker.SetPoseUpdateFunction( + [](const std::string& name) + { + return std::make_tuple(true, ignition::math::Pose3d::Zero, 0.0); + }); +} + +///////////////////////////////////////////////// +int main(int argc, char** argv) +{ + ros::init(argc, argv, "test_broker"); + // subscribe to /clock messages; this is needed because we do not have any + // NodeHandle + ros::start(); + + Broker broker; + setDummyComms(broker); + + broker.Start(); + + ROS_INFO("Broker is running in partition '%s'", broker.IgnPartition().c_str()); + + while (!ros::Time::waitForValid(ros::WallDuration(1.0))) + { + ros::spinOnce(); + ROS_WARN("Waiting for valid ROS time"); + } + + ros::Rate rate(1.0); + + while (ros::ok()) + { + broker.DispatchMessages(); + rate.sleep(); + ROS_DEBUG("Broker dispatched"); + } + + ROS_INFO("Broker exiting"); +} \ No newline at end of file From 20f3a10eccb5bb826a3a439084a896331fe06328 Mon Sep 17 00:00:00 2001 From: Martin Pecka Date: Tue, 2 Feb 2021 20:50:53 +0100 Subject: [PATCH 2/2] Fixed concurrency issues during Bind(). --- .../src/subt_communication_client.cpp | 11 ++--- subt_ros/src/SubtRosRelay.cc | 43 +++++++++++-------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp index 73c51d5e..b73634b0 100644 --- a/subt-communication/subt_communication_broker/src/subt_communication_client.cpp +++ b/subt-communication/subt_communication_broker/src/subt_communication_client.cpp @@ -241,6 +241,8 @@ CommsClient::Bind(std::functionuseIgnition) { + std::lock_guard lock(this->mutex); + // Ignition transport registers a datagram receive service, so we have to // make sure we only advertise it once if (!this->advertised) @@ -254,7 +256,7 @@ CommsClient::Bind(std::functionnode.Advertise(address, &CommsClient::OnMessage, this, opts)) { std::cerr << "[" << this->Host() << "] Bind Error: could not advertise " - << address << std::endl; + << address << " when binding " << unicastEndPoint << std::endl; // if we cannot advertise but we have already bound the endpoints, // we need to unbind them before exiting with error @@ -267,9 +269,8 @@ CommsClient::Bind(std::functionadvertised = true; } - - this->advertised = true; } // Register the callbacks. @@ -361,8 +362,8 @@ bool CommsClient::SendTo(const std::string &_data, // Sanity check: Make sure that the communications are enabled. if (!this->enabled || this->clientId == invalidClientId) { - std::cerr << "[" << this->localAddress << "] CommsClient::SendTo:" - << "Calling Unregister() before registering the client." + std::cerr << "[" << this->localAddress << "] CommsClient::SendTo: " + << "Calling SendTo() before registering the client." << std::endl; return false; } diff --git a/subt_ros/src/SubtRosRelay.cc b/subt_ros/src/SubtRosRelay.cc index a029207f..8f82d729 100644 --- a/subt_ros/src/SubtRosRelay.cc +++ b/subt_ros/src/SubtRosRelay.cc @@ -462,29 +462,34 @@ bool SubtRosRelay::OnBind(subt_msgs::Bind::Request &_req, if (executed && result) _res.endpoint_id = rep.data(); - if (executed && result && - // Only establish the Ignition service once per client. - this->boundAddresses.find(address) == this->boundAddresses.end()) { - std::function cb = - [this, address] (const subt::msgs::Datagram& _msg) - { - this->OnMessage(_msg, address); - }; + std::unique_lock lock(this->clientsMutex); - if (!this->node.Advertise(address, cb)) - { - std::cerr << "Bind Error: could not advertise " << address << std::endl; - return false; - } - else + if (executed && result && + // Only establish the Ignition service once per client. + this->boundAddresses.find(address) == this->boundAddresses.end()) { - this->boundAddresses.insert(address); + std::function cb = + [this, address](const subt::msgs::Datagram& _msg) + { + this->OnMessage(_msg, address); + }; - // Prepare the ROS comms publisher - this->commsPublishers[address] = - this->rosnode->advertise( - "/" + address + "/comms", 1000); + if (!this->node.Advertise(address, cb)) + { + ROS_ERROR_STREAM("Bind Error: could not advertise " << address << + " while binding " << _req.endpoint); + return false; + } + else + { + this->boundAddresses.insert(address); + + // Prepare the ROS comms publisher + this->commsPublishers[address] = + this->rosnode->advertise( + "/" + address + "/comms", 1000); + } } }