diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a6f9c1..7df3cd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.11) -project(Homa VERSION 0.1.1.0 LANGUAGES CXX) +project(Homa VERSION 0.1.2.0 LANGUAGES CXX) ################################################################################ ## Dependency Configuration #################################################### @@ -72,9 +72,10 @@ endif() ## lib Homa #################################################################### add_library(Homa + src/Bindings/CHoma.cc src/CodeLocation.cc src/Debug.cc - src/Homa.cc + src/Driver.cc src/Perf.cc src/Policy.cc src/Receiver.cc @@ -82,6 +83,8 @@ add_library(Homa src/StringUtil.cc src/ThreadId.cc src/TransportImpl.cc + src/Transports/PollModeTransportImpl.cc + src/Transports/Shenango.cc src/Util.cc ) add_library(Homa::Homa ALIAS Homa) @@ -261,6 +264,7 @@ add_executable(unit_test src/ThreadIdTest.cc src/TimeoutTest.cc src/TransportImplTest.cc + src/Transports/PollModeTransportImplTest.cc src/TubTest.cc src/UtilTest.cc ) diff --git a/include/Homa/Bindings/CHoma.h b/include/Homa/Bindings/CHoma.h new file mode 100644 index 0000000..bd307cb --- /dev/null +++ b/include/Homa/Bindings/CHoma.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file CHoma.h + * + * Contains C-bindings for the Homa Transport API. + */ + +#ifndef HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H +#define HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H + +#include "Homa/OutMessageStatus.h" + +#ifdef __cplusplus +#include +#include +extern "C" { +#else +#include +#include +#endif + +/** + * Define handle types for various Homa objects. + * + * A handle type is essentially a thin wrapper around an opaque pointer. + * Compared to generic pointers, using handle types in the C API enables + * some type safety. + */ +#define DEFINE_HOMA_OBJ_HANDLE(x) \ + typedef struct { \ + void* p; \ + } homa_##x; + +DEFINE_HOMA_OBJ_HANDLE(callbacks) /* Homa::Callbacks */ +DEFINE_HOMA_OBJ_HANDLE(driver) /* Homa::Driver */ +DEFINE_HOMA_OBJ_HANDLE(inmsg) /* Homa::InMessage */ +DEFINE_HOMA_OBJ_HANDLE(outmsg) /* Homa::OutMessage */ +DEFINE_HOMA_OBJ_HANDLE(trans) /* Homa::Transport */ + +/* ============================ */ +/* Homa::InMessage API */ +/* ============================ */ + +/** + * homa_inmsg_ack - C-binding for Homa::InMessage::acknowledge + */ +extern void homa_inmsg_ack(homa_inmsg in_msg); + +/** + * homa_inmsg_dropped - C-binding for Homa::InMessage::dropped + */ +extern bool homa_inmsg_dropped(homa_inmsg in_msg); + +/** + * homa_inmsg_fail - C-binding for Homa::InMessage::fail + */ +extern void homa_inmsg_fail(homa_inmsg in_msg); + +/** + * homa_inmsg_get - C-binding for Homa::InMessage::get + */ +extern size_t homa_inmsg_get(homa_inmsg in_msg, size_t ofs, void* dst, + size_t len); + +/** + * homa_inmsg_src_addr - C-binding for Homa::InMessage::getSourceAddress + */ +extern void homa_inmsg_src_addr(homa_inmsg in_msg, uint32_t* ip, + uint16_t* port); + +/** + * homa_inmsg_len - C-binding for Homa::InMessage::length + */ +extern size_t homa_inmsg_len(homa_inmsg in_msg); + +/** + * homa_inmsg_release - C-binding for Homa::InMessage::release + */ +extern void homa_inmsg_release(homa_inmsg in_msg); + +/** + * homa_inmsg_strip - C-binding for Homa::InMessage::strip + */ +extern void homa_inmsg_strip(homa_inmsg in_msg, size_t n); + +/* ============================ */ +/* Homa::OutMessage API */ +/* ============================ */ + +/** + * homa_outmsg_append - C-binding for Homa::OutMessage::append + */ +extern void homa_outmsg_append(homa_outmsg out_msg, const void* buf, + size_t len); + +/** + * homa_outmsg_cancel - C-binding for Homa::OutMessage::cancel + */ +extern void homa_outmsg_cancel(homa_outmsg out_msg); + +/** + * homa_outmsg_status - C-binding for Homa::OutMessage::getStatus + */ +extern int homa_outmsg_status(homa_outmsg out_msg); + +/** + * homa_outmsg_prepend - C-binding for Homa::OutMessage::prepend + */ +extern void homa_outmsg_prepend(homa_outmsg out_msg, const void* buf, + size_t len); + +/** + * homa_outmsg_reserve - C-binding for Homa::OutMessage::reserve + */ +extern void homa_outmsg_reserve(homa_outmsg out_msg, size_t n); + +/** + * homa_outmsg_send - C-binding for Homa::OutMessage::send + */ +extern void homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port); + +/** + * homa_outmsg_release - C-binding for Homa::OutMessage::release + */ +extern void homa_outmsg_release(homa_outmsg out_msg); + +/* ================================ */ +/* Homa::TransportBase API */ +/* ================================ */ + +/** + * homa_trans_create - C-binding for Homa::TransportBase::create + */ +extern homa_trans homa_trans_create(homa_driver drv, homa_callbacks cbs, + uint64_t id); + +/** + * homa_trans_free - C-binding for Homa::TransportBase::free + */ +extern void homa_trans_free(homa_trans trans); + +/** + * homa_trans_alloc - C-binding for Homa::TransportBase::alloc + */ +extern homa_outmsg homa_trans_alloc(homa_trans trans, uint16_t port); + +/** + * homa_trans_get_drv - C-binding for Homa::TransportBase::getDriver + */ +extern homa_driver homa_trans_get_drv(homa_trans trans); + +/** + * homa_trans_id - C-binding for Homa::TransportBase::getId + */ +extern uint64_t homa_trans_id(homa_trans trans); + +/* ================================ */ +/* Homa::Core::Transport API */ +/* ================================ */ + +/** + * homa_trans_check_timeouts - C-binding for Core::Transport::checkTimeouts + */ +extern uint64_t homa_trans_check_timeouts(homa_trans trans); + +/** + * homa_trans_proc - C-binding for Core::Transport::processPacket + */ +extern void homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, + int32_t len, uint32_t src_ip); + +/** + * homa_trans_try_send - C-binding for Core::Transport::trySend + */ +extern uint64_t homa_trans_try_send(homa_trans trans); + +/** + * homa_trans_try_grant - C-binding for Core::Transport::trySendGrants + */ +extern bool homa_trans_try_grant(homa_trans trans); + +#ifdef __cplusplus +} +#endif + +#endif // HOMA_INCLUDE_HOMA_BINDINGS_CHOMA_H \ No newline at end of file diff --git a/include/Homa/Core/Transport.h b/include/Homa/Core/Transport.h new file mode 100644 index 0000000..f7b06d8 --- /dev/null +++ b/include/Homa/Core/Transport.h @@ -0,0 +1,173 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Homa/Core/Transport.h + * + * Contains the low-level Homa Transport API. Advanced users of the Homa + * Transport library should include this header. + */ + +#ifndef HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H +#define HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H + +#include + +namespace Homa::Core { + +/** + * Minimal set of low-level API that can be used to create Homa-based transports + * for different runtime environments (e.g. polling, kernel threading, + * green threads, etc). + * + * The execution of a transport is driven through repeated calls to methods + * like checkTimeouts(), processPacket(), trySend(), and trySendGrants(); the + * transport will not make any progress otherwise. Advanced users can compose + * these methods in a way that suits them best. + * + * This class is thread-safe. + */ +class Transport : public TransportBase { + public: + /** + * Collection of user-defined transport callbacks. + */ + class Callbacks { + public: + /** + * Destructor. + */ + virtual ~Callbacks() = default; + + /** + * Invoked when an incoming message arrives and needs to dispatched to + * its destination in the user application for processing. + * + * Here are a few example use cases of this callback: + *
    + *
  • + * Interaction with the user's thread scheduler: e.g., an application + * may want to block on receive until a message is delivered, so this + * method can be used to wake up blocking threads. + *
  • + * High-performance message dispatch: e.g., an application may choose + * to implement the message receive queue with a concurrent MPMC queue + * as opposed to a linked-list protected by a mutex; + *
  • + * Lightweight synchronization: e.g., the socket table that maps port + * numbers to sockets is a read-mostly data structure, so lookup + * operations can benefit from synchronization schemes such as RCU. + *
+ * + * @param port + * Destination port number of the message. + * @param message + * Incoming message to dispatch. + * @return + * True if the message is delivered successfully; false, otherwise. + */ + virtual bool deliver(uint16_t port, + Homa::unique_ptr message) = 0; + + /** + * Invoked when some packets just became ready to be sent (and there was + * none before). + * + * This callback allows the transport library to notify the users that + * trySend() should be invoked again as soon as possible. For example, + * the callback can be used to implement wakeup signals for the thread + * that is responsible for calling trySend(), if this thread decides to + * sleep when there is no packets to send. + */ + virtual void notifySendReady() {} + }; + + /** + * Return a new instance of a Homa-based transport. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param callbacks + * Collection of user-defined callbacks to customize the behavior of + * the transport. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + * @return + * Pointer to the new transport instance. + */ + static Homa::unique_ptr create(Driver* driver, + Callbacks* callbacks, + uint64_t transportId); + + /** + * Process any timeouts that have expired. + * + * This method must be called periodically to ensure timely handling of + * expired timeouts. + * + * @return + * The rdtsc cycle time when this method should be called again. + */ + virtual uint64_t checkTimeouts() = 0; + + /** + * Return the driver that this transport uses to send and receive packets. + */ + virtual Driver* getDriver() = 0; + + /** + * Handle an ingress packet by running it through the transport protocol + * stack. + * + * @param packet + * The ingress packet. + * @param source + * IpAddress of the socket from which the packet is sent. + */ + virtual void processPacket(Driver::Packet* packet, IpAddress source) = 0; + + /** + * Attempt to send out packets for any messages with unscheduled/granted + * bytes in a way that limits queue buildup in the NIC. + * + * This method must be called eagerly to allow the Transport to make + * progress toward sending outgoing messages. + * + * @return + * The rdtsc cycle time when this method should be called again to + * transmit the rest of the packets (this allows the NIC to drain its + * transmit queue first), or zero if there is no more packets to send. + */ + virtual uint64_t trySend() = 0; + + /** + * Attempt to grant to incoming messages according to the Homa protocol. + * + * This method must be called eagerly to allow the Transport to make + * progress toward receiving incoming messages. For example, a user may + * invoke this method every time the transport finishes processing a batch + * of incoming packets. + * + * @return + * True if the method has found some messages to grant; false, + * otherwise. + */ + virtual bool trySendGrants() = 0; +}; + +} // namespace Homa::Core + +#endif // HOMA_INCLUDE_HOMA_CORE_TRANSPORT_H \ No newline at end of file diff --git a/include/Homa/Driver.h b/include/Homa/Driver.h index ecfe666..ef32e99 100644 --- a/include/Homa/Driver.h +++ b/include/Homa/Driver.h @@ -23,83 +23,73 @@ namespace Homa { /** - * Used by Homa::Transport to send and receive unreliable datagrams. Provides - * the interface to which all Driver implementations must conform. + * A simple wrapper struct around an IP address in binary format. * - * Implementations of this class should be thread-safe. + * This struct is meant to provide some type-safety when manipulating IP + * addresses. In order to avoid any runtime overhead, this struct contains + * nothing more than the IP address, so it is trivially copyable. */ -class Driver { - public: +struct IpAddress final { + /// IPv4 address in host byte order. + uint32_t addr; + /** - * Represents a Network address. - * - * Each Address representation is specific to the Driver instance that - * returned the it; they cannot be use interchangeably between different - * Driver instances. + * Unbox the IP address in binary format. */ - using Address = uint64_t; + explicit operator uint32_t() + { + return addr; + } /** - * Used to hold a driver's serialized byte-format for a network address. - * - * Each driver may define its own byte-format so long as fits within the - * bytes array. + * Equality function for IpAddress, for use in std::unordered_maps etc. */ - struct WireFormatAddress { - uint8_t type; ///< Can be used to distinguish between different wire - ///< address formats. - uint8_t bytes[19]; ///< Holds an Address's serialized byte-format. - } __attribute__((packed)); + bool operator==(const IpAddress& other) const + { + return addr == other.addr; + } /** - * Represents a packet of data that can be send or is received over the - * network. A Packet logically contains only the payload and not any Driver - * specific headers. - * - * A Packet may be Driver specific and should not used interchangeably - * between Driver instances or implementations. - * - * This class is NOT thread-safe but the Transport and Driver's use of - * Packet objects should be allow the Transport and the Driver to execute on - * different threads. + * This class computes a hash of an IpAddress, so that IpAddress can be used + * as keys in unordered_maps. */ - class Packet { - public: - /// Packet's source or destination. When sending a Packet, the address - /// field will contain the destination Address. When receiving a Packet, - /// address field will contain the source Address. - Address address; + struct Hasher { + /// Return a "hash" of the given IpAddress. + std::size_t operator()(const IpAddress& address) const + { + return std::hash{}(address.addr); + } + }; - /// Packet's network priority (send only); the lowest possible priority - /// is 0. The highest priority is positive number defined by the - /// Driver; the highest priority can be queried by calling the method - /// getHighestPacketPriority(). - int priority; + static std::string toString(IpAddress address); + static IpAddress fromString(const char* addressStr); +}; +static_assert(std::is_trivially_copyable()); + +/** + * Used by Homa::Core::Transport to send and receive unreliable datagrams. + * Provides the interface to which all Driver implementations must conform. + * + * Implementations of this class should be thread-safe. + */ +class Driver { + public: + /** + * Describes a packet of data that can be send or is received over the + * network. A Packet logically contains only the transport-layer (L4) Homa + * header in addition to application data. + */ + struct Packet { + /// Unique identifier of this Packet within the Driver. This descriptor + /// is entirely opaque to the transport. + uintptr_t descriptor; /// Pointer to an array of bytes containing the payload of this Packet. /// This array is valid until the Packet is released back to the Driver. - void* const payload; + void* payload; /// Number of bytes in the payload. - int length; - - /// Return the maximum number of bytes the payload can hold. - virtual int getMaxPayloadSize() = 0; - - protected: - /** - * Construct a Packet. - */ - explicit Packet(void* payload, int length = 0) - : address() - , priority(0) - , payload(payload) - , length(length) - {} - - // DISALLOW_COPY_AND_ASSIGN - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; + int32_t length; }; /** @@ -107,63 +97,16 @@ class Driver { */ virtual ~Driver() = default; - /** - * Return a Driver specific network address for the given string - * representation of the address. - * - * @param addressString - * The string representation of the address to return. The address - * string format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _addressString_ is malformed. - */ - virtual Address getAddress(std::string const* const addressString) = 0; - - /** - * Return a Driver specific network address for the given serialized - * byte-format of the address. - * - * @param wireAddress - * The serialized byte-format of the address to be returned. The - * format can be Driver specific. - * - * @return - * Address that can be the source or destination of a Packet. - * - * @throw BadAddress - * _rawAddress_ is malformed. - */ - virtual Address getAddress(WireFormatAddress const* const wireAddress) = 0; - - /** - * Return the string representation of a network address. - * - * @param address - * Address whose string representation should be returned. - */ - virtual std::string addressToString(const Address address) = 0; - - /** - * Serialize a network address into its Raw byte format. - * - * @param address - * Address to be serialized. - * @param[out] wireAddress - * WireFormatAddress object to which the Address is serialized. - */ - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) = 0; - /** * Allocate a new Packet object from the Driver's pool of resources. The * caller must eventually release the packet by passing it to a call to * releasePacket(). + * + * @param[out] packet + * Set to the description of the allocated packet when the method + * returns. */ - virtual Packet* allocPacket() = 0; + virtual void allocPacket(Packet* packet) = 0; /** * Send a packet over the network. @@ -187,8 +130,16 @@ class Driver { * * @param packet * Packet to be sent over the network. + * @param destination + * IP address of the packet destination. + * @param priority + * Packet's network priority; the lowest possible priority is 0. + * The highest priority is positive number defined by the Driver; + * the highest priority can be queried by calling the method + * getHighestPacketPriority(). */ - virtual void sendPacket(Packet* packet) = 0; + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority) = 0; /** * Request that the Driver enter the "corked" mode where outbound packets @@ -218,9 +169,12 @@ class Driver { * * @param maxPackets * The maximum number of Packet objects that should be returned by - * this method. + * this method. * @param[out] receivedPackets * Received packets are appended to this array in order of arrival. + * @param[out] sourceAddresses + * Source IP addresses of the received packets are appended to this + * array in order of arrival. * * @return * Number of Packet objects being returned. @@ -228,7 +182,8 @@ class Driver { * @sa Driver::releasePackets() */ virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]) = 0; + Packet receivedPackets[], + IpAddress sourceAddresses[]) = 0; /** * Release a collection of Packet objects back to the Driver. Every @@ -245,7 +200,7 @@ class Driver { * @param numPackets * Number of Packet objects in _packets_. */ - virtual void releasePackets(Packet* packets[], uint16_t numPackets) = 0; + virtual void releasePackets(Packet packets[], uint16_t numPackets) = 0; /** * Returns the highest packet priority level this Driver supports (0 is @@ -273,10 +228,10 @@ class Driver { virtual uint32_t getBandwidth() = 0; /** - * Return this Driver's local network Address which it uses as the source - * Address for outgoing packets. + * Return this Driver's local IP address which it uses as the source + * address for outgoing packets. */ - virtual Address getLocalAddress() = 0; + virtual IpAddress getLocalAddress() = 0; /** * Return the number of bytes that have been passed to the Driver through diff --git a/include/Homa/Drivers/DPDK/DpdkDriver.h b/include/Homa/Drivers/DPDK/DpdkDriver.h index dafb05f..fd2dd85 100644 --- a/include/Homa/Drivers/DPDK/DpdkDriver.h +++ b/include/Homa/Drivers/DPDK/DpdkDriver.h @@ -53,14 +53,14 @@ class DpdkDriver : public Driver { * has exclusive access to DPDK. Note: This call will initialize the DPDK * EAL with default values. * - * @param port - * Selects which physical port to use for communication. + * @param ifname + * Selects which network interface to use for communication. * @param config * Optional configuration parameters (see Config). * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, const Config* const config = nullptr); + DpdkDriver(const char* ifname, const Config* const config = nullptr); /** * Construct a DpdkDriver and initialize the DPDK EAL using the provided @@ -75,7 +75,7 @@ class DpdkDriver : public Driver { * overriding the default affinity set by rte_eal_init(). * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param argc * Parameter passed to rte_eal_init(). * @param argv @@ -85,7 +85,7 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, int argc, char* argv[], + DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); /// Used to signal to the DpdkDriver constructor that the DPDK EAL should @@ -101,7 +101,7 @@ class DpdkDriver : public Driver { * called before calling this constructor. * * @param port - * Selects which physical port to use for communication. + * Selects which network interface to use for communication. * @param _ * Parameter is used only to define this constructors alternate * signature. @@ -110,29 +110,20 @@ class DpdkDriver : public Driver { * @throw DriverInitFailure * Thrown if DpdkDriver fails to initialize for any reason. */ - DpdkDriver(int port, NoEalInit _, const Config* const config = nullptr); + DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config = nullptr); /** * DpdkDriver Destructor. */ virtual ~DpdkDriver(); - /// See Driver::getAddress() - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - - /// See Driver::addressToString() - virtual std::string addressToString(const Address address); - - /// See Driver::addressToWireFormat() - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - /// See Driver::allocPacket() - virtual Packet* allocPacket(); + virtual void allocPacket(Packet* packet); /// See Driver::sendPacket() - virtual void sendPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); /// See Driver::cork() virtual void cork(); @@ -142,10 +133,11 @@ class DpdkDriver : public Driver { /// See Driver::receivePackets() virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); + Packet receivedPackets[], + IpAddress sourceAddresses[]); /// See Driver::releasePackets() - virtual void releasePackets(Packet* packets[], uint16_t numPackets); + virtual void releasePackets(Packet packets[], uint16_t numPackets); /// See Driver::getHighestPacketPriority() virtual int getHighestPacketPriority(); @@ -157,7 +149,7 @@ class DpdkDriver : public Driver { virtual uint32_t getBandwidth(); /// See Driver::getLocalAddress() - virtual Driver::Address getLocalAddress(); + virtual IpAddress getLocalAddress(); /// See Driver::getQueuedBytes(); virtual uint32_t getQueuedBytes(); diff --git a/include/Homa/Drivers/Fake/FakeDriver.h b/include/Homa/Drivers/Fake/FakeDriver.h index 8413778..bba1fc2 100644 --- a/include/Homa/Drivers/Fake/FakeDriver.h +++ b/include/Homa/Drivers/Fake/FakeDriver.h @@ -34,7 +34,7 @@ const int NUM_PRIORITIES = 8; /// Maximum number of bytes a packet can hold. const uint32_t MAX_PAYLOAD_SIZE = 1500; -/// A set of methods to contol the underlying FakeNetwork's behavior. +/// A set of methods to control the underlying FakeNetwork's behavior. namespace FakeNetworkConfig { /** * Configure the FakeNetwork to drop packets at the specified loss rate. @@ -51,43 +51,45 @@ void setPacketLossRate(double lossRate); * * @sa Driver::Packet */ -class FakePacket : public Driver::Packet { - public: +struct FakePacket { + /// Raw storage for this packets payload. + char buf[MAX_PAYLOAD_SIZE]; + + /// Number of bytes in the payload. + int length; + + /// Source IpAddress of the packet. + IpAddress sourceIp; + /** * FakePacket constructor. - * - * @param maxPayloadSize - * The maximum number of bytes this packet can hold. */ explicit FakePacket() - : Packet(buf, 0) + : buf() + , length() + , sourceIp() {} /** * Copy constructor. */ FakePacket(const FakePacket& other) - : Packet(buf, other.length) + : buf() + , length(other.length) + , sourceIp() { - address = other.address; - priority = other.priority; memcpy(buf, other.buf, MAX_PAYLOAD_SIZE); } - virtual ~FakePacket() {} - - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() + /** + * Convert this FakePacket to a generic Driver packet representation. + */ + Driver::Packet toPacket() { - return MAX_PAYLOAD_SIZE; + Driver::Packet packet = { + .descriptor = (uintptr_t)this, .payload = buf, .length = length}; + return packet; } - - private: - /// Raw storage for this packets payload. - char buf[MAX_PAYLOAD_SIZE]; - - // Disable Assignment - FakePacket& operator=(const FakePacket&) = delete; }; /// Holds the incoming packets for a particular driver. @@ -117,25 +119,22 @@ class FakeDriver : public Driver { */ virtual ~FakeDriver(); - virtual Address getAddress(std::string const* const addressString); - virtual Address getAddress(WireFormatAddress const* const wireAddress); - virtual std::string addressToString(const Address address); - virtual void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - virtual Packet* allocPacket(); - virtual void sendPacket(Packet* packet); + virtual void allocPacket(Packet* packet); + virtual void sendPacket(Packet* packet, IpAddress destination, + int priority); virtual uint32_t receivePackets(uint32_t maxPackets, - Packet* receivedPackets[]); - virtual void releasePackets(Packet* packets[], uint16_t numPackets); + Packet receivedPackets[], + IpAddress sourceAddresses[]); + virtual void releasePackets(Packet packets[], uint16_t numPackets); virtual int getHighestPacketPriority(); virtual uint32_t getMaxPayloadSize(); virtual uint32_t getBandwidth(); - virtual Address getLocalAddress(); + virtual IpAddress getLocalAddress(); virtual uint32_t getQueuedBytes(); private: /// Identifier for this driver on the fake network. - uint64_t localAddressId; + uint32_t localAddressId; /// Holds the incoming packets for this driver. FakeNIC nic; diff --git a/include/Homa/Homa.h b/include/Homa/Homa.h index dec090c..ff9e8b9 100644 --- a/include/Homa/Homa.h +++ b/include/Homa/Homa.h @@ -24,19 +24,33 @@ #define HOMA_INCLUDE_HOMA_HOMA_H #include - -#include -#include -#include +#include namespace Homa { /** * Shorthand for an std::unique_ptr with a customized deleter. + * + * This is used to implement the "borrow" semantics for interface classes like + * InMessage, OutMessage, and Transport; that is, a user can obtain pointers to + * these objects and use them to make function calls, but the user must always + * return the objects back to the transport library eventually because the user + * has no idea how to destruct the objects or reclaim memory. */ template using unique_ptr = std::unique_ptr; +/** + * Represents a socket address to (from) which we can send (receive) messages. + */ +struct SocketAddress { + /// IPv4 address in host byte order. + IpAddress ip; + + /// Port number in host byte order. + uint16_t port; +}; + /** * Represents an array of bytes that has been received over the network. * @@ -92,6 +106,11 @@ class InMessage { virtual size_t get(size_t offset, void* destination, size_t count) const = 0; + /** + * Return the remote address from which this Message is sent. + */ + virtual SocketAddress getSourceAddress() const = 0; + /** * Return the number of bytes this Message contains. */ @@ -106,6 +125,12 @@ class InMessage { virtual void strip(size_t count) = 0; protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~InMessage() = default; + /** * Signal that this message is no longer needed. The caller should not * access this message following this call. @@ -123,14 +148,7 @@ class OutMessage { /** * Defines the possible states of an OutMessage. */ - enum class Status { - NOT_STARTED, //< The sending of this message has not started. - IN_PROGRESS, //< The message is in the process of being sent. - CANCELED, //< The message was canceled while still IN_PROGRESS. - SENT, //< The message has been completely sent. - COMPLETED, //< The message has been received and processed. - FAILED, //< The message failed to be delivered and processed. - }; + using Status = OutMessageStatus; /** * Options with which an OutMessage can be sent. @@ -220,14 +238,21 @@ class OutMessage { * Send this message to the destination. * * @param destination - * Address of the transport to which this message will be sent. + * Network address to which this message will be sent. * @param options * Flags to request non-default sending behavior. */ - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE) = 0; + // FIXME: this is problematic; we can't really call send a second time... protected: + /** + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. + */ + ~OutMessage() = default; + /** * Signal that this message is no longer needed. The caller should not * access this message following this call. @@ -236,66 +261,50 @@ class OutMessage { }; /** - * Provides a means of communicating across the network using the Homa protocol. - * - * The transport is used to send and receive messages across the network using - * the RemoteOp and ServerOp abstractions. The execution of the transport is - * driven through repeated calls to the Transport::poll() method; the transport - * will not make any progress otherwise. + * Basic transport API that are shared between the low-level and high-level + * transport interfaces. * * This class is thread-safe. */ -class Transport { +class TransportBase { public: /** - * Return a new instance of a Homa-based transport. - * - * The caller is responsible for calling free() on the returned pointer. - * - * @param driver - * Driver with which this transport should send and receive packets. - * @param transportId - * This transport's unique identifier in the group of transports among - * which this transport will communicate. - * @return - * Pointer to the new transport instance. + * Custom deleter for use with std::unique_ptr. */ - static Transport* create(Driver* driver, uint64_t transportId); + struct Deleter { + void operator()(TransportBase* transport) + { + transport->free(); + } + }; /** * Allocate Message that can be sent with this Transport. * + * @param port + * Port number of the socket from which the message will be sent. * @return * A pointer to the allocated message. */ - virtual Homa::unique_ptr alloc() = 0; - - /** - * Check for and return a Message sent to this Transport if available. - * - * @return - * Pointer to the received message, if any. Otherwise, nullptr is - * returned if no message has been delivered. - */ - virtual Homa::unique_ptr receive() = 0; + virtual Homa::unique_ptr alloc(uint16_t port) = 0; /** - * Make incremental progress performing all Transport functionality. - * - * This method MUST be called for the Transport to make progress and should - * be called frequently to ensure timely progress. + * Return this transport's unique identifier. */ - virtual void poll() = 0; + virtual uint64_t getId() = 0; + protected: /** - * Return the driver that this transport uses to send and receive packets. + * Use protected destructor to prevent users from calling delete on pointers + * to this interface. */ - virtual Driver* getDriver() = 0; + ~TransportBase() = default; /** - * Return this transport's unique identifier. + * Free this transport instance. No one should not access this transport + * following this call. */ - virtual uint64_t getId() = 0; + virtual void free() = 0; }; /** diff --git a/include/Homa/OutMessageStatus.h b/include/Homa/OutMessageStatus.h new file mode 100644 index 0000000..941ce02 --- /dev/null +++ b/include/Homa/OutMessageStatus.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H +#define HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H + +/** + * Defines the possible states of an OutMessage. + */ +#ifdef __cplusplus +enum class OutMessageStatus : int { +#else +enum homa_outmsg_status { +#endif + NOT_STARTED, //< The sending of this message has not started. + IN_PROGRESS, //< The message is in the process of being sent. + CANCELED, //< The message was canceled while still IN_PROGRESS. + SENT, //< The message has been completely sent. + COMPLETED, //< The message has been received and processed. + FAILED, //< The message failed to be delivered and processed. +}; + +#endif // HOMA_INCLUDE_HOMA_OUTMESSAGESTATUS_H \ No newline at end of file diff --git a/include/Homa/Transports/PollModeTransport.h b/include/Homa/Transports/PollModeTransport.h new file mode 100644 index 0000000..262c3d7 --- /dev/null +++ b/include/Homa/Transports/PollModeTransport.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H +#define HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H + +#include + +namespace Homa { + +/** + * A polling-based Homa transport implementation. + */ +class PollModeTransport : public TransportBase { + public: + /** + * Return a new instance of a polling-based Homa transport. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + * @return + * Pointer to the new transport instance. + */ + static Homa::unique_ptr create(Driver* driver, + uint64_t transportId); + + /** + * Make incremental progress performing all Transport functionality. + * + * This method MUST be called for the Transport to make progress and should + * be called frequently to ensure timely progress. + */ + virtual void poll() = 0; + + /** + * Check for and return a Message sent to this transport if available. + * + * @return + * Pointer to the received message, if any; otherwise, nullptr. + */ + virtual Homa::unique_ptr receive() = 0; +}; + +} // namespace Homa + +#endif // HOMA_INCLUDE_HOMA_TRANSPORTS_POLLMODETRANSPORT_H \ No newline at end of file diff --git a/include/Homa/Transports/Shenango.h b/include/Homa/Transports/Shenango.h new file mode 100644 index 0000000..40d7b5d --- /dev/null +++ b/include/Homa/Transports/Shenango.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2020 Stanford University + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR(S) DISCLAIM ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL AUTHORS BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +/** + * @file Homa/Transports/Shenango.h + * + * Contains the glue code for Homa-Shenango integration. This is the only + * header Shenango needs to include in order to use Homa transport. + * + * Shenango is an experimental operating system that aims to provide low tail + * latency and high CPU efficiency simultaneously for servers in datacenters. + * See for more information. + */ + +#ifndef HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H +#define HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * homa_create_shenango_trans - creates a transport instance that can be used by + * Shenango to send and receive messages. + * @id: Unique identifier for this transport instance + * @proto: Protocol number reserved for Homa transport protocol + * @local_ip: Local IP address of the driver + * @max_payload: Maximum number of bytes carried by the packet payload + * @link_speed: Effective network bandwidth, in Mbits/second + * @cb_send_ready: Callback function to invoke in Callbacks::notifySendReady + * @cb_data: Input data for @cb_send_ready + * + * Returns a handle to the callbacks created. + */ +extern homa_trans homa_create_shenango_trans(uint64_t id, + uint8_t proto, uint32_t local_ip, uint32_t max_payload, uint32_t link_speed, + void (*cb_send_ready)(void*), void* cb_data); + +/** + * homa_free_shenango_trans - frees a transport created earlier with + * @homa_create_shenango_trans. + * @param trans: the transport to free + */ +extern void homa_free_shenango_trans(homa_trans trans); + +#ifdef __cplusplus +} +#endif + +#endif // HOMA_INCLUDE_HOMA_TRANSPORTS_SHENANGO_H \ No newline at end of file diff --git a/include/Homa/Util.h b/include/Homa/Util.h index 121bb44..0f58cb7 100644 --- a/include/Homa/Util.h +++ b/include/Homa/Util.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2009-2018, Stanford University +/* Copyright (c) 2009-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above diff --git a/src/Bindings/CHoma.cc b/src/Bindings/CHoma.cc new file mode 100644 index 0000000..b75971e --- /dev/null +++ b/src/Bindings/CHoma.cc @@ -0,0 +1,178 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "Homa/Bindings/CHoma.h" +#include "Homa/Core/Transport.h" + +using namespace Homa; +using Transport = Core::Transport; + +/// Shorthand for converting C-style Homa object handle types back to C++ types. +#define deref(T, x) (*static_cast(x.p)) + +void +homa_inmsg_ack(homa_inmsg in_msg) +{ + deref(InMessage, in_msg).acknowledge(); +} + +bool +homa_inmsg_dropped(homa_inmsg in_msg) +{ + return deref(InMessage, in_msg).dropped(); +} + +void +homa_inmsg_fail(homa_inmsg in_msg) +{ + deref(InMessage, in_msg).fail(); +} + +size_t +homa_inmsg_get(homa_inmsg in_msg, size_t ofs, void* dst, size_t len) +{ + return deref(InMessage, in_msg).get(ofs, dst, len); +} + +void +homa_inmsg_src_addr(homa_inmsg in_msg, uint32_t* ip, uint16_t* port) +{ + SocketAddress src = deref(InMessage, in_msg).getSourceAddress(); + *ip = (uint32_t)src.ip; + *port = src.port; +} + +size_t +homa_inmsg_len(homa_inmsg in_msg) +{ + return deref(InMessage, in_msg).length(); +} + +void +homa_inmsg_release(homa_inmsg in_msg) +{ + InMessage::Deleter deleter; + deleter(&deref(InMessage, in_msg)); +} + +void +homa_inmsg_strip(homa_inmsg in_msg, size_t n) +{ + deref(InMessage, in_msg).strip(n); +} + +void +homa_outmsg_append(homa_outmsg out_msg, const void* buf, size_t len) +{ + deref(OutMessage, out_msg).append(buf, len); +} + +void +homa_outmsg_cancel(homa_outmsg out_msg) +{ + deref(OutMessage, out_msg).cancel(); +} + +int +homa_outmsg_status(homa_outmsg out_msg) +{ + return int(deref(OutMessage, out_msg).getStatus()); +} + +void +homa_outmsg_prepend(homa_outmsg out_msg, const void* buf, size_t len) +{ + deref(OutMessage, out_msg).prepend(buf, len); +} + +void +homa_outmsg_reserve(homa_outmsg out_msg, size_t n) +{ + deref(OutMessage, out_msg).reserve(n); +} + +void +homa_outmsg_send(homa_outmsg out_msg, uint32_t ip, uint16_t port) +{ + deref(OutMessage, out_msg).send({IpAddress{ip}, port}); +} + +void +homa_outmsg_release(homa_outmsg out_msg) +{ + OutMessage::Deleter deleter; + deleter(&deref(OutMessage, out_msg)); +} + +homa_trans +homa_trans_create(homa_driver drv, homa_callbacks cbs, uint64_t id) +{ + unique_ptr trans = Transport::create( + &deref(Driver, drv), &deref(Transport::Callbacks, cbs), id); + return homa_trans{trans.release()}; +} + +void +homa_trans_free(homa_trans trans) +{ + Transport::Deleter deleter; + deleter(&deref(Transport, trans)); +} + +homa_outmsg +homa_trans_alloc(homa_trans trans, uint16_t port) +{ + unique_ptr out_msg = deref(Transport, trans).alloc(port); + return homa_outmsg{out_msg.release()}; +} + +uint64_t +homa_trans_check_timeouts(homa_trans trans) +{ + return deref(Transport, trans).checkTimeouts(); +} + +uint64_t +homa_trans_id(homa_trans trans) +{ + return deref(Transport, trans).getId(); +} + +homa_driver homa_trans_get_drv(homa_trans trans) +{ + Driver* drv = deref(Transport, trans).getDriver(); + return homa_driver{drv}; +} + +void +homa_trans_proc(homa_trans trans, uintptr_t desc, void* payload, int32_t len, + uint32_t src_ip) +{ + Driver::Packet packet = { + .descriptor = desc, .payload = payload, .length = len}; + deref(Transport, trans).processPacket(&packet, IpAddress{src_ip}); +} + +uint64_t +homa_trans_try_send(homa_trans trans) +{ + return deref(Transport, trans).trySend(); +} + +bool +homa_trans_try_grant(homa_trans trans) +{ + return deref(Transport, trans).trySendGrants(); +} diff --git a/src/ControlPacket.h b/src/ControlPacket.h index a8da070..9c557ef 100644 --- a/src/ControlPacket.h +++ b/src/ControlPacket.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -31,21 +31,20 @@ namespace ControlPacket { * @param driver * Driver with which to send the packet. * @param address - * Destination address for the packet to be sent. + * Destination IP address for the packet to be sent. * @param args * Arguments to PacketHeaderType's constructor. */ template void -send(Driver* driver, Driver::Address address, Args&&... args) +send(Driver* driver, IpAddress address, Args&&... args) { - Driver::Packet* packet = driver->allocPacket(); - new (packet->payload) PacketHeaderType(static_cast(args)...); - packet->length = sizeof(PacketHeaderType); - packet->address = address; - packet->priority = driver->getHighestPacketPriority(); - Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + Driver::Packet packet; + driver->allocPacket(&packet); + new (packet.payload) PacketHeaderType(static_cast(args)...); + packet.length = sizeof(PacketHeaderType); + Perf::counters.tx_bytes.add(packet.length); + driver->sendPacket(&packet, address, driver->getHighestPacketPriority()); driver->releasePackets(&packet, 1); } diff --git a/src/Homa.cc b/src/Driver.cc similarity index 59% rename from src/Homa.cc rename to src/Driver.cc index e03b55e..b29c828 100644 --- a/src/Homa.cc +++ b/src/Driver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018-2019, Stanford University +/* Copyright (c) 2018-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -13,16 +13,26 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ -#include +#include -#include "TransportImpl.h" +#include "StringUtil.h" namespace Homa { -Transport* -Transport::create(Driver* driver, uint64_t transportId) +std::string +IpAddress::toString(IpAddress address) { - return new Core::TransportImpl(driver, transportId); + uint32_t ip = address.addr; + return StringUtil::format("%d.%d.%d.%d", (ip >> 24) & 0xff, + (ip >> 16) & 0xff, (ip >> 8) & 0xff, ip & 0xff); +} + +IpAddress +IpAddress::fromString(const char* addressStr) +{ + unsigned int b0, b1, b2, b3; + sscanf(addressStr, "%u.%u.%u.%u", &b0, &b1, &b2, &b3); + return IpAddress{(b0 << 24u) | (b1 << 16u) | (b2 << 8u) | b3}; } } // namespace Homa diff --git a/src/Drivers/DPDK/DpdkDriver.cc b/src/Drivers/DPDK/DpdkDriver.cc index c536159..16a7016 100644 --- a/src/Drivers/DPDK/DpdkDriver.cc +++ b/src/Drivers/DPDK/DpdkDriver.cc @@ -21,62 +21,34 @@ namespace Homa { namespace Drivers { namespace DPDK { -DpdkDriver::DpdkDriver(int port, const Config* const config) - : pImpl(new Impl(port, config)) +DpdkDriver::DpdkDriver(const char* ifname, const Config* const config) + : pImpl(new Impl(ifname, config)) {} -DpdkDriver::DpdkDriver(int port, int argc, char* argv[], +DpdkDriver::DpdkDriver(const char* ifname, int argc, char* argv[], const Config* const config) - : pImpl(new Impl(port, argc, argv, config)) + : pImpl(new Impl(ifname, argc, argv, config)) {} -DpdkDriver::DpdkDriver(int port, NoEalInit _, const Config* const config) - : pImpl(new Impl(port, _, config)) +DpdkDriver::DpdkDriver(const char* ifname, NoEalInit _, + const Config* const config) + : pImpl(new Impl(ifname, _, config)) {} DpdkDriver::~DpdkDriver() = default; -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(std::string const* const addressString) -{ - return pImpl->getAddress(addressString); -} - -/// See Driver::getAddress() -Driver::Address -DpdkDriver::getAddress(WireFormatAddress const* const wireAddress) -{ - return pImpl->getAddress(wireAddress); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::addressToString(const Address address) -{ - return pImpl->addressToString(address); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - pImpl->addressToWireFormat(address, wireAddress); -} - /// See Driver::allocPacket() -Driver::Packet* -DpdkDriver::allocPacket() +void +DpdkDriver::allocPacket(Packet* packet) { - return pImpl->allocPacket(); + return pImpl->allocPacket(packet); } /// See Driver::sendPacket() void -DpdkDriver::sendPacket(Packet* packet) +DpdkDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - return pImpl->sendPacket(packet); + return pImpl->sendPacket(packet, destination, priority); } /// See Driver::cork() @@ -95,13 +67,14 @@ DpdkDriver::uncork() /// See Driver::receivePackets() uint32_t -DpdkDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +DpdkDriver::receivePackets(uint32_t maxPackets, Packet receivedPackets[], + IpAddress sourceAddresses[]) { - return pImpl->receivePackets(maxPackets, receivedPackets); + return pImpl->receivePackets(maxPackets, receivedPackets, sourceAddresses); } /// See Driver::releasePackets() void -DpdkDriver::releasePackets(Packet* packets[], uint16_t numPackets) +DpdkDriver::releasePackets(Packet packets[], uint16_t numPackets) { pImpl->releasePackets(packets, numPackets); } @@ -128,7 +101,7 @@ DpdkDriver::getBandwidth() } /// See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::getLocalAddress() { return pImpl->getLocalAddress(); diff --git a/src/Drivers/DPDK/DpdkDriverImpl.cc b/src/Drivers/DPDK/DpdkDriverImpl.cc index e658ccb..a797f18 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.cc +++ b/src/Drivers/DPDK/DpdkDriverImpl.cc @@ -17,12 +17,18 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +#include +#include +#include +#include +#include + #include "DpdkDriverImpl.h" #include -#include #include "CodeLocation.h" +#include "Homa/Util.h" #include "StringUtil.h" namespace Homa { @@ -37,49 +43,63 @@ const int default_eal_argc = 1; const char* default_eal_argv[] = {"homa", NULL}; /** - * Construct a DPDK Packet backed by a DPDK mbuf. + * Construct a DPDK PacketBuf backed by a DPDK mbuf. * * @param mbuf * Pointer to the DPDK mbuf that holds this packet. * @param data * Memory location in the mbuf where the packet data should be stored. */ -DpdkDriver::Impl::Packet::Packet(struct rte_mbuf* mbuf, void* data) - : Driver::Packet(data, 0) +DpdkDriver::Impl::PacketBuf::PacketBuf(struct rte_mbuf* mbuf, void* data) + : payload(data) , bufType(MBUF) - , bufRef() -{ - bufRef.mbuf = mbuf; -} + , bufRef{.mbuf = mbuf} +{} /** - * Construct a DPDK Packet backed by an OverflowBuffer. + * Construct a DPDK PacketBuf backed by an OverflowBuffer. * * @param overflowBuf * Overflow buffer that holds this packet. */ -DpdkDriver::Impl::Packet::Packet(OverflowBuffer* overflowBuf) - : Driver::Packet(overflowBuf->data, 0) +DpdkDriver::Impl::PacketBuf::PacketBuf(OverflowBuffer* overflowBuf) + : payload(overflowBuf->data) , bufType(OVERFLOW_BUF) - , bufRef() + , bufRef{.overflowBuf = overflowBuf} +{} + +/** + * Convert this DPDK PacketBuf into the generic Driver::Packet representation. + * + * @param length + * Number of bytes used in the payload buffer. + */ +Driver::Packet +DpdkDriver::Impl::PacketBuf::toPacket(int length) { - bufRef.overflowBuf = overflowBuf; + Driver::Packet packet = { + .descriptor = (uintptr_t)this, .payload = payload, .length = length}; + return packet; } /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, const Config* const config) - : Impl(port, default_eal_argc, const_cast(default_eal_argv), config) +DpdkDriver::Impl::Impl(const char* ifname, const Config* const config) + : Impl(ifname, default_eal_argc, const_cast(default_eal_argv), + config) {} /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, int argc, char* argv[], +DpdkDriver::Impl::Impl(const char* ifname, int argc, char* argv[], const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -124,10 +144,14 @@ DpdkDriver::Impl::Impl(int port, int argc, char* argv[], /** * See DpdkDriver::DpdkDriver() */ -DpdkDriver::Impl::Impl(int port, __attribute__((__unused__)) NoEalInit _, +DpdkDriver::Impl::Impl(const char* ifname, + __attribute__((__unused__)) NoEalInit _, const Config* const config) - : port(port) - , localMac(Driver::Address(0)) + : ifname(ifname) + , port() + , arpTable() + , localIp() + , localMac("00:00:00:00:00:00") , HIGHEST_PACKET_PRIORITY( (config == nullptr || config->HIGHEST_PACKET_PRIORITY_OVERRIDE < 0) ? Homa::Util::arrayLength(PRIORITY_TO_PCP) - 1 @@ -159,61 +183,32 @@ DpdkDriver::Impl::~Impl() rte_mempool_free(mbufPool); } -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(std::string const* const addressString) -{ - return MacAddress(addressString->c_str()).toAddress(); -} - -// See Driver::getAddress() -Driver::Address -DpdkDriver::Impl::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - return MacAddress(wireAddress).toAddress(); -} - -/// See Driver::addressToString() -std::string -DpdkDriver::Impl::addressToString(const Driver::Address address) -{ - return MacAddress(address).toString(); -} - -/// See Driver::addressToWireFormat() -void -DpdkDriver::Impl::addressToWireFormat(const Driver::Address address, - Driver::WireFormatAddress* wireAddress) -{ - MacAddress(address).toWireFormat(wireAddress); -} - // See Driver::allocPacket() -DpdkDriver::Impl::Packet* -DpdkDriver::Impl::allocPacket() +void +DpdkDriver::Impl::allocPacket(Driver::Packet* packet) { - DpdkDriver::Impl::Packet* packet = _allocMbufPacket(); - if (unlikely(packet == nullptr)) { + PacketBuf* packetBuf = _allocMbufPacket(); + if (unlikely(packetBuf == nullptr)) { SpinLock::Lock lock(packetLock); OverflowBuffer* buf = overflowBufferPool.construct(); - packet = packetPool.construct(buf); + packetBuf = packetPool.construct(buf); NOTICE("OverflowBuffer used."); } - return packet; + *packet = packetBuf->toPacket(0); } // See Driver::sendPacket() void -DpdkDriver::Impl::sendPacket(Driver::Packet* packet) +DpdkDriver::Impl::sendPacket(Driver::Packet* packet, IpAddress destination, + int priority) { - DpdkDriver::Impl::Packet* pkt = - static_cast(packet); - struct rte_mbuf* mbuf = nullptr; + auto* packetBuf = (PacketBuf*)packet->descriptor; // If the packet is held in an Overflow buffer, we need to copy it out // into a new mbuf. - if (unlikely(pkt->bufType == DpdkDriver::Impl::Packet::OVERFLOW_BUF)) { + struct rte_mbuf* mbuf = nullptr; + if (unlikely(packetBuf->bufType == PacketBuf::OVERFLOW_BUF)) { mbuf = rte_pktmbuf_alloc(mbufPool); - if (unlikely(NULL == mbuf)) { + if (unlikely(nullptr == mbuf)) { uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); WARNING( @@ -224,16 +219,17 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) return; } char* buf = rte_pktmbuf_append( - mbuf, Homa::Util::downCast(PACKET_HDR_LEN + pkt->length)); + mbuf, + Homa::Util::downCast(PACKET_HDR_LEN + packet->length)); if (unlikely(NULL == buf)) { WARNING("rte_pktmbuf_append call failed; dropping packet"); rte_pktmbuf_free(mbuf); return; } char* data = buf + PACKET_HDR_LEN; - rte_memcpy(data, pkt->payload, pkt->length); + rte_memcpy(data, packetBuf->payload, packet->length); } else { - mbuf = pkt->bufRef.mbuf; + mbuf = packetBuf->bufRef.mbuf; // If the mbuf is still transmitting from a previous call to send, // we don't want to modify the buffer when the send is occuring. @@ -246,9 +242,14 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // Fill out the destination and source MAC addresses plus the Ethernet // frame type (i.e., IEEE 802.1Q VLAN tagging). - MacAddress macAddr(pkt->address); + auto it = arpTable.find(destination); + if (it == arpTable.end()) { + WARNING("Failed to find ARP record for packet; dropping packet"); + return; + } + MacAddress& destMac = it->second; struct ether_hdr* ethHdr = rte_pktmbuf_mtod(mbuf, struct ether_hdr*); - rte_memcpy(ðHdr->d_addr, macAddr.address, ETHER_ADDR_LEN); + rte_memcpy(ðHdr->d_addr, destMac.address, ETHER_ADDR_LEN); rte_memcpy(ðHdr->s_addr, localMac.address, ETHER_ADDR_LEN); ethHdr->ether_type = rte_cpu_to_be_16(ETHER_TYPE_VLAN); @@ -256,13 +257,17 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // encapsulated frame (DEI and VLAN ID are not relevant and trivially // set to 0). struct vlan_hdr* vlanHdr = reinterpret_cast(ethHdr + 1); - vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[pkt->priority]); + vlanHdr->vlan_tci = rte_cpu_to_be_16(PRIORITY_TO_PCP[priority]); vlanHdr->eth_proto = rte_cpu_to_be_16(EthPayloadType::HOMA); + // Store our local IP address right before the payload. + *rte_pktmbuf_mtod_offset(mbuf, uint32_t*, PACKET_HDR_LEN - 4) = + (uint32_t)localIp; + // In the normal case, we pre-allocate a pakcet's mbuf with enough // storage to hold the MAX_PAYLOAD_SIZE. If the actual payload is // smaller, trim the mbuf to size to avoid sending unecessary bits. - uint32_t actualLength = PACKET_HDR_LEN + pkt->length; + uint32_t actualLength = PACKET_HDR_LEN + packet->length; uint32_t mbufDataLength = rte_pktmbuf_pkt_len(mbuf); if (actualLength < mbufDataLength) { if (rte_pktmbuf_trim(mbuf, mbufDataLength - actualLength) < 0) { @@ -274,7 +279,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) } // loopback if src mac == dst mac - if (localMac.toAddress() == pkt->address) { + if (localMac == destMac) { struct rte_mbuf* mbuf_clone = rte_pktmbuf_clone(mbuf, mbufPool); if (unlikely(mbuf_clone == NULL)) { WARNING("Failed to clone packet for loopback; dropping packet"); @@ -289,7 +294,7 @@ DpdkDriver::Impl::sendPacket(Driver::Packet* packet) // If the packet is held in an mbuf, retain access to it so that the // processing of sending the mbuf won't free it. - if (likely(pkt->bufType == DpdkDriver::Impl::Packet::MBUF)) { + if (likely(packetBuf->bufType == PacketBuf::MBUF)) { rte_pktmbuf_refcnt_update(mbuf, 1); } @@ -327,7 +332,8 @@ DpdkDriver::Impl::uncork() // See Driver::receivePackets() uint32_t DpdkDriver::Impl::receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]) + Driver::Packet receivedPackets[], + IpAddress sourceAddresses[]) { uint32_t numPacketsReceived = 0; @@ -390,19 +396,22 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, } } + uint32_t srcIp = *rte_pktmbuf_mtod_offset(m, uint32_t*, headerLength); + headerLength += sizeof(srcIp); + payload += sizeof(srcIp); assert(rte_pktmbuf_pkt_len(m) >= headerLength); uint32_t length = rte_pktmbuf_pkt_len(m) - headerLength; assert(length <= MAX_PAYLOAD_SIZE); - DpdkDriver::Impl::Packet* packet = nullptr; + PacketBuf* packetBuf = nullptr; { SpinLock::Lock lock(packetLock); - packet = packetPool.construct(m, payload); + packetBuf = packetPool.construct(m, payload); } - packet->address = MacAddress(ethHdr->s_addr.addr_bytes).toAddress(); - packet->length = length; - receivedPackets[numPacketsReceived++] = packet; + receivedPackets[numPacketsReceived] = packetBuf->toPacket(length); + sourceAddresses[numPacketsReceived] = {srcIp}; + ++numPacketsReceived; } return numPacketsReceived; @@ -410,18 +419,17 @@ DpdkDriver::Impl::receivePackets(uint32_t maxPackets, // See Driver::releasePackets() void -DpdkDriver::Impl::releasePackets(Driver::Packet* packets[], uint16_t numPackets) +DpdkDriver::Impl::releasePackets(Driver::Packet packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { SpinLock::Lock lock(packetLock); - DpdkDriver::Impl::Packet* packet = - static_cast(packets[i]); - if (likely(packet->bufType == DpdkDriver::Impl::Packet::MBUF)) { - rte_pktmbuf_free(packet->bufRef.mbuf); + auto* packetBuf = (PacketBuf*)packets[i].descriptor; + if (likely(packetBuf->bufType == PacketBuf::MBUF)) { + rte_pktmbuf_free(packetBuf->bufRef.mbuf); } else { - overflowBufferPool.destroy(packet->bufRef.overflowBuf); + overflowBufferPool.destroy(packetBuf->bufRef.overflowBuf); } - packetPool.destroy(packet); + packetPool.destroy(packetBuf); } } @@ -447,10 +455,10 @@ DpdkDriver::Impl::getBandwidth() } // See Driver::getLocalAddress() -Driver::Address +IpAddress DpdkDriver::Impl::getLocalAddress() { - return localMac.toAddress(); + return localIp; } // See Driver::getQueuedBytes(); @@ -490,11 +498,77 @@ DpdkDriver::Impl::_eal_init(int argc, char* argv[]) void DpdkDriver::Impl::_init() { - struct ether_addr mac; struct rte_eth_conf portConf; int ret; uint16_t mtu; + // Populate the ARP table with records in /proc/net/arp (inspired by + // net-tools/arp.c) + std::ifstream input("/proc/net/arp"); + for (std::string line; getline(input, line);) { + char ip[100]; + char hwa[100]; + char mask[100]; + char dev[100]; + int type, flags; + int cols = sscanf(line.c_str(), "%s 0x%x 0x%x %99s %99s %99s\n", ip, + &type, &flags, hwa, mask, dev); + if (cols != 6) + continue; + arpTable.emplace(IpAddress::fromString(ip), hwa); + } + + // Use ioctl to obtain the IP and MAC addresses of the network interface. + struct ifreq ifr; + ifname.copy(ifr.ifr_name, ifname.length()); + ifr.ifr_name[ifname.length() + 1] = 0; + if (ifname.length() >= sizeof(ifr.ifr_name)) { + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Interface name %s too long", ifname.c_str())); + } + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd == -1) { + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to create socket: %s", strerror(errno))); + } + + if (ioctl(fd, SIOCGIFADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to obtain IP address: %s", error)); + } + localIp = {be32toh(((struct sockaddr_in*)&ifr.ifr_addr)->sin_addr.s_addr)}; + + if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) { + char* error = strerror(errno); + close(fd); + throw DriverInitFailure( + HERE_STR, + StringUtil::format("Failed to obtain MAC address: %s", error)); + } + close(fd); + memcpy(localMac.address, ifr.ifr_hwaddr.sa_data, 6); + + // Iterate over ethernet devices to locate the port identifier. + int p; + RTE_ETH_FOREACH_DEV(p) + { + struct ether_addr mac; + rte_eth_macaddr_get(p, &mac); + if (MacAddress(mac.addr_bytes) == localMac) { + port = p; + break; + } + } + NOTICE("Using interface %s, ip %s, mac %s, port %u", ifname.c_str(), + IpAddress::toString(localIp).c_str(), localMac.toString().c_str(), + port); + std::string poolName = StringUtil::format("homa_mbuf_pool_%u", port); std::string ringName = StringUtil::format("homa_loopback_ring_%u", port); @@ -518,10 +592,6 @@ DpdkDriver::Impl::_init() StringUtil::format("Ethernet port %u doesn't exist", port)); } - // Read the MAC address from the NIC via DPDK. - rte_eth_macaddr_get(port, &mac); - new (const_cast(&localMac)) MacAddress(mac.addr_bytes); - // configure some default NIC port parameters memset(&portConf, 0, sizeof(portConf)); portConf.rxmode.max_rx_pkt_len = ETHER_MAX_VLAN_FRAME_LEN; @@ -651,10 +721,9 @@ DpdkDriver::Impl::_init() * The newly allocated Dpdk Packet; nullptr if the mbuf allocation * failed. */ -DpdkDriver::Impl::Packet* +DpdkDriver::Impl::PacketBuf* DpdkDriver::Impl::_allocMbufPacket() { - DpdkDriver::Impl::Packet* packet = nullptr; uint32_t numMbufsAvail = rte_mempool_avail_count(mbufPool); if (unlikely(numMbufsAvail <= NB_MBUF_RESERVED)) { uint32_t numMbufsInUse = rte_mempool_in_use_count(mbufPool); @@ -688,11 +757,8 @@ DpdkDriver::Impl::_allocMbufPacket() } // Perform packet operations with the lock held. - { - SpinLock::Lock _(packetLock); - packet = packetPool.construct(mbuf, buf + PACKET_HDR_LEN); - } - return packet; + SpinLock::Lock _(packetLock); + return packetPool.construct(mbuf, buf + PACKET_HDR_LEN); } /** diff --git a/src/Drivers/DPDK/DpdkDriverImpl.h b/src/Drivers/DPDK/DpdkDriverImpl.h index 289e83f..ac9188c 100644 --- a/src/Drivers/DPDK/DpdkDriverImpl.h +++ b/src/Drivers/DPDK/DpdkDriverImpl.h @@ -28,6 +28,7 @@ #include #include +#include #include "MacAddress.h" #include "ObjectPool.h" @@ -65,8 +66,13 @@ const uint16_t MAX_PKT_BURST = 32; /// field defined in the VLAN tag to specify the packet priority. const uint32_t VLAN_TAG_LEN = 4; -// Size of Ethernet header including VLAN tag, in bytes. -const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN; +/// Strictly speaking, this DPDK driver is supposed to send/receive IP packets; +/// however, it currently only records the source IP address right after the +/// Ethernet header for simplicity. +const uint32_t IP_HDR_LEN = sizeof(IpAddress); + +// Size of Ethernet header including VLAN tag plus IP header, in bytes. +const uint32_t PACKET_HDR_LEN = ETHER_HDR_LEN + VLAN_TAG_LEN + IP_HDR_LEN; // The MTU (Maximum Transmission Unit) size of an Ethernet frame, which is the // maximum size of the packet an Ethernet frame can carry in its payload. This @@ -101,19 +107,16 @@ struct OverflowBuffer { class DpdkDriver::Impl { public: /** - * Dpdk specific Packet object used to track a its lifetime and + * DPDK specific Packet object used to track a its lifetime and * contents. */ - class Packet : public Driver::Packet { - public: - explicit Packet(struct rte_mbuf* mbuf, void* data); - explicit Packet(OverflowBuffer* overflowBuf); + struct PacketBuf { + explicit PacketBuf(struct rte_mbuf* mbuf, void* data); + explicit PacketBuf(OverflowBuffer* overflowBuf); + Driver::Packet toPacket(int length); - /// see Driver::Packet::getMaxPayloadSize() - virtual int getMaxPayloadSize() - { - return MAX_PAYLOAD_SIZE; - } + /// Memory location where the packet data should be stored. + void* const payload; /// Used to indicate whether the packet is backed by an DPDK mbuf or a /// driver-level OverflowBuffer. @@ -124,57 +127,55 @@ class DpdkDriver::Impl { struct rte_mbuf* mbuf; OverflowBuffer* overflowBuf; } bufRef; - - /// The memory location of this packet's header. The header should be - /// PACKET_HDR_LEN in length. - void* header; - - private: - Packet(const Packet&) = delete; - Packet& operator=(const Packet&) = delete; }; - Impl(int port, const Config* const config = nullptr); - Impl(int port, int argc, char* argv[], + Impl(const char* ifname, const Config* const config = nullptr); + Impl(const char* ifname, int argc, char* argv[], const Config* const config = nullptr); - Impl(int port, NoEalInit _, const Config* const config = nullptr); + Impl(const char* ifname, NoEalInit _, const Config* const config = nullptr); virtual ~Impl(); // Interface Methods - Driver::Address getAddress(std::string const* const addressString); - Driver::Address getAddress(WireFormatAddress const* const wireAddress); - std::string addressToString(const Address address); - void addressToWireFormat(const Address address, - WireFormatAddress* wireAddress); - Packet* allocPacket(); - void sendPacket(Driver::Packet* packet); + void allocPacket(Driver::Packet* packet); + void sendPacket(Driver::Packet* packet, IpAddress destination, + int priority); void cork(); void uncork(); uint32_t receivePackets(uint32_t maxPackets, - Driver::Packet* receivedPackets[]); - void releasePackets(Driver::Packet* packets[], uint16_t numPackets); + Driver::Packet receivedPackets[], + IpAddress sourceAddresses[]); + void releasePackets(Driver::Packet packets[], uint16_t numPackets); int getHighestPacketPriority(); uint32_t getMaxPayloadSize(); uint32_t getBandwidth(); - Driver::Address getLocalAddress(); + IpAddress getLocalAddress(); uint32_t getQueuedBytes(); private: void _eal_init(int argc, char* argv[]); void _init(); - Packet* _allocMbufPacket(); + PacketBuf* _allocMbufPacket(); static uint16_t txBurstCallback(uint16_t port_id, uint16_t queue, struct rte_mbuf* pkts[], uint16_t nb_pkts, void* user_param); static void txBurstErrorCallback(struct rte_mbuf* pkts[], uint16_t unsent, void* userdata); + /// Name of the Linux network interface to be used by DPDK. + std::string ifname; + /// Stores the NIC's physical port id addressed by the instantiated /// driver. - const uint16_t port; + uint16_t port; + + /// Address resolution table that translates IP addresses to MAC addresses. + std::unordered_map arpTable; + + /// Stores the IpAddress of the driver. + IpAddress localIp; - /// Stores the address of the NIC (either native or set by override). - const MacAddress localMac; + /// Stores the HW address of the NIC (either native or set by override). + MacAddress localMac; /// Stores the driver's maximum network packet priority (either default or /// set by override). @@ -185,7 +186,7 @@ class DpdkDriver::Impl { /// Provides memory allocation for the DPDK specific implementation of a /// Driver Packet. - ObjectPool packetPool; + ObjectPool packetPool; /// Provides memory allocation for packet storage when mbuf are running out. ObjectPool overflowBufferPool; diff --git a/src/Drivers/DPDK/MacAddress.cc b/src/Drivers/DPDK/MacAddress.cc index 0178851..e47f27a 100644 --- a/src/Drivers/DPDK/MacAddress.cc +++ b/src/Drivers/DPDK/MacAddress.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -18,7 +18,6 @@ #include "StringUtil.h" #include "../../CodeLocation.h" -#include "../RawAddressType.h" namespace Homa { namespace Drivers { @@ -55,33 +54,6 @@ MacAddress::MacAddress(const char* macStr) address[i] = Util::downCast(bytes[i]); } -/** - * Create a new address from a given address in its raw byte format. - * @param raw - * The raw bytes format. - * - * @sa Driver::Address::Raw - */ -MacAddress::MacAddress(const Driver::WireFormatAddress* const wireAddress) -{ - if (wireAddress->type != RawAddressType::MAC) { - throw BadAddress(HERE_STR, "Bad address: Raw format is not type MAC"); - } - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(address, wireAddress->bytes, 6); -} - -/** - * Create a new address given the Driver::Address representation. - * - * @param addr - * The Driver::Address representation of an address. - */ -MacAddress::MacAddress(const Driver::Address addr) -{ - memcpy(address, &addr, 6); -} - /** * Return the string representation of this address. */ @@ -94,31 +66,6 @@ MacAddress::toString() const return buf; } -/** - * Serialized this address into a wire format. - * - * @param[out] wireAddress - * WireFormatAddress object to which the this address is serialized. - */ -void -MacAddress::toWireFormat(Driver::WireFormatAddress* wireAddress) const -{ - static_assert(sizeof(wireAddress->bytes) >= 6); - memcpy(wireAddress->bytes, address, 6); - wireAddress->type = RawAddressType::MAC; -} - -/** - * Return a Driver::Address representation of this address. - */ -Driver::Address -MacAddress::toAddress() const -{ - Driver::Address addr = 0; - memcpy(&addr, address, 6); - return addr; -} - /** * @return * True if the MacAddress consists of all zero bytes, false if not. diff --git a/src/Drivers/DPDK/MacAddress.h b/src/Drivers/DPDK/MacAddress.h index 1106eec..33f47a5 100644 --- a/src/Drivers/DPDK/MacAddress.h +++ b/src/Drivers/DPDK/MacAddress.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -28,14 +28,19 @@ namespace DPDK { struct MacAddress { explicit MacAddress(const uint8_t raw[6]); explicit MacAddress(const char* macStr); - explicit MacAddress(const Driver::WireFormatAddress* const wireAddress); - explicit MacAddress(const Driver::Address addr); MacAddress(const MacAddress&) = default; std::string toString() const; - void toWireFormat(Driver::WireFormatAddress* wireAddress) const; - Driver::Address toAddress() const; bool isNull() const; + /** + * Equality function for MacAddress, for use in std::unordered_maps etc. + */ + bool operator==(const MacAddress& other) const + { + return (*(uint32_t*)(address + 0) == *(uint32_t*)(other.address + 0)) && + (*(uint16_t*)(address + 4) == *(uint16_t*)(other.address + 4)); + } + /// The raw bytes of the MAC address. uint8_t address[6]; }; diff --git a/src/Drivers/DPDK/MacAddressTest.cc b/src/Drivers/DPDK/MacAddressTest.cc index 329c309..9b8b8ae 100644 --- a/src/Drivers/DPDK/MacAddressTest.cc +++ b/src/Drivers/DPDK/MacAddressTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2011-2019, Stanford University +/* Copyright (c) 2011-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -15,8 +15,6 @@ #include "MacAddress.h" -#include "../RawAddressType.h" - #include namespace Homa { @@ -35,26 +33,6 @@ TEST(MacAddressTest, constructorString) EXPECT_EQ("de:ad:be:ef:98:76", MacAddress("de:ad:be:ef:98:76").toString()); } -TEST(MacAddressTest, constructorWireFormatAddress) -{ - uint8_t bytes[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::MAC; - memcpy(wireformatAddress.bytes, bytes, 6); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(&wireformatAddress).toString()); - - wireformatAddress.type = RawAddressType::FAKE; - EXPECT_THROW(MacAddress address(&wireformatAddress), BadAddress); -} - -TEST(MacAddressTest, constructorAddress) -{ - uint8_t raw[] = {0xde, 0xad, 0xbe, 0xef, 0x98, 0x76}; - MacAddress(raw).toString(); - Driver::Address addr = MacAddress("de:ad:be:ef:98:76").toAddress(); - EXPECT_EQ("de:ad:be:ef:98:76", MacAddress(addr).toString()); -} - TEST(MacAddressTest, construct_DefaultCopy) { MacAddress source("de:ad:be:ef:98:76"); @@ -67,24 +45,6 @@ TEST(MacAddressTest, toString) // tested sufficiently in constructor tests } -TEST(MacAddressTest, toWireFormat) -{ - Driver::WireFormatAddress wireformatAddress; - MacAddress("de:ad:be:ef:98:76").toWireFormat(&wireformatAddress); - EXPECT_EQ(RawAddressType::MAC, wireformatAddress.type); - EXPECT_EQ(0xde, wireformatAddress.bytes[0]); - EXPECT_EQ(0xad, wireformatAddress.bytes[1]); - EXPECT_EQ(0xbe, wireformatAddress.bytes[2]); - EXPECT_EQ(0xef, wireformatAddress.bytes[3]); - EXPECT_EQ(0x98, wireformatAddress.bytes[4]); - EXPECT_EQ(0x76, wireformatAddress.bytes[5]); -} - -TEST(MacAddressTest, toAddress) -{ - // Tested in constructorAddress -} - TEST(MacAddressTest, isNull) { uint8_t rawNull[] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; diff --git a/src/Drivers/Fake/FakeAddressTest.cc b/src/Drivers/Fake/FakeAddressTest.cc deleted file mode 100644 index 67cef78..0000000 --- a/src/Drivers/Fake/FakeAddressTest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#include - -#include "FakeAddress.h" - -#include "../RawAddressType.h" - -namespace Homa { -namespace Drivers { -namespace Fake { -namespace { - -TEST(FakeAddressTest, constructor_id) -{ - FakeAddress address(42); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str) -{ - FakeAddress address("42"); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_str_bad) -{ - EXPECT_THROW(FakeAddress address("D42"), BadAddress); -} - -TEST(FakeAddressTest, constructor_raw) -{ - Driver::Address::Raw raw; - raw.type = RawAddressType::FAKE; - *reinterpret_cast(raw.bytes) = 42; - - FakeAddress address(&raw); - EXPECT_EQ("42", address.toString()); -} - -TEST(FakeAddressTest, constructor_raw_bad) -{ - Driver::Address::Raw raw; - raw.type = !RawAddressType::FAKE; - - EXPECT_THROW(FakeAddress address(&raw), BadAddress); -} - -TEST(FakeAddressTest, toString) -{ - // tested sufficiently in constructor tests -} - -TEST(FakeAddressTest, toAddressId) -{ - EXPECT_THROW(FakeAddress::toAddressId("D42"), BadAddress); -} - -} // namespace -} // namespace Fake -} // namespace Drivers -} // namespace Homa diff --git a/src/Drivers/Fake/FakeDriver.cc b/src/Drivers/Fake/FakeDriver.cc index 6200a49..10c5e0c 100644 --- a/src/Drivers/Fake/FakeDriver.cc +++ b/src/Drivers/Fake/FakeDriver.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -56,24 +56,26 @@ static class FakeNetwork { /// Register the FakeNIC so it can receive packets. Returns the newly /// registered FakeNIC's addressId. - uint64_t registerNIC(FakeNIC* nic) + uint32_t registerNIC(FakeNIC* nic) { std::lock_guard lock(mutex); - uint64_t addressId = nextAddressId.fetch_add(1); - network.insert({addressId, nic}); + uint32_t addressId = nextAddressId.fetch_add(1); + IpAddress ipAddress{addressId}; + network.insert({ipAddress, nic}); return addressId; } /// Remove the FakeNIC from the network. - void deregisterNIC(uint64_t addressId) + void deregisterNIC(uint32_t addressId) { std::lock_guard lock(mutex); - network.erase(addressId); + IpAddress ipAddress{addressId}; + network.erase(ipAddress); } /// Deliver the provide packet to the specified destination. - void sendPacket(FakePacket* packet, Driver::Address src, - Driver::Address dst) + void sendPacket(FakePacket* packet, int priority, IpAddress src, + IpAddress dst) { FakeNIC* nic = nullptr; { @@ -92,10 +94,10 @@ static class FakeNetwork { assert(nic != nullptr); std::lock_guard lock_nic(nic->mutex, std::adopt_lock); FakePacket* dstPacket = new FakePacket(*packet); - dstPacket->address = src; - assert(dstPacket->priority < NUM_PRIORITIES); - assert(dstPacket->priority >= 0); - nic->priorityQueue.at(dstPacket->priority).push_back(dstPacket); + dstPacket->sourceIp = src; + assert(priority < NUM_PRIORITIES); + assert(priority >= 0); + nic->priorityQueue.at(priority).push_back(dstPacket); } void setPacketLossRate(double lossRate) @@ -115,11 +117,10 @@ static class FakeNetwork { std::mutex mutex; /// Holds all the packets being sent through the fake network. - std::unordered_map network; + std::unordered_map network; - /// The FakeAddress identifier for the next FakeDriver that "connects" to - /// the FakeNetwork. - std::atomic nextAddressId; + /// Identifier for the next FakeDriver that "connects" to the FakeNetwork. + std::atomic nextAddressId; /// Rate at which packets should be dropped when sent over this network. double packetLossRate; @@ -177,73 +178,27 @@ FakeDriver::~FakeDriver() fakeNetwork.deregisterNIC(localAddressId); } -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(std::string const* const addressString) -{ - char* end; - uint64_t address = std::strtoul(addressString->c_str(), &end, 10); - if (address == 0) { - throw BadAddress(HERE_STR, StringUtil::format("Bad address string: %s", - addressString->c_str())); - } - return address; -} - -/** - * See Driver::getAddress() - */ -Driver::Address -FakeDriver::getAddress(Driver::WireFormatAddress const* const wireAddress) -{ - const Address* address = - reinterpret_cast(wireAddress->bytes); - return *address; -} - -/** - * See Driver::addressToString() - */ -std::string -FakeDriver::addressToString(const Address address) -{ - char buf[21]; - snprintf(buf, sizeof(buf), "%lu", address); - return buf; -} - -/** - * See Driver::addressToWireFormat() - */ -void -FakeDriver::addressToWireFormat(const Address address, - WireFormatAddress* wireAddress) -{ - new (reinterpret_cast(wireAddress->bytes)) Address(address); -} - /** * See Driver::allocPacket() */ -Driver::Packet* -FakeDriver::allocPacket() +void +FakeDriver::allocPacket(Packet* packet) { - FakePacket* packet = new FakePacket(); - return packet; + FakePacket* fakePacket = new FakePacket(); + *packet = fakePacket->toPacket(); } /** * See Driver::sendPacket() */ void -FakeDriver::sendPacket(Packet* packet) +FakeDriver::sendPacket(Packet* packet, IpAddress destination, int priority) { - FakePacket* srcPacket = static_cast(packet); - Address srcAddress = getLocalAddress(); - Address dstAddress = srcPacket->address; - fakeNetwork.sendPacket(srcPacket, srcAddress, dstAddress); + FakePacket* srcPacket = (FakePacket*)packet->descriptor; + srcPacket->length = packet->length; + IpAddress srcAddress = getLocalAddress(); + IpAddress dstAddress = destination; + fakeNetwork.sendPacket(srcPacket, priority, srcAddress, dstAddress); queueEstimator.signalBytesSent(packet->length); } @@ -251,14 +206,17 @@ FakeDriver::sendPacket(Packet* packet) * See Driver::receivePackets() */ uint32_t -FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) +FakeDriver::receivePackets(uint32_t maxPackets, Packet receivedPackets[], + IpAddress sourceAddresses[]) { std::lock_guard lock_nic(nic.mutex); uint32_t numReceived = 0; for (int i = NUM_PRIORITIES - 1; i >= 0; --i) { while (numReceived < maxPackets && !nic.priorityQueue.at(i).empty()) { - receivedPackets[numReceived] = nic.priorityQueue.at(i).front(); + FakePacket* fakePacket = nic.priorityQueue.at(i).front(); nic.priorityQueue.at(i).pop_front(); + receivedPackets[numReceived] = fakePacket->toPacket(); + sourceAddresses[numReceived] = fakePacket->sourceIp; numReceived++; } } @@ -269,11 +227,10 @@ FakeDriver::receivePackets(uint32_t maxPackets, Packet* receivedPackets[]) * See Driver::releasePackets() */ void -FakeDriver::releasePackets(Packet* packets[], uint16_t numPackets) +FakeDriver::releasePackets(Packet packets[], uint16_t numPackets) { for (uint16_t i = 0; i < numPackets; ++i) { - FakePacket* packet = static_cast(packets[i]); - delete packet; + delete (FakePacket*)packets[i].descriptor; } } @@ -308,10 +265,10 @@ FakeDriver::getBandwidth() /** * See Driver::getLocalAddress() */ -Driver::Address +IpAddress FakeDriver::getLocalAddress() { - return localAddressId; + return IpAddress{localAddressId}; } /** diff --git a/src/Drivers/Fake/FakeDriverTest.cc b/src/Drivers/Fake/FakeDriverTest.cc index e410119..2f02f07 100644 --- a/src/Drivers/Fake/FakeDriverTest.cc +++ b/src/Drivers/Fake/FakeDriverTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -18,7 +18,6 @@ #include -#include "../RawAddressType.h" #include "StringUtil.h" namespace Homa { @@ -28,52 +27,18 @@ namespace { TEST(FakeDriverTest, constructor) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; EXPECT_EQ(nextAddressId, driver.localAddressId); } -TEST(FakeDriverTest, getAddress_string) -{ - FakeDriver driver; - std::string addressStr("42"); - Driver::Address address = driver.getAddress(&addressStr); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, getAddress_wireformat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - wireformatAddress.type = RawAddressType::FAKE; - *reinterpret_cast(wireformatAddress.bytes) = 42; - Driver::Address address = driver.getAddress(&wireformatAddress); - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToString) -{ - FakeDriver driver; - Driver::Address address = 42; - EXPECT_EQ("42", driver.addressToString(address)); -} - -TEST(FakeDriverTest, addressToWireFormat) -{ - FakeDriver driver; - Driver::WireFormatAddress wireformatAddress; - driver.addressToWireFormat(42, &wireformatAddress); - EXPECT_EQ("42", - driver.addressToString(driver.getAddress(&wireformatAddress))); -} - TEST(FakeDriverTest, allocPacket) { - FakeDriver driver; - Driver::Packet* packet = driver.allocPacket(); // allocPacket doesn't do much so we just need to make sure we can call it. - delete packet; + FakeDriver driver; + Driver::Packet packet; + driver.allocPacket(&packet); } TEST(FakeDriverTest, sendPackets) @@ -81,14 +46,15 @@ TEST(FakeDriverTest, sendPackets) FakeDriver driver1; FakeDriver driver2; - Driver::Packet* packets[4]; + Driver::Packet packets[4]; + IpAddress destinations[4]; + int prio[4]; for (int i = 0; i < 4; ++i) { - packets[i] = driver1.allocPacket(); - packets[i]->address = driver2.getLocalAddress(); - packets[i]->priority = i; + driver1.allocPacket(&packets[i]); + destinations[i] = driver2.getLocalAddress(); + prio[i] = i; } - std::string addressStr("42"); - packets[2]->address = driver1.getAddress(&addressStr); + destinations[2] = IpAddress{42}; EXPECT_EQ(0U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -99,7 +65,7 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - driver1.sendPacket(packets[0]); + driver1.sendPacket(&packets[0], destinations[0], prio[0]); EXPECT_EQ(1U, driver2.nic.priorityQueue.at(0).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(1).size()); @@ -110,13 +76,12 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); { - Driver::Packet* packet = static_cast( - driver2.nic.priorityQueue.at(0).front()); - EXPECT_EQ(driver1.getLocalAddress(), packet->address); + FakePacket* packet = driver2.nic.priorityQueue.at(0).front(); + EXPECT_EQ(driver1.getLocalAddress(), packet->sourceIp); } for (int i = 0; i < 4; ++i) { - driver1.sendPacket(packets[i]); + driver1.sendPacket(&packets[i], destinations[i], prio[i]); } EXPECT_EQ(2U, driver2.nic.priorityQueue.at(0).size()); @@ -127,8 +92,6 @@ TEST(FakeDriverTest, sendPackets) EXPECT_EQ(0U, driver2.nic.priorityQueue.at(5).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver2.nic.priorityQueue.at(7).size()); - - delete packets[2]; } TEST(FakeDriverTest, receivePackets) @@ -136,7 +99,8 @@ TEST(FakeDriverTest, receivePackets) std::string addressStr("42"); FakeDriver driver; - Driver::Packet* packets[4]; + Driver::Packet packets[4]; + IpAddress srcAddrs[4]; // 3 packets at priority 7 for (int i = 0; i < 3; ++i) @@ -158,7 +122,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(3U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(4U, driver.receivePackets(4, packets)); + EXPECT_EQ(4U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 4); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -170,7 +134,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -193,7 +157,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(1U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(1U, driver.receivePackets(1, packets)); + EXPECT_EQ(1U, driver.receivePackets(1, packets, srcAddrs)); driver.releasePackets(packets, 1); EXPECT_EQ(0U, driver.nic.priorityQueue.at(0).size()); @@ -205,7 +169,7 @@ TEST(FakeDriverTest, receivePackets) EXPECT_EQ(0U, driver.nic.priorityQueue.at(6).size()); EXPECT_EQ(0U, driver.nic.priorityQueue.at(7).size()); - EXPECT_EQ(3U, driver.receivePackets(4, packets)); + EXPECT_EQ(3U, driver.receivePackets(4, packets, srcAddrs)); driver.releasePackets(packets, 3); } @@ -234,11 +198,9 @@ TEST(FakeDriverTest, getBandwidth) TEST(FakeDriverTest, getLocalAddress) { - uint64_t nextAddressId = FakeDriver().localAddressId + 1; - std::string addressStr = StringUtil::format("%lu", nextAddressId); - + uint32_t nextAddressId = FakeDriver().localAddressId + 1; FakeDriver driver; - EXPECT_EQ(driver.getAddress(&addressStr), driver.getLocalAddress()); + EXPECT_EQ(nextAddressId, (uint32_t)driver.getLocalAddress()); } } // namespace diff --git a/src/Drivers/RawAddressType.h b/src/Drivers/RawAddressType.h deleted file mode 100644 index 1def76d..0000000 --- a/src/Drivers/RawAddressType.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2019, Stanford University - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef HOMA_DRIVERS_RAWADDRESSTYPE_H -#define HOMA_DRIVERS_RAWADDRESSTYPE_H - -namespace Homa { -namespace Drivers { - -/** - * Identifies a particular raw serialized byte-format for a Driver::Address - * supported by this project. The types are enumerated here in one place to - * ensure drivers do have overlapping type identifiers. New drivers that wish - * to claim a type id should add an entry to this enum. - * - * @sa Driver::Address::Raw - */ -enum RawAddressType { - FAKE = 0, - MAC = 1, -}; - -} // namespace Drivers -} // namespace Homa - -#endif // HOMA_DRIVERS_RAWADDRESSTYPE_H diff --git a/src/Mock/MockDriver.h b/src/Mock/MockDriver.h index 6cc5ea7..5a29bc9 100644 --- a/src/Mock/MockDriver.h +++ b/src/Mock/MockDriver.h @@ -31,36 +31,41 @@ namespace Mock { class MockDriver : public Driver { public: /** - * Used in unit tests to mock calls to Driver::Packet. - * - * @sa Driver::Packet. + * Used in unit tests to mock driver-specific packet buffers. */ - class MockPacket : public Driver::Packet { - public: - MockPacket(void* payload, uint16_t length = 0) - : Packet(payload, length) - {} + struct PacketBuf { + /// External buffer which stores the packet data. + void* buffer; - MOCK_METHOD0(getMaxPayloadSize, int()); + /** + * Convert this packet buffer to the generic Driver::Packet + * representation. + */ + Driver::Packet toPacket(int length = 0) + { + Driver::Packet packet = {.descriptor = (uintptr_t)this, + .payload = buffer, + .length = length}; + return packet; + } }; - MOCK_METHOD1(getAddress, Address(std::string const* const addressString)); - MOCK_METHOD1(getAddress, - Address(WireFormatAddress const* const wireAddress)); - MOCK_METHOD1(addressToString, std::string(Address address)); - MOCK_METHOD2(addressToWireFormat, - void(Address address, WireFormatAddress* wireAddress)); - MOCK_METHOD0(allocPacket, Packet*()); - MOCK_METHOD1(sendPacket, void(Packet* packet)); - MOCK_METHOD0(flushPackets, void()); - MOCK_METHOD2(receivePackets, - uint32_t(uint32_t maxPackets, Packet* receivedPackets[])); - MOCK_METHOD2(releasePackets, void(Packet* packets[], uint16_t numPackets)); - MOCK_METHOD0(getHighestPacketPriority, int()); - MOCK_METHOD0(getMaxPayloadSize, uint32_t()); - MOCK_METHOD0(getBandwidth, uint32_t()); - MOCK_METHOD0(getLocalAddress, Address()); - MOCK_METHOD0(getQueuedBytes, uint32_t()); + MOCK_METHOD(void, allocPacket, (Packet * packet), (override)); + MOCK_METHOD(void, sendPacket, + (Packet * packet, IpAddress destination, int priority), + (override)); + MOCK_METHOD(void, flushPackets, ()); + MOCK_METHOD(uint32_t, receivePackets, + (uint32_t maxPackets, Packet receivedPackets[], + IpAddress sourceAddresses[]), + (override)); + MOCK_METHOD(void, releasePackets, (Packet packets[], uint16_t numPackets), + (override)); + MOCK_METHOD(int, getHighestPacketPriority, (), (override)); + MOCK_METHOD(uint32_t, getMaxPayloadSize, (), (override)); + MOCK_METHOD(uint32_t, getBandwidth, (), (override)); + MOCK_METHOD(IpAddress, getLocalAddress, (), (override)); + MOCK_METHOD(uint32_t, getQueuedBytes, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockPolicy.h b/src/Mock/MockPolicy.h index 0595f25..32e7be8 100644 --- a/src/Mock/MockPolicy.h +++ b/src/Mock/MockPolicy.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -36,10 +36,10 @@ class MockPolicyManager : public Core::Policy::Manager { MOCK_METHOD0(getResendPriority, int()); MOCK_METHOD0(getScheduledPolicy, Core::Policy::Scheduled()); MOCK_METHOD2(getUnscheduledPolicy, - Core::Policy::Unscheduled(const Driver::Address destination, + Core::Policy::Unscheduled(const IpAddress destination, const uint32_t messageLength)); MOCK_METHOD3(signalNewMessage, - void(const Driver::Address source, uint8_t policyVersion, + void(const IpAddress source, uint8_t policyVersion, uint32_t messageLength)); MOCK_METHOD0(poll, void()); }; diff --git a/src/Mock/MockReceiver.h b/src/Mock/MockReceiver.h index fc0fa13..e8e4e1d 100644 --- a/src/Mock/MockReceiver.h +++ b/src/Mock/MockReceiver.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -33,18 +33,17 @@ class MockReceiver : public Core::Receiver { public: MockReceiver(Driver* driver, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) - : Receiver(driver, nullptr, messageTimeoutCycles, resendIntervalCycles) + : Receiver(driver, nullptr, nullptr, messageTimeoutCycles, + resendIntervalCycles) {} - MOCK_METHOD2(handleDataPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleBusyPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handlePingPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(receiveMessage, Homa::InMessage*()); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(void, handleDataPacket, + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(void, handleBusyPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handlePingPacket, + (Driver::Packet * packet, IpAddress sourceIp), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); + MOCK_METHOD(bool, trySendGrants, (), (override)); }; } // namespace Mock diff --git a/src/Mock/MockSender.h b/src/Mock/MockSender.h index b67152b..faa1291 100644 --- a/src/Mock/MockSender.h +++ b/src/Mock/MockSender.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -33,23 +33,20 @@ class MockSender : public Core::Sender { public: MockSender(uint64_t transportId, Driver* driver, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) - : Sender(transportId, driver, nullptr, messageTimeoutCycles, + : Sender(transportId, driver, nullptr, nullptr, messageTimeoutCycles, pingIntervalCycles) {} - MOCK_METHOD0(allocMessage, Homa::OutMessage*()); - MOCK_METHOD2(handleDonePacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleGrantPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleResendPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleUnknownPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD2(handleErrorPacket, - void(Driver::Packet* packet, Driver* driver)); - MOCK_METHOD0(poll, void()); - MOCK_METHOD0(checkTimeouts, uint64_t()); + MOCK_METHOD(Homa::OutMessage*, allocMessage, (uint16_t sport), (override)); + MOCK_METHOD(void, handleDonePacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleGrantPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(void, handleResendPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleUnknownPacket, (Driver::Packet * packet), + (override)); + MOCK_METHOD(void, handleErrorPacket, (Driver::Packet * packet), (override)); + MOCK_METHOD(uint64_t, checkTimeouts, (), (override)); + MOCK_METHOD(uint64_t, trySend, (), (override)); }; } // namespace Mock diff --git a/src/Perf.cc b/src/Perf.cc index 155802c..faf154a 100644 --- a/src/Perf.cc +++ b/src/Perf.cc @@ -15,8 +15,6 @@ #include "Perf.h" -#include - #include #include diff --git a/src/Perf.h b/src/Perf.h index bf9668d..2349b01 100644 --- a/src/Perf.h +++ b/src/Perf.h @@ -17,7 +17,7 @@ #define HOMA_PERF_H #include -#include +#include #include diff --git a/src/Policy.cc b/src/Policy.cc index cf0e62e..12e7e16 100644 --- a/src/Policy.cc +++ b/src/Policy.cc @@ -97,14 +97,14 @@ Manager::getScheduledPolicy() * unilaterally "granted" (unscheduled) bytes for a new Message to be sent. * * @param destination - * The policy for the Transport at this Address will be returned. + * The policy for the Transport at this IpAddress will be returned. * @param messageLength * The policy for message containing this many bytes will be returned. * * @sa Policy::Unscheduled */ Unscheduled -Manager::getUnscheduledPolicy(const Driver::Address destination, +Manager::getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength) { SpinLock::Lock lock(mutex); @@ -140,14 +140,14 @@ Manager::getUnscheduledPolicy(const Driver::Address destination, * Called by the Receiver when a new Message has started to arrive. * * @param source - * Address of the Transport from which the new Message was received. + * IpAddress of the Transport from which the new Message was received. * @param policyVersion * Version of the policy the Sender used when sending the Message. * @param messageLength * Number of bytes the new incoming Message contains. */ void -Manager::signalNewMessage(const Driver::Address source, uint8_t policyVersion, +Manager::signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength) { SpinLock::Lock lock(mutex); diff --git a/src/Policy.h b/src/Policy.h index c32bf66..0be1eb2 100644 --- a/src/Policy.h +++ b/src/Policy.h @@ -75,10 +75,9 @@ class Manager { virtual ~Manager() = default; virtual int getResendPriority(); virtual Scheduled getScheduledPolicy(); - virtual Unscheduled getUnscheduledPolicy(const Driver::Address destination, + virtual Unscheduled getUnscheduledPolicy(const IpAddress destination, const uint32_t messageLength); - virtual void signalNewMessage(const Driver::Address source, - uint8_t policyVersion, + virtual void signalNewMessage(const IpAddress source, uint8_t policyVersion, uint32_t messageLength); virtual void poll(); @@ -107,7 +106,8 @@ class Manager { /// The scheduled policy for the Transport that owns this Policy::Manager. Scheduled localScheduledPolicy; /// Collection of the known Policies for each peered Homa::Transport; - std::unordered_map peerPolicies; + std::unordered_map + peerPolicies; /// Number of bytes that can be transmitted in one round-trip-time. const uint32_t RTT_BYTES; /// The highest network packet priority that the driver supports. diff --git a/src/PolicyTest.cc b/src/PolicyTest.cc index ee0dde5..44b8829 100644 --- a/src/PolicyTest.cc +++ b/src/PolicyTest.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -59,7 +59,7 @@ TEST(PolicyManagerTest, getUnscheduledPolicy) EXPECT_CALL(mockDriver, getBandwidth).WillOnce(Return(8000)); EXPECT_CALL(mockDriver, getHighestPacketPriority).WillOnce(Return(7)); Policy::Manager manager(&mockDriver); - Driver::Address dest(22); + IpAddress dest{22}; { Policy::Unscheduled policy = manager.getUnscheduledPolicy(dest, 1); diff --git a/src/Protocol.h b/src/Protocol.h index f83725e..fe2cec0 100644 --- a/src/Protocol.h +++ b/src/Protocol.h @@ -104,19 +104,24 @@ enum Opcode { /** * This is the first part of the Homa packet header and is common to all - * versions of the protocol. The struct contains version information about the + * versions of the protocol. The first four bytes of the header store the source + * and destination ports, which is common for many transport layer protocols + * (e.g., TCP, UDP, etc.) The struct also contains version information about the * protocol used in the encompassing packet. The Transport should always send * this prefix and can always expect it when receiving a Homa packet. The prefix * is separated into its own struct because the Transport may need to know the * protocol version before interpreting the rest of the packet. */ struct HeaderPrefix { - uint8_t version; ///< The version of the protocol being used by this - ///< packet. + uint16_t sport; ///< Transport layer (L4) source and destination ports + uint16_t dport; ///< in network byte order; only used by DataHeader. + uint8_t version; ///< The version of the protocol being used by this packet /// HeaderPrefix constructor. - HeaderPrefix(uint8_t version) - : version(version) + HeaderPrefix(uint16_t sport, uint16_t dport, uint8_t version) + : sport(sport) + , dport(dport) + , version(version) {} } __attribute__((packed)); @@ -131,7 +136,7 @@ struct CommonHeader { /// CommonHeader constructor. CommonHeader(Opcode opcode, MessageId messageId) - : prefix(1) + : prefix(0, 0, 1) , opcode(opcode) , messageId(messageId) {} @@ -157,14 +162,18 @@ struct DataHeader { // starting at the offset corresponding to the given packet index. /// DataHeader constructor. - DataHeader(MessageId messageId, uint32_t totalLength, uint8_t policyVersion, + DataHeader(uint16_t sport, uint16_t dport, MessageId messageId, + uint32_t totalLength, uint8_t policyVersion, uint16_t unscheduledIndexLimit, uint16_t index) : common(Opcode::DATA, messageId) , totalLength(totalLength) , policyVersion(policyVersion) , unscheduledIndexLimit(unscheduledIndexLimit) , index(index) - {} + { + common.prefix.sport = htobe16(sport); + common.prefix.dport = htobe16(dport); + } } __attribute__((packed)); /** diff --git a/src/Receiver.cc b/src/Receiver.cc index 25e0619..2c266ba 100644 --- a/src/Receiver.cc +++ b/src/Receiver.cc @@ -18,6 +18,7 @@ #include #include "Perf.h" +#include "Tub.h" #include "Util.h" namespace Homa { @@ -28,6 +29,8 @@ namespace Core { * * @param driver * The driver used to send and receive packets. + * @param callbacks + * User-defined transport callbacks. * @param policyManager * Provides information about the grant and network priority policies. * @param messageTimeoutCycles @@ -37,15 +40,16 @@ namespace Core { * Number of cycles of inactivity to wait between requesting retransmission * of un-received parts of a message. */ -Receiver::Receiver(Driver* driver, Policy::Manager* policyManager, +Receiver::Receiver(Driver* driver, Transport::Callbacks* callbacks, + Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles) - : driver(driver) + : callbacks(callbacks) + , driver(driver) , policyManager(policyManager) , messageBuckets(messageTimeoutCycles, resendIntervalCycles) , schedulerMutex() , scheduledPeers() - , receivedMessages() - , granting() + , dontNeedGrants() , messageAllocator() {} @@ -57,8 +61,6 @@ Receiver::~Receiver() schedulerMutex.lock(); scheduledPeers.clear(); peerTable.clear(); - receivedMessages.mutex.lock(); - receivedMessages.queue.clear(); for (auto it = messageBuckets.buckets.begin(); it != messageBuckets.buckets.end(); ++it) { MessageBucket* bucket = *it; @@ -82,11 +84,11 @@ Receiver::~Receiver() * * @param packet * The incoming packet to be processed. - * @param driver - * The driver from which the packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleDataPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::DataHeader* header = static_cast(packet->payload); @@ -94,22 +96,25 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) Protocol::MessageId id = header->common.messageId; MessageBucket* bucket = messageBuckets.getBucket(id); - SpinLock::Lock lock_bucket(bucket->mutex); - Message* message = bucket->findMessage(id, lock_bucket); + Tub lock_bucket; + lock_bucket.construct(bucket->mutex); + Message* message = bucket->findMessage(id, *lock_bucket); if (message == nullptr) { // New message int messageLength = header->totalLength; int numUnscheduledPackets = header->unscheduledIndexLimit; { SpinLock::Lock lock_allocator(messageAllocator.mutex); + SocketAddress srcAddress = { + .ip = sourceIp, .port = be16toh(header->common.prefix.sport)}; message = messageAllocator.pool.construct( - this, driver, dataHeaderLength, messageLength, id, - packet->address, numUnscheduledPackets); + this, driver, dataHeaderLength, messageLength, id, srcAddress, + numUnscheduledPackets); } bucket->messages.push_back(&message->bucketNode); - policyManager->signalNewMessage(message->source, header->policyVersion, - header->totalLength); + policyManager->signalNewMessage( + message->source.ip, header->policyVersion, header->totalLength); if (message->scheduled) { // Message needs to be scheduled. @@ -121,7 +126,8 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) // Things that must be true (sanity check) assert(id == message->id); assert(message->driver == driver); - assert(message->source == packet->address); + assert(message->source.ip == sourceIp); + assert(message->source.port == be16toh(header->common.prefix.sport)); assert(message->messageLength == Util::downCast(header->totalLength)); // Add the packet @@ -140,6 +146,10 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) info->bytesRemaining -= packetDataBytes; updateSchedule(message, lock_scheduler); } + + // Non-duplicate DATA packets from scheduled messages can change + // the state of scheduledPeers; time to run trySendGrants() again + signalNeedGrants(lock_scheduler); } // Receiving a new packet means the message is still active so it @@ -152,16 +162,22 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) bucket->resendTimeouts.setTimeout(&message->resendTimeout); } else { // All message packets have been received. - message->state.store(Message::State::COMPLETED); + message->setState(Message::State::COMPLETED); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - receivedMessages.queue.push_back(&message->receivedMessageNode); + lock_bucket.destroy(); + + uint16_t dport = be16toh(header->common.prefix.dport); + bool success = + callbacks->deliver(dport, Homa::unique_ptr(message)); + if (!success) { + ERROR("Unable to deliver the message; message dropped"); + dropMessage(message); + } } } else { // must be a duplicate packet; drop packet. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } - return; } /** @@ -169,11 +185,9 @@ Receiver::handleDataPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming BUSY packet to be processed. - * @param driver - * The driver from which the BUSY packet was received. */ void -Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) +Receiver::handleBusyPacket(Driver::Packet* packet) { Protocol::Packet::BusyHeader* header = static_cast(packet->payload); @@ -186,11 +200,11 @@ Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) // Sender has replied BUSY to our RESEND request; consider this message // still active. bucket->messageTimeouts.setTimeout(&message->messageTimeout); - if (message->state == Message::State::IN_PROGRESS) { + if (message->getState() == Message::State::IN_PROGRESS) { bucket->resendTimeouts.setTimeout(&message->resendTimeout); } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -198,11 +212,11 @@ Receiver::handleBusyPacket(Driver::Packet* packet, Driver* driver) * * @param packet * The incoming PING packet to be processed. - * @param driver - * The driver from which the PING packet was received. + * @param sourceIp + * Source IP address of the packet. */ void -Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) +Receiver::handlePingPacket(Driver::Packet* packet, IpAddress sourceIp) { Protocol::Packet::PingHeader* header = static_cast(packet->payload); @@ -236,49 +250,15 @@ Receiver::handlePingPacket(Driver::Packet* packet, Driver* driver) Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, message->source, message->id, bytesGranted, priority); + driver, message->source.ip, message->id, bytesGranted, priority); } else { // We are here because we have no knowledge of the message the Sender is // asking about. Reply UNKNOWN so the Sender can react accordingly. Perf::counters.tx_unknown_pkts.add(1); - ControlPacket::send( - driver, packet->address, id); + ControlPacket::send(driver, sourceIp, + id); } - driver->releasePackets(&packet, 1); -} - -/** - * Return a handle to a new received Message. - * - * The Transport should regularly call this method to insure incoming messages - * are processed. - * - * @return - * A new Message which has been received, if available; otherwise, nullptr. - * - * @sa dropMessage() - */ -Homa::InMessage* -Receiver::receiveMessage() -{ - SpinLock::Lock lock_received_messages(receivedMessages.mutex); - Message* message = nullptr; - if (!receivedMessages.queue.empty()) { - message = &receivedMessages.queue.front(); - receivedMessages.queue.pop_front(); - } - return message; -} - -/** - * Allow the Receiver to make progress toward receiving incoming messages. - * - * This method must be called eagerly to ensure messages are received. - */ -void -Receiver::poll() -{ - trySendGrants(); + driver->releasePackets(packet, 1); } /** @@ -293,16 +273,13 @@ Receiver::poll() uint64_t Receiver::checkTimeouts() { - uint64_t nextTimeout; - // Ping Timeout - nextTimeout = checkResendTimeouts(); + uint64_t resendTimeout = checkResendTimeouts(); // Message Timeout uint64_t messageTimeout = checkMessageTimeouts(); - nextTimeout = nextTimeout < messageTimeout ? nextTimeout : messageTimeout; - return nextTimeout; + return std::min(resendTimeout, messageTimeout); } /** @@ -346,7 +323,7 @@ Receiver::Message::acknowledge() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_done_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -355,7 +332,7 @@ Receiver::Message::acknowledge() const bool Receiver::Message::dropped() const { - return state.load() == State::DROPPED; + return getState() == State::DROPPED; } /** @@ -367,7 +344,7 @@ Receiver::Message::fail() const MessageBucket* bucket = receiver->messageBuckets.getBucket(id); SpinLock::Lock lock(bucket->mutex); Perf::counters.tx_error_pkts.add(1); - ControlPacket::send(driver, source, id); + ControlPacket::send(driver, source.ip, id); } /** @@ -397,7 +374,7 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const while (bytesCopied < _count) { uint32_t bytesToCopy = std::min(_count - bytesCopied, PACKET_DATA_LENGTH - packetOffset); - Driver::Packet* packet = getPacket(packetIndex); + const Driver::Packet* packet = getPacket(packetIndex); if (packet != nullptr) { char* source = static_cast(packet->payload); source += packetOffset + TRANSPORT_HEADER_LENGTH; @@ -415,6 +392,15 @@ Receiver::Message::get(size_t offset, void* destination, size_t count) const return bytesCopied; } +/** + * @copydoc Homa::InMessage::getSourceAddress() + */ +SocketAddress +Receiver::Message::getSourceAddress() const +{ + return source; +} + /** * @copydoc Homa::InMessage::length() */ @@ -451,11 +437,11 @@ Receiver::Message::release() * @return * Pointer to a Packet at the given index if it exists; nullptr otherwise. */ -Driver::Packet* +const Driver::Packet* Receiver::Message::getPacket(size_t index) const { if (occupied.test(index)) { - return packets[index]; + return &packets[index]; } return nullptr; } @@ -471,7 +457,7 @@ Receiver::Message::getPacket(size_t index) const * The Packet's index in the array of packets that form the message. * "packet index = "packet message offset" / PACKET_DATA_LENGTH * @param packet - * The packet pointer that should be stored. + * The packet that should be stored. * @return * True if the packet was stored; false if a packet already exists (the new * packet is not stored). @@ -482,12 +468,31 @@ Receiver::Message::setPacket(size_t index, Driver::Packet* packet) if (occupied.test(index)) { return false; } - packets[index] = packet; + packets[index] = *packet; occupied.set(index); numPackets++; return true; } +/** + * Clear the atomic _dontNeedGrants_ flag to indicate that trySendGrants() + * needs to run again. This method is called when the state of active messages + * in Receiver::scheduledPeers might have changed. + * + * Note: we require the caller to hold the schedulerMutex during this call + * because it becomes much easier to reason about the interaction between + * the atomic flag and the mutex this way (and it's essentially free). + * + * @param lockHeld + * Reminder to hold the Receiver::schedulerMutex during this call. + */ +void +Receiver::signalNeedGrants(const SpinLock::Lock& lockHeld) +{ + (void)lockHeld; + dontNeedGrants.clear(std::memory_order_release); +} + /** * Inform the Receiver that an Message returned by receiveMessage() is not * needed and can be dropped. @@ -563,7 +568,7 @@ Receiver::checkMessageTimeouts() bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->resendTimeouts.cancelTimeout(&message->resendTimeout); - if (message->state == Message::State::IN_PROGRESS) { + if (message->getState() == Message::State::IN_PROGRESS) { // Message timed out before being fully received; drop the // message. @@ -586,7 +591,7 @@ Receiver::checkMessageTimeouts() } else { // Message timed out but we already made it available to the // Transport; let the Transport know. - message->state.store(Message::State::DROPPED); + message->setState(Message::State::DROPPED); } } globalNextTimeout = std::min(globalNextTimeout, nextTimeout); @@ -628,7 +633,7 @@ Receiver::checkResendTimeouts() } // Found expired timeout. - assert(message->state == Message::State::IN_PROGRESS); + assert(message->getState() == Message::State::IN_PROGRESS); bucket->resendTimeouts.setTimeout(&message->resendTimeout); // This Receiver expected to have heard from the Sender within the @@ -678,7 +683,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -691,7 +696,7 @@ Receiver::checkResendTimeouts() SpinLock::Lock lock_scheduler(schedulerMutex); Perf::counters.tx_resend_pkts.add(1); ControlPacket::send( - message->driver, message->source, message->id, + message->driver, message->source.ip, message->id, Util::downCast(index), Util::downCast(num), message->scheduledMessageInfo.priority); @@ -704,23 +709,30 @@ Receiver::checkResendTimeouts() /** * Send GRANTs to incoming Message according to the Receiver's policy. + * + * This method must be called eagerly to allow the Receiver to make progress + * toward receiving incoming messages. + * + * @return + * True if the method has found some messages to grant; false, otherwise. */ -void +bool Receiver::trySendGrants() { uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); - bool idle = true; - // Skip scheduling if another poller is already working on it. - if (granting.test_and_set()) { - return; + // Fast path: skip if no message is waiting for grants + bool needGrants = !dontNeedGrants.test_and_set(std::memory_order_acquire); + if (!needGrants) { + return false; } + /* It's possible to have a benign race-condition here when another thread + * acquires the schedulerMutex before us and sets _dontNeedGrants_ back to + * false via signalNeedGrants. As a result, _dontNeedGrants_ will stay false + * when the method returns although all messages have been granted. + */ SpinLock::Lock lock(schedulerMutex); - if (scheduledPeers.empty()) { - granting.clear(); - return; - } /* The overall goal is to grant up to policy.degreeOvercommitment number of * scheduled messages simultaneously. Each of these messages should always @@ -742,13 +754,14 @@ Receiver::trySendGrants() auto it = scheduledPeers.begin(); int slot = 0; + bool foundWork = false; while (it != scheduledPeers.end() && slot < policy.degreeOvercommitment) { assert(!it->scheduledMessages.empty()); Message* message = &it->scheduledMessages.front(); ScheduledMessageInfo* info = &message->scheduledMessageInfo; // Access message const variables without message mutex. const Protocol::MessageId id = message->id; - const Driver::Address source = message->source; + const IpAddress sourceIp = message->source.ip; // Recalculate message priority info->priority = @@ -757,7 +770,6 @@ Receiver::trySendGrants() // Send a GRANT if there are too few bytes granted and unreceived. int receivedBytes = info->messageLength - info->bytesRemaining; if (info->bytesGranted - receivedBytes < policy.minScheduledBytes) { - idle = false; // Calculate new grant limit int newGrantLimit = std::min( receivedBytes + policy.maxScheduledBytes, info->messageLength); @@ -765,8 +777,9 @@ Receiver::trySendGrants() info->bytesGranted = newGrantLimit; Perf::counters.tx_grant_pkts.add(1); ControlPacket::send( - driver, source, id, + driver, sourceIp, id, Util::downCast(info->bytesGranted), info->priority); + foundWork = true; } // Update the iterator first since calling unschedule() may cause the @@ -781,14 +794,13 @@ Receiver::trySendGrants() ++slot; } - granting.clear(); - uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; - if (!idle) { + if (foundWork) { Perf::counters.active_cycles.add(elapsed_cycles); } else { Perf::counters.idle_cycles.add(elapsed_cycles); } + return foundWork; } /** @@ -806,7 +818,7 @@ Receiver::schedule(Receiver::Message* message, const SpinLock::Lock& lock) { (void)lock; ScheduledMessageInfo* info = &message->scheduledMessageInfo; - Peer* peer = &peerTable[message->source]; + Peer* peer = &peerTable[message->source.ip]; // Insert the Message peer->scheduledMessages.push_front(&info->scheduledMessageNode); Intrusive::deprioritize(&peer->scheduledMessages, @@ -871,6 +883,9 @@ Receiver::unschedule(Receiver::Message* message, const SpinLock::Lock& lock) Intrusive::deprioritize(&scheduledPeers, &peer->scheduledPeerNode, comp); } + + // scheduledPeers has been updated; time to run trySendGrants() again + signalNeedGrants(lock); } /** diff --git a/src/Receiver.h b/src/Receiver.h index 444e1aa..1a425f1 100644 --- a/src/Receiver.h +++ b/src/Receiver.h @@ -16,10 +16,11 @@ #ifndef HOMA_CORE_RECEIVER_H #define HOMA_CORE_RECEIVER_H +#include #include -#include #include +#include #include #include @@ -30,6 +31,7 @@ #include "Protocol.h" #include "SpinLock.h" #include "Timeout.h" +#include "Util.h" namespace Homa { namespace Core { @@ -42,16 +44,16 @@ namespace Core { */ class Receiver { public: - explicit Receiver(Driver* driver, Policy::Manager* policyManager, + explicit Receiver(Driver* driver, Transport::Callbacks* callbacks, + Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t resendIntervalCycles); virtual ~Receiver(); - virtual void handleDataPacket(Driver::Packet* packet, Driver* driver); - virtual void handleBusyPacket(Driver::Packet* packet, Driver* driver); - virtual void handlePingPacket(Driver::Packet* packet, Driver* driver); - virtual Homa::InMessage* receiveMessage(); - virtual void poll(); + virtual void handleDataPacket(Driver::Packet* packet, IpAddress sourceIp); + virtual void handleBusyPacket(Driver::Packet* packet); + virtual void handlePingPacket(Driver::Packet* packet, IpAddress sourceIp); virtual uint64_t checkTimeouts(); + virtual bool trySendGrants(); private: // Forward declaration @@ -117,7 +119,7 @@ class Receiver { * Represents an incoming message that is being assembled or being processed * by the application. */ - class Message : public Homa::InMessage { + class Message final : public Homa::InMessage { public: /** * Defines the possible states of this Message. @@ -132,7 +134,7 @@ class Receiver { explicit Message(Receiver* receiver, Driver* driver, size_t packetHeaderLength, size_t messageLength, - Protocol::MessageId id, Driver::Address source, + Protocol::MessageId id, SocketAddress source, int numUnscheduledPackets) : receiver(receiver) , driver(driver) @@ -153,7 +155,6 @@ class Receiver { // construction. See Message::occupied. , state(Message::State::IN_PROGRESS) , bucketNode(this) - , receivedMessageNode(this) , messageTimeout(this) , resendTimeout(this) , scheduledMessageInfo(this, messageLength) @@ -165,23 +166,35 @@ class Receiver { virtual void fail() const; virtual size_t get(size_t offset, void* destination, size_t count) const; + virtual SocketAddress getSourceAddress() const; virtual size_t length() const; virtual void strip(size_t count); virtual void release(); + private: /** * Return the current state of this message. */ State getState() const { - return state.load(); + return state.load(std::memory_order_acquire); + } + + /** + * Change the current state of this message. + * + * @param newState + * The new state of the message + */ + void setState(State newState) + { + state.store(newState, std::memory_order_release); } - private: /// Define the maximum number of packets that a message can hold. static const int MAX_MESSAGE_PACKETS = 1024; - Driver::Packet* getPacket(size_t index) const; + const Driver::Packet* getPacket(size_t index) const; bool setPacket(size_t index, Driver::Packet* packet); /// The Receiver responsible for this message. @@ -195,7 +208,7 @@ class Receiver { const Protocol::MessageId id; /// Contains source address this message. - const Driver::Address source; + const SocketAddress source; /// Number of bytes at the beginning of each Packet that should be /// reserved for the Homa transport header. @@ -229,7 +242,7 @@ class Receiver { /// Collection of Packet objects that make up this context's Message. /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + Driver::Packet packets[MAX_MESSAGE_PACKETS]; /// This message's current state. std::atomic state; @@ -239,10 +252,6 @@ class Receiver { /// is protected by the associated MessageBucket::mutex; Intrusive::List::Node bucketNode; - /// Intrusive structure used by the Receiver to keep track of this - /// message when it has been completely received. - Intrusive::List::Node receivedMessageNode; - /// Intrusive structure used by the Receiver to keep track when the /// receiving of this message should be considered failed. Timeout messageTimeout; @@ -449,14 +458,17 @@ class Receiver { Intrusive::List::Node scheduledPeerNode; }; + void signalNeedGrants(const SpinLock::Lock& lockHeld); void dropMessage(Receiver::Message* message); uint64_t checkMessageTimeouts(); uint64_t checkResendTimeouts(); - void trySendGrants(); void schedule(Message* message, const SpinLock::Lock& lock); void unschedule(Message* message, const SpinLock::Lock& lock); void updateSchedule(Message* message, const SpinLock::Lock& lock); + /// User-defined transport callbacks. Not owned by this class. + Transport::Callbacks* const callbacks; + /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. Driver* const driver; @@ -473,24 +485,16 @@ class Receiver { /// Collection of all peers; used for fast access. Access is protected by /// the schedulerMutex. - std::unordered_map peerTable; + std::unordered_map peerTable; /// List of peers with inbound messages that require grants to complete. /// Access is protected by the schedulerMutex. Intrusive::List scheduledPeers; - /// Message objects to be processed by the transport. - struct { - /// Protects the receivedMessage.queue - SpinLock mutex; - /// List of completely received messages. - Intrusive::List queue; - } receivedMessages; - - /// True if the Receiver is executing trySendGrants(); false, otherwise. - /// Used to prevent concurrent calls to trySendGrants() from blocking on - /// each other. - std::atomic_flag granting = ATOMIC_FLAG_INIT; + /// Hint whether there MIGHT be messages that need to be granted. Encoded + /// into an atomic bool so that checking if there is work to do can be done + /// efficiently without acquiring the schedulerMutex first. + std::atomic_flag dontNeedGrants; /// Used to allocate Message objects. struct { diff --git a/src/ReceiverTest.cc b/src/ReceiverTest.cc index a49aee2..0db4d1d 100644 --- a/src/ReceiverTest.cc +++ b/src/ReceiverTest.cc @@ -34,24 +34,68 @@ using ::testing::InSequence; using ::testing::Matcher; using ::testing::Mock; using ::testing::NiceMock; -using ::testing::Pointee; using ::testing::Return; +/// Helper macro to construct an IpAddress from a numeric number. +#define IP(x) \ + IpAddress \ + { \ + x \ + } + +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + +/** + * Defines a matcher EqPacketLen(p) to match a Driver::Packet* by its length. + */ +MATCHER_P(EqPacketLen, length, "") +{ + return arg->length == length; +} + +class MockCallbacks : public Transport::Callbacks { + public: + explicit MockCallbacks() + : receivedMessage() + {} + + bool deliver(uint16_t port, Homa::unique_ptr message) override + { + if (port != 60001) { + return false; + } + receivedMessage = message.release(); + return true; + } + + InMessage* receivedMessage; +}; + class ReceiverTest : public ::testing::Test { public: ReceiverTest() - : mockDriver() - , mockPacket(&payload) + : mockCallbacks() + , mockDriver() + , mockPacket() , mockPolicyManager(&mockDriver) , payload() + , packetBuf{&payload} , receiver() , savedLogPolicy(Debug::getLogPolicy()) { ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + mockPacket = packetBuf.toPacket(); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); - receiver = new Receiver(&mockDriver, &mockPolicyManager, + receiver = new Receiver(&mockDriver, &mockCallbacks, &mockPolicyManager, messageTimeoutCycles, resendIntervalCycles); PerfUtils::Cycles::mockTscValue = 10000; } @@ -67,10 +111,12 @@ class ReceiverTest : public ::testing::Test { static const uint64_t messageTimeoutCycles = 1000; static const uint64_t resendIntervalCycles = 100; + MockCallbacks mockCallbacks; NiceMock mockDriver; - NiceMock mockPacket; + Driver::Packet mockPacket; NiceMock mockPolicyManager; char payload[1028]; + Homa::Mock::MockDriver::PacketBuf packetBuf; Receiver* receiver; std::vector> savedLogPolicy; }; @@ -98,28 +144,24 @@ TEST_F(ReceiverTest, handleDataPacket) Receiver::ScheduledMessageInfo* info = nullptr; Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); + new (mockPacket.payload) Protocol::Packet::DataHeader( + 0, 60001, id, totalMessageLength, policyVersion, 1, 0); Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); - header->common.opcode = Protocol::Packet::DATA; - header->common.messageId = id; - header->totalLength = totalMessageLength; - header->policyVersion = policyVersion; - header->unscheduledIndexLimit = 1; - mockPacket.address = Driver::Address(22); + IpAddress sourceIp{22}; // ------------------------------------------------------------------------- // Receive packet[1]. New message. header->index = 1; mockPacket.length = HEADER_SIZE + 1000; EXPECT_CALL(mockPolicyManager, - signalNewMessage(Eq(mockPacket.address), Eq(policyVersion), + signalNewMessage(Eq(sourceIp), Eq(policyVersion), Eq(totalMessageLength))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- { @@ -144,11 +186,10 @@ TEST_F(ReceiverTest, handleDataPacket) // ------------------------------------------------------------------------- // Receive packet[1]. Duplicate. - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(1U, message->numPackets); @@ -158,11 +199,10 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[2]. header->index = 2; mockPacket.length = HEADER_SIZE + 1000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(2U, message->numPackets); @@ -173,11 +213,10 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[3]. header->index = 3; mockPacket.length = HEADER_SIZE + 500; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(3U, message->numPackets); @@ -188,26 +227,24 @@ TEST_F(ReceiverTest, handleDataPacket) // Receive packet[0]. Finished. header->index = 0; mockPacket.length = HEADER_SIZE + 1000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(0); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- EXPECT_EQ(4U, message->numPackets); EXPECT_EQ(0U, info->bytesRemaining); EXPECT_EQ(Receiver::Message::State::COMPLETED, message->state); - EXPECT_EQ(message, &receiver->receivedMessages.queue.back()); + EXPECT_EQ(message, mockCallbacks.receivedMessage); Mock::VerifyAndClearExpectations(&mockDriver); // ------------------------------------------------------------------------- // Receive packet[0]. Already finished. - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); // TEST CALL - receiver->handleDataPacket(&mockPacket, &mockDriver); + receiver->handleDataPacket(&mockPacket, sourceIp); // --------- Mock::VerifyAndClearExpectations(&mockDriver); @@ -217,7 +254,7 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) { Protocol::MessageId id(42, 32); Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(0), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); bucket->messages.push_back(&message->bucketNode); @@ -225,10 +262,9 @@ TEST_F(ReceiverTest, handleBusyPacket_basic) (Protocol::Packet::BusyHeader*)mockPacket.payload; busyHeader->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->resendTimeout.expirationCycleTime); @@ -242,18 +278,17 @@ TEST_F(ReceiverTest, handleBusyPacket_unknown) (Protocol::Packet::BusyHeader*)mockPacket.payload; busyHeader->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - receiver->handleBusyPacket(&mockPacket, &mockDriver); + receiver->handleBusyPacket(&mockPacket); } TEST_F(ReceiverTest, handlePingPacket_basic) { Protocol::MessageId id(42, 32); - Driver::Address mockAddress = 22; + IpAddress mockAddress{22}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 20000, id, mockAddress, 0); + receiver, &mockDriver, 0, 20000, id, SocketAddress{mockAddress, 0}, 0); ASSERT_TRUE(message->scheduled); Receiver::ScheduledMessageInfo* info = &message->scheduledMessageInfo; info->bytesGranted = 500; @@ -263,25 +298,28 @@ TEST_F(ReceiverTest, handlePingPacket_basic) bucket->messages.push_back(&message->bucketNode); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = mockAddress; + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{pingPayload}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(); + IpAddress sourceIp = mockAddress; Protocol::Packet::PingHeader* pingHeader = - (Protocol::Packet::PingHeader*)pingPacket.payload; + (Protocol::Packet::PingHeader*)pingPacketBuf.buffer; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, + sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, sourceIp); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(0U, message->resendTimeout.expirationCycleTime); - EXPECT_EQ(mockAddress, mockPacket.address); Protocol::Packet::GrantHeader* header = (Protocol::Packet::GrantHeader*)payload; EXPECT_EQ(Protocol::Packet::GRANT, header->common.opcode); @@ -295,61 +333,36 @@ TEST_F(ReceiverTest, handlePingPacket_unknown) Protocol::MessageId id(42, 32); char pingPayload[1028]; - Homa::Mock::MockDriver::MockPacket pingPacket(pingPayload); - pingPacket.address = (Driver::Address)22; + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{pingPayload}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(); + IpAddress mockAddress{22}; Protocol::Packet::PingHeader* pingHeader = (Protocol::Packet::PingHeader*)pingPacket.payload; pingHeader->common.messageId = id; - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, + sendPacket(EqPacket(&mockPacket), Eq(mockAddress), _)) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&pingPacket), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&pingPacket), Eq(1))) .Times(1); - receiver->handlePingPacket(&pingPacket, &mockDriver); + receiver->handlePingPacket(&pingPacket, mockAddress); - EXPECT_EQ(pingPacket.address, mockPacket.address); Protocol::Packet::UnknownHeader* header = (Protocol::Packet::UnknownHeader*)payload; EXPECT_EQ(Protocol::Packet::UNKNOWN, header->common.opcode); EXPECT_EQ(id, header->common.messageId); } -TEST_F(ReceiverTest, receiveMessage) -{ - Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); - Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, Protocol::MessageId(42, 0), - Driver::Address(22), 0); - - receiver->receivedMessages.queue.push_back(&msg0->receivedMessageNode); - receiver->receivedMessages.queue.push_back(&msg1->receivedMessageNode); - EXPECT_FALSE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(msg0, receiver->receiveMessage()); - EXPECT_FALSE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(msg1, receiver->receiveMessage()); - EXPECT_TRUE(receiver->receivedMessages.queue.empty()); - - EXPECT_EQ(nullptr, receiver->receiveMessage()); - EXPECT_TRUE(receiver->receivedMessages.queue.empty()); -} - -TEST_F(ReceiverTest, poll) -{ - // Nothing to test - receiver->poll(); -} - TEST_F(ReceiverTest, checkTimeouts) { Receiver::Message message(receiver, &mockDriver, 0, 0, - Protocol::MessageId(0, 0), Driver::Address(0), 0); + Protocol::MessageId(0, 0), + SocketAddress{0, 60001}, 0); Receiver::MessageBucket* bucket = receiver->messageBuckets.buckets.at(0); bucket->resendTimeouts.setTimeout(&message.resendTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -373,7 +386,7 @@ TEST_F(ReceiverTest, Message_destructor_basic) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 5; @@ -392,7 +405,7 @@ TEST_F(ReceiverTest, Message_destructor_holes) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); const uint16_t NUM_PKTS = 4; @@ -414,11 +427,15 @@ TEST_F(ReceiverTest, Message_acknowledge) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, + sendPacket(EqPacketLen(sizeof(Protocol::Packet::DoneHeader)), + Eq(message->source.ip), _)) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); message->acknowledge(); @@ -427,15 +444,13 @@ TEST_F(ReceiverTest, Message_acknowledge) static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::DONE, header->opcode); EXPECT_EQ(id, header->messageId); - EXPECT_EQ(sizeof(Protocol::Packet::DoneHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_dropped) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->state = Receiver::Message::State::IN_PROGRESS; @@ -450,11 +465,15 @@ TEST_F(ReceiverTest, Message_fail) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, + sendPacket(EqPacketLen(sizeof(Protocol::Packet::ErrorHeader)), + Eq(message->source.ip), _)) + .Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); message->fail(); @@ -463,8 +482,6 @@ TEST_F(ReceiverTest, Message_fail) static_cast(mockPacket.payload); EXPECT_EQ(Protocol::Packet::ERROR, header->opcode); EXPECT_EQ(id, header->messageId); - EXPECT_EQ(sizeof(Protocol::Packet::ErrorHeader), mockPacket.length); - EXPECT_EQ(message->source, mockPacket.address); } TEST_F(ReceiverTest, Message_get_basic) @@ -472,10 +489,12 @@ TEST_F(ReceiverTest, Message_get_basic) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); char source[] = "Hello, world!"; message->setPacket(0, &packet0); @@ -499,10 +518,12 @@ TEST_F(ReceiverTest, Message_get_offsetTooLarge) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); message->setPacket(0, &packet0); message->setPacket(1, &packet1); @@ -525,10 +546,12 @@ TEST_F(ReceiverTest, Message_get_missingPacket) ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 24, 24 + 2007, id, Driver::Address(22), 0); + receiver, &mockDriver, 24, 24 + 2007, id, SocketAddress{22, 60001}, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); char source[] = "Hello,"; message->setPacket(0, &packet0); @@ -557,7 +580,7 @@ TEST_F(ReceiverTest, Message_length) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 200; message->start = 20; EXPECT_EQ(180U, message->length()); @@ -567,7 +590,7 @@ TEST_F(ReceiverTest, Message_strip) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); message->messageLength = 30; message->start = 0; @@ -589,10 +612,10 @@ TEST_F(ReceiverTest, Message_getPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); - Driver::Packet* packet = (Driver::Packet*)42; - message->packets[0] = packet; + message->packets[0] = {}; + Driver::Packet* packet = &message->packets[0]; EXPECT_EQ(nullptr, message->getPacket(0)); @@ -605,15 +628,16 @@ TEST_F(ReceiverTest, Message_setPacket) { Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 0, id, Driver::Address(22), 0); - Driver::Packet* packet = (Driver::Packet*)42; + receiver, &mockDriver, 0, 0, id, SocketAddress{22, 60001}, 0); + Driver::Packet* packet = &mockPacket; EXPECT_FALSE(message->occupied.test(0)); EXPECT_EQ(0U, message->numPackets); EXPECT_TRUE(message->setPacket(0, packet)); - EXPECT_EQ(packet, message->packets[0]); + EXPECT_EQ(packet->descriptor, message->packets[0].descriptor); + EXPECT_EQ(packet->payload, message->packets[0].payload); EXPECT_TRUE(message->occupied.test(0)); EXPECT_EQ(1U, message->numPackets); @@ -626,12 +650,12 @@ TEST_F(ReceiverTest, MessageBucket_findMessage) Protocol::MessageId id0 = {42, 0}; Receiver::Message* msg0 = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, 0, - 0); + receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id0, + SocketAddress{0, 60001}, 0); Protocol::MessageId id1 = {42, 1}; Receiver::Message* msg1 = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 0, id1, - Driver::Address(0), 0); + SocketAddress{0, 60001}, 0); Protocol::MessageId id_none = {42, 42}; bucket->messages.push_back(&msg0->bucketNode); @@ -659,7 +683,7 @@ TEST_F(ReceiverTest, dropMessage) SpinLock::Lock dummy(dummyMutex); Protocol::MessageId id = {42, 32}; Receiver::Message* message = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, Driver::Address(22), 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{22, 60001}, 0); ASSERT_TRUE(message->scheduled); Receiver::MessageBucket* bucket = receiver->messageBuckets.getBucket(id); @@ -670,7 +694,7 @@ TEST_F(ReceiverTest, dropMessage) EXPECT_EQ(1U, receiver->messageAllocator.pool.outstandingObjects); EXPECT_EQ(message, bucket->findMessage(id, dummy)); - EXPECT_EQ(&receiver->peerTable[message->source], + EXPECT_EQ(&receiver->peerTable[message->source.ip], message->scheduledMessageInfo.peer); EXPECT_FALSE(bucket->messageTimeouts.list.empty()); EXPECT_FALSE(bucket->resendTimeouts.list.empty()); @@ -693,7 +717,7 @@ TEST_F(ReceiverTest, checkMessageTimeouts_basic) Protocol::MessageId id = {42, 10 + i}; op[i] = reinterpret_cast(i); message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 1000, id, 0, 0); + receiver, &mockDriver, 0, 1000, id, SocketAddress{0, 60001}, 0); bucket->messages.push_back(&message[i]->bucketNode); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); @@ -767,7 +791,7 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( - receiver, &mockDriver, 0, 10000, id, Driver::Address(22), 5); + receiver, &mockDriver, 0, 10000, id, SocketAddress{22, 60001}, 5); bucket->resendTimeouts.setTimeout(&message[i]->resendTimeout); } @@ -803,17 +827,21 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) char buf1[1024]; char buf2[1024]; - Homa::Mock::MockDriver::MockPacket mockResendPacket1(buf1); - Homa::Mock::MockDriver::MockPacket mockResendPacket2(buf2); - - EXPECT_CALL(mockDriver, allocPacket()) - .WillOnce(Return(&mockResendPacket1)) - .WillOnce(Return(&mockResendPacket2)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket1))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockResendPacket2))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket1), Eq(1))) + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf1}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf2}; + Driver::Packet mockResendPacket1 = packetBuf0.toPacket(); + Driver::Packet mockResendPacket2 = packetBuf1.toPacket(); + const size_t RESEND_HEADER_LEN = sizeof(Protocol::Packet::ResendHeader); + + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&mockResendPacket1](auto p) { *p = mockResendPacket1; }) + .WillOnce([&mockResendPacket2](auto p) { *p = mockResendPacket2; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacketLen(RESEND_HEADER_LEN), + Eq(message[0]->source.ip), _)) + .Times(2); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockResendPacket1), Eq(1))) .Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockResendPacket2), Eq(1))) + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockResendPacket2), Eq(1))) .Times(1); // TEST CALL @@ -829,16 +857,12 @@ TEST_F(ReceiverTest, checkResendTimeouts_basic) EXPECT_EQ(message[0]->id, header1->common.messageId); EXPECT_EQ(2U, header1->index); EXPECT_EQ(4U, header1->num); - EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket1.length); - EXPECT_EQ(message[0]->source, mockResendPacket1.address); Protocol::Packet::ResendHeader* header2 = static_cast(mockResendPacket2.payload); EXPECT_EQ(Protocol::Packet::RESEND, header2->common.opcode); EXPECT_EQ(message[0]->id, header2->common.messageId); EXPECT_EQ(8U, header2->index); EXPECT_EQ(2U, header2->num); - EXPECT_EQ(sizeof(Protocol::Packet::ResendHeader), mockResendPacket2.length); - EXPECT_EQ(message[0]->source, mockResendPacket2.address); // Message[1]: Blocked on grants EXPECT_EQ(10100, message[1]->resendTimeout.expirationCycleTime); @@ -863,11 +887,12 @@ TEST_F(ReceiverTest, trySendGrants) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - for (uint64_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10000 * (i + 1), id, Driver::Address(100 + i), 10 * (i + 1)); + 10000 * (i + 1), id, SocketAddress{IP(100 + i), 60001}, + 10 * (i + 1)); { SpinLock::Lock lock_scheduler(receiver->schedulerMutex); receiver->schedule(message[i], lock_scheduler); @@ -893,11 +918,13 @@ TEST_F(ReceiverTest, trySendGrants) info[0]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(1, info[0]->priority); @@ -919,11 +946,13 @@ TEST_F(ReceiverTest, trySendGrants) info[1]->bytesRemaining -= 1000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(0, info[1]->priority); @@ -941,8 +970,9 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(1, info[1]->priority); @@ -960,8 +990,9 @@ TEST_F(ReceiverTest, trySendGrants) policy.maxScheduledBytes = 10000; EXPECT_CALL(mockPolicyManager, getScheduledPolicy()) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(_)).Times(0); + EXPECT_CALL(mockDriver, sendPacket(_, _, _)).Times(0); + receiver->dontNeedGrants.clear(); receiver->trySendGrants(); EXPECT_EQ(2, info[1]->priority); @@ -975,13 +1006,13 @@ TEST_F(ReceiverTest, schedule) { Receiver::Message* message[4]; Receiver::ScheduledMessageInfo* info[4]; - Driver::Address address[4] = {22, 33, 33, 22}; + IpAddress address[4] = {22, 33, 33, 22}; int messageLength[4] = {2000, 3000, 1000, 4000}; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, address[i], 0); + messageLength[i], id, SocketAddress{address[i], 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; } @@ -994,7 +1025,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[0], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[0]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[0]->peer); EXPECT_EQ(message[0], &info[0]->peer->scheduledMessages.front()); EXPECT_EQ(info[0]->peer, &receiver->scheduledPeers.front()); @@ -1006,7 +1037,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[1], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[1]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[1]->peer); EXPECT_EQ(message[1], &info[1]->peer->scheduledMessages.front()); EXPECT_EQ(info[1]->peer, &receiver->scheduledPeers.back()); @@ -1018,7 +1049,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[2], lock); - EXPECT_EQ(&receiver->peerTable.at(33), info[2]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(33)), info[2]->peer); EXPECT_EQ(message[2], &info[2]->peer->scheduledMessages.front()); EXPECT_EQ(info[2]->peer, &receiver->scheduledPeers.front()); @@ -1030,7 +1061,7 @@ TEST_F(ReceiverTest, schedule) receiver->schedule(message[3], lock); - EXPECT_EQ(&receiver->peerTable.at(22), info[3]->peer); + EXPECT_EQ(&receiver->peerTable.at(IP(22)), info[3]->peer); EXPECT_EQ(message[3], &info[3]->peer->scheduledMessages.back()); EXPECT_EQ(info[3]->peer, &receiver->scheduledPeers.back()); } @@ -1041,23 +1072,24 @@ TEST_F(ReceiverTest, unschedule) Receiver::ScheduledMessageInfo* info[5]; SpinLock::Lock lock(receiver->schedulerMutex); int messageLength[5] = {10, 20, 30, 10, 20}; - for (uint64_t i = 0; i < 5; ++i) { + for (uint32_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - Driver::Address source = Driver::Address((i / 3) + 10); + IpAddress source = IP((i / 3) + 10); message[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - messageLength[i], id, source, 0); + messageLength[i], id, SocketAddress{source, 60001}, 0); info[i] = &message[i]->scheduledMessageInfo; receiver->schedule(message[i], lock); } + auto& scheduledPeers = receiver->scheduledPeers; - ASSERT_EQ(Driver::Address(10), message[0]->source); - ASSERT_EQ(Driver::Address(10), message[1]->source); - ASSERT_EQ(Driver::Address(10), message[2]->source); - ASSERT_EQ(Driver::Address(11), message[3]->source); - ASSERT_EQ(Driver::Address(11), message[4]->source); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + ASSERT_EQ(IP(10), message[0]->source.ip); + ASSERT_EQ(IP(10), message[1]->source.ip); + ASSERT_EQ(IP(10), message[2]->source.ip); + ASSERT_EQ(IP(11), message[3]->source.ip); + ASSERT_EQ(IP(11), message[4]->source.ip); + ASSERT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + ASSERT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); // <10>: [0](10) -> [1](20) -> [2](30) // <11>: [3](10) -> [4](20) @@ -1075,10 +1107,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[4], lock); EXPECT_EQ(nullptr, info[4]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(3U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(3U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[1]; peer in correct position. @@ -1088,10 +1120,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[1], lock); EXPECT_EQ(nullptr, info[1]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(2U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(2U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[0]; peer needs to be reordered. @@ -1101,10 +1133,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[0], lock); EXPECT_EQ(nullptr, info[0]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(11).scheduledMessages.size()); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); //-------------------------------------------------------------------------- // Remove message[3]; peer needs to be removed. @@ -1113,10 +1145,10 @@ TEST_F(ReceiverTest, unschedule) receiver->unschedule(message[3], lock); EXPECT_EQ(nullptr, info[3]->peer); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(10)); - EXPECT_EQ(1U, receiver->peerTable.at(10).scheduledMessages.size()); - EXPECT_EQ(0U, receiver->peerTable.at(11).scheduledMessages.size()); + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(10))); + EXPECT_EQ(1U, receiver->peerTable.at(IP(10)).scheduledMessages.size()); + EXPECT_EQ(0U, receiver->peerTable.at(IP(11)).scheduledMessages.size()); } TEST_F(ReceiverTest, updateSchedule) @@ -1125,25 +1157,26 @@ TEST_F(ReceiverTest, updateSchedule) // 11 : [20][30] SpinLock::Lock lock(receiver->schedulerMutex); Receiver::Message* other[3]; - for (uint64_t i = 0; i < 3; ++i) { + for (uint32_t i = 0; i < 3; ++i) { Protocol::MessageId id = {42, 10 + i}; int messageLength = 10 * (i + 1); - Driver::Address source = Driver::Address(((i + 1) / 2) + 10); + IpAddress source = IP(((i + 1) / 2) + 10); other[i] = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), - 10 * (i + 1), id, source, 0); + 10 * (i + 1), id, SocketAddress{source, 60001}, 0); receiver->schedule(other[i], lock); } Receiver::Message* message = receiver->messageAllocator.pool.construct( receiver, &mockDriver, sizeof(Protocol::Packet::DataHeader), 100, - Protocol::MessageId(42, 1), Driver::Address(11), 0); + Protocol::MessageId(42, 1), SocketAddress{11, 60001}, 0); receiver->schedule(message, lock); - ASSERT_EQ(&receiver->peerTable.at(10), other[0]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[1]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), other[2]->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->peerTable.at(11), message->scheduledMessageInfo.peer); - ASSERT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(10)); - ASSERT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& peerTable = receiver->peerTable; + ASSERT_EQ(&peerTable.at(IP(10)), other[0]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[1]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), other[2]->scheduledMessageInfo.peer); + ASSERT_EQ(&peerTable.at(IP(11)), message->scheduledMessageInfo.peer); + ASSERT_EQ(&receiver->scheduledPeers.front(), &peerTable.at(IP(10))); + ASSERT_EQ(&receiver->scheduledPeers.back(), &peerTable.at(IP(11))); //-------------------------------------------------------------------------- // Move message up within peer. @@ -1153,11 +1186,12 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); + auto& scheduledPeers = receiver->scheduledPeers; + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); Receiver::Peer* peer = &receiver->scheduledPeers.back(); auto it = peer->scheduledMessages.begin(); EXPECT_TRUE( - std::next(receiver->peerTable.at(11).scheduledMessages.begin()) == + std::next(receiver->peerTable.at(IP(11)).scheduledMessages.begin()) == message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1169,8 +1203,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.back(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.back(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); @@ -1182,8 +1216,8 @@ TEST_F(ReceiverTest, updateSchedule) receiver->updateSchedule(message, lock); - EXPECT_EQ(&receiver->scheduledPeers.front(), &receiver->peerTable.at(11)); - EXPECT_EQ(receiver->peerTable.at(11).scheduledMessages.begin(), + EXPECT_EQ(&scheduledPeers.front(), &receiver->peerTable.at(IP(11))); + EXPECT_EQ(receiver->peerTable.at(IP(11)).scheduledMessages.begin(), message->scheduledMessageInfo.peer->scheduledMessages.get( &message->scheduledMessageInfo.scheduledMessageNode)); } diff --git a/src/Sender.cc b/src/Sender.cc index c2d0c3f..43c7600 100644 --- a/src/Sender.cc +++ b/src/Sender.cc @@ -33,6 +33,8 @@ namespace Core { * Unique identifier for the Transport that owns this Sender. * @param driver * The driver used to send and receive packets. + * @param callbacks + * Collections of user-defined transport callbacks. * @param policyManager * Provides information about the network packet priority policies. * @param messageTimeoutCycles @@ -43,18 +45,20 @@ namespace Core { * of an Sender::Message. */ Sender::Sender(uint64_t transportId, Driver* driver, - Policy::Manager* policyManager, uint64_t messageTimeoutCycles, - uint64_t pingIntervalCycles) + Transport::Callbacks* callbacks, Policy::Manager* policyManager, + uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles) : transportId(transportId) + , callbacks(callbacks) , driver(driver) , policyManager(policyManager) , nextMessageSequenceNumber(1) , DRIVER_QUEUED_BYTE_LIMIT(2 * driver->getMaxPayloadSize()) + , DRIVER_CYCLES_TO_DRAIN_1MB(PerfUtils::Cycles::fromSeconds(1) * 8 / + driver->getBandwidth()) , messageBuckets(messageTimeoutCycles, pingIntervalCycles) , queueMutex() - , sendQueue() - , sending() , sendReady(false) + , sendQueue() , messageAllocator() {} @@ -67,10 +71,10 @@ Sender::~Sender() {} * Allocate an OutMessage that can be sent with this Sender. */ Homa::OutMessage* -Sender::allocMessage() +Sender::allocMessage(uint16_t sourcePort) { SpinLock::Lock lock_allocator(messageAllocator.mutex); - return messageAllocator.pool.construct(this, driver); + return messageAllocator.pool.construct(this, sourcePort); } /** @@ -78,12 +82,9 @@ Sender::allocMessage() * * @param packet * Incoming DONE packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) +Sender::handleDonePacket(Driver::Packet* packet) { Protocol::Packet::DoneHeader* header = static_cast(packet->payload); @@ -95,7 +96,7 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) if (message == nullptr) { // No message for this DONE packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -106,7 +107,7 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) // Expected behavior bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::COMPLETED); + message->setStatus(OutMessage::Status::COMPLETED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the DONE. @@ -144,7 +145,7 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) break; } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -152,12 +153,9 @@ Sender::handleDonePacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming RESEND packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) +Sender::handleResendPacket(Driver::Packet* packet) { Protocol::Packet::ResendHeader* header = static_cast(packet->payload); @@ -173,7 +171,7 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) if (message == nullptr) { // No message for this RESEND; RESEND must be old. Just ignore it; this // case should be pretty rare and the Receiver will timeout eventually. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } else if (message->numPackets < 2) { // We should never get a RESEND for a single packet message. Just @@ -182,14 +180,15 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) "Message (%lu, %lu) with only 1 packet received unexpected RESEND " "request; peer Transport may be confused.", msgId.transportId, msgId.sequence); - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); - SpinLock::Lock lock_queue(queueMutex); + bool notifySendReady = false; + SpinLock::UniqueLock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; // Check if RESEND request is out of range. @@ -201,7 +200,7 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) "may be confused.", msgId.transportId, msgId.sequence, index, resendEnd, info->packets->numPackets); - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -212,7 +211,8 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) // will never be overridden since the resend index will not exceed the // preset packetsGranted. info->priority = header->priority; - sendReady.store(true); + sendReady = true; + notifySendReady = true; } if (index >= info->packetsSent) { @@ -222,7 +222,7 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) // when it's ready. Perf::counters.tx_busy_pkts.add(1); ControlPacket::send( - driver, info->destination, info->id); + driver, info->destination.ip, info->id); } else { // There are some packets to resend but only resend packets that have // already been sent. @@ -230,15 +230,20 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) int resendPriority = policyManager->getResendPriority(); for (uint16_t i = index; i < resendEnd; ++i) { Driver::Packet* packet = info->packets->getPacket(i); - packet->priority = resendPriority; // Packets will be sent at the priority their original priority. Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message->destination.ip, resendPriority); } } - driver->releasePackets(&packet, 1); + // Only invoke the callback after unlocking queueMutex. + lock_queue.unlock(); + if (notifySendReady) { + callbacks->notifySendReady(); + } + + driver->releasePackets(packet, 1); } /** @@ -246,12 +251,9 @@ Sender::handleResendPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming GRANT packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) +Sender::handleGrantPacket(Driver::Packet* packet) { Protocol::Packet::GrantHeader* header = static_cast(packet->payload); @@ -262,14 +264,15 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this grant; grant must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); - if (message->state.load() == OutMessage::Status::IN_PROGRESS) { + bool notifySendReady = false; + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -298,11 +301,17 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) // limit will never be overridden since the incomingGrantIndex will // not exceed the preset packetsGranted. info->priority = header->priority; - sendReady.store(true); + sendReady = true; + notifySendReady = true; } } - driver->releasePackets(&packet, 1); + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); + } + + driver->releasePackets(packet, 1); } /** @@ -310,12 +319,9 @@ Sender::handleGrantPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming UNKNOWN packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) +Sender::handleUnknownPacket(Driver::Packet* packet) { Protocol::Packet::UnknownHeader* header = static_cast(packet->payload); @@ -327,7 +333,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) if (message == nullptr) { // No message was found. Just drop the packet. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -347,7 +353,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) if (message->numPackets > 1) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } @@ -356,7 +362,7 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); } else { // Message isn't done yet so we will restart sending the message. @@ -365,18 +371,18 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) if (message->numPackets > 1) { SpinLock::Lock lock_queue(queueMutex); QueuedMessageInfo* info = &message->queuedMessageInfo; - if (message->state == OutMessage::Status::IN_PROGRESS) { + if (message->getStatus() == OutMessage::Status::IN_PROGRESS) { assert(sendQueue.contains(&info->sendQueueNode)); sendQueue.remove(&info->sendQueueNode); } assert(!sendQueue.contains(&info->sendQueueNode)); } - message->state.store(OutMessage::Status::IN_PROGRESS); + message->setStatus(OutMessage::Status::IN_PROGRESS); // Get the current policy for unscheduled bytes. Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - message->destination, message->messageLength); + message->destination.ip, message->messageLength); int unscheduledIndexLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -397,15 +403,16 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) bucket->pingTimeouts.setTimeout(&message->pingTimeout); assert(message->numPackets > 0); + bool notifySendReady = false; if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. Driver::Packet* dataPacket = message->getPacket(0); assert(dataPacket != nullptr); - dataPacket->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(dataPacket->length); - driver->sendPacket(dataPacket); - message->state.store(OutMessage::Status::SENT); + driver->sendPacket(dataPacket, message->destination.ip, + policy.priority); + message->setStatus(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -413,7 +420,8 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) // Some of these values should still be set from when the message // was first queued. assert(info->id == message->id); - assert(info->destination == message->destination); + assert(!memcmp(&info->destination, &message->destination, + sizeof(info->destination))); assert(info->packets == message); // Some values need to be updated info->unsentBytes = message->messageLength; @@ -427,11 +435,17 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) Intrusive::deprioritize( &sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - sendReady.store(true); + sendReady = true; + notifySendReady = true; + } + + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); } } - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); } /** @@ -439,12 +453,9 @@ Sender::handleUnknownPacket(Driver::Packet* packet, Driver* driver) * * @param packet * Incoming ERROR packet to be processed. - * @param driver - * Driver from which the packet was received and to which it should be - * returned after the packet has been processed. */ void -Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) +Sender::handleErrorPacket(Driver::Packet* packet) { Protocol::Packet::ErrorHeader* header = static_cast(packet->payload); @@ -455,7 +466,7 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) Message* message = bucket->findMessage(msgId, lock); if (message == nullptr) { // No message for this ERROR packet; must be old. Just drop it. - driver->releasePackets(&packet, 1); + driver->releasePackets(packet, 1); return; } @@ -465,7 +476,7 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) // Message was sent and a failure notification was received. bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - message->state.store(OutMessage::Status::FAILED); + message->setStatus(OutMessage::Status::FAILED); break; case OutMessage::Status::CANCELED: // Canceled by the the application; just ignore the ERROR. @@ -503,18 +514,7 @@ Sender::handleErrorPacket(Driver::Packet* packet, Driver* driver) break; } - driver->releasePackets(&packet, 1); -} - -/** - * Allow the Sender to make progress toward sending outgoing messages. - * - * This method must be called eagerly to ensure messages are sent. - */ -void -Sender::poll() -{ - trySend(); + driver->releasePackets(packet, 1); } /** @@ -603,7 +603,19 @@ Sender::Message::cancel() OutMessage::Status Sender::Message::getStatus() const { - return state.load(); + return state.load(std::memory_order_acquire); +} + +/** + * Change the current state of this message and invoke callback if necessary. + * + * @param newStatus + * The new state of the message + */ +void +Sender::Message::setStatus(OutMessage::Status newStatus) +{ + state.store(newStatus, std::memory_order_release); } /** @@ -697,7 +709,7 @@ Sender::Message::reserve(size_t count) * @copydoc Homa::OutMessage::send() */ void -Sender::Message::send(Driver::Address destination, +Sender::Message::send(SocketAddress destination, Sender::Message::Options options) { sender->sendMessage(this, destination, options); @@ -713,10 +725,10 @@ Sender::Message::send(Driver::Address destination, * Pointer to a Packet at the given index if it exists; nullptr otherwise. */ Driver::Packet* -Sender::Message::getPacket(size_t index) const +Sender::Message::getPacket(size_t index) { if (occupied.test(index)) { - return packets[index]; + return &packets[index]; } return nullptr; } @@ -735,14 +747,14 @@ Driver::Packet* Sender::Message::getOrAllocPacket(size_t index) { if (!occupied.test(index)) { - packets[index] = driver->allocPacket(); + driver->allocPacket(&packets[index]); occupied.set(index); numPackets++; // TODO(cstlee): A Message probably shouldn't be in charge of setting // the packet length. - packets[index]->length = TRANSPORT_HEADER_LENGTH; + packets[index].length = TRANSPORT_HEADER_LENGTH; } - return packets[index]; + return &packets[index]; } /** @@ -758,7 +770,7 @@ Sender::Message::getOrAllocPacket(size_t index) * @sa dropMessage() */ void -Sender::sendMessage(Sender::Message* message, Driver::Address destination, +Sender::sendMessage(Sender::Message* message, SocketAddress destination, Sender::Message::Options options) { // Prepare the message @@ -767,7 +779,7 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, Protocol::MessageId id(transportId, nextMessageSequenceNumber++); Policy::Unscheduled policy = policyManager->getUnscheduledPolicy( - destination, message->messageLength); + destination.ip, message->messageLength); int unscheduledPacketLimit = ((policy.unscheduledByteLimit + message->PACKET_DATA_LENGTH - 1) / message->PACKET_DATA_LENGTH); @@ -775,7 +787,7 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, message->id = id; message->destination = destination; message->options = options; - message->state.store(OutMessage::Status::IN_PROGRESS); + message->setStatus(OutMessage::Status::IN_PROGRESS); int actualMessageLen = 0; // fill out metadata. @@ -789,10 +801,10 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, i * message->PACKET_DATA_LENGTH); } - packet->address = message->destination; new (packet->payload) Protocol::Packet::DataHeader( - message->id, Util::downCast(message->messageLength), - policy.version, Util::downCast(unscheduledPacketLimit), + message->source.port, destination.port, message->id, + Util::downCast(message->messageLength), policy.version, + Util::downCast(unscheduledPacketLimit), Util::downCast(i)); actualMessageLen += (packet->length - message->TRANSPORT_HEADER_LENGTH); } @@ -811,16 +823,16 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); + bool notifySendReady = false; assert(message->numPackets > 0); if (message->numPackets == 1) { // If there is only one packet in the message, send it right away. Driver::Packet* packet = message->getPacket(0); assert(packet != nullptr); - packet->priority = policy.priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); - message->state.store(OutMessage::Status::SENT); + driver->sendPacket(packet, message->destination.ip, policy.priority); + message->setStatus(OutMessage::Status::SENT); } else { // Otherwise, queue the message to be sent in SRPT order. SpinLock::Lock lock_queue(queueMutex); @@ -837,7 +849,13 @@ Sender::sendMessage(Sender::Message* message, Driver::Address destination, sendQueue.push_front(&info->sendQueueNode); Intrusive::deprioritize(&sendQueue, &info->sendQueueNode, QueuedMessageInfo::ComparePriority()); - sendReady.store(true); + sendReady = true; + notifySendReady = true; + } + + // Only invoke the callback after unlocking queueMutex. + if (notifySendReady) { + callbacks->notifySendReady(); } } @@ -856,19 +874,21 @@ Sender::cancelMessage(Sender::Message* message) if (bucket->messages.contains(&message->bucketNode)) { bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); - if (message->numPackets > 1 && - message->state == OutMessage::Status::IN_PROGRESS) { - // Check to see if the message needs to be dequeued. + + // Check to see if the message needs to be dequeued. In order to reduce + // cache misses related to queueMutex, check the status first to avoid + // unnecessary locking. + OutMessage::Status status = message->getStatus(); + if ((status == OutMessage::Status::IN_PROGRESS) || + (status == OutMessage::Status::FAILED)) { SpinLock::Lock lock_queue(queueMutex); - // Recheck state with lock in case it change right before this. - if (message->state == OutMessage::Status::IN_PROGRESS) { - QueuedMessageInfo* info = &message->queuedMessageInfo; - assert(sendQueue.contains(&info->sendQueueNode)); + QueuedMessageInfo* info = &message->queuedMessageInfo; + if (sendQueue.contains(&info->sendQueueNode)) { sendQueue.remove(&info->sendQueueNode); } } bucket->messages.remove(&message->bucketNode); - message->state.store(OutMessage::Status::CANCELED); + message->setStatus(OutMessage::Status::CANCELED); } } @@ -918,13 +938,13 @@ Sender::checkMessageTimeouts() break; } // Found expired timeout. - if (message->state != OutMessage::Status::COMPLETED) { - message->state.store(OutMessage::Status::FAILED); + if (message->getStatus() != OutMessage::Status::COMPLETED) { + message->setStatus(OutMessage::Status::FAILED); // A sent NO_KEEP_ALIVE message should never reach this state // since the shorter ping timeout should have already canceled // the message timeout. assert( - !((message->state == OutMessage::Status::SENT) && + !((message->getStatus() == OutMessage::Status::SENT) && (message->options & OutMessage::Options::NO_KEEP_ALIVE))); } bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); @@ -967,12 +987,12 @@ Sender::checkPingTimeouts() break; } // Found expired timeout. - if (message->state == OutMessage::Status::COMPLETED || - message->state == OutMessage::Status::FAILED) { + if (message->getStatus() == OutMessage::Status::COMPLETED || + message->getStatus() == OutMessage::Status::FAILED) { bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; } else if (message->options & OutMessage::Options::NO_KEEP_ALIVE && - message->state == OutMessage::Status::SENT) { + message->getStatus() == OutMessage::Status::SENT) { bucket->messageTimeouts.cancelTimeout(&message->messageTimeout); bucket->pingTimeouts.cancelTimeout(&message->pingTimeout); continue; @@ -984,29 +1004,24 @@ Sender::checkPingTimeouts() // the receiver to ensure it still knows about this Message. Perf::counters.tx_ping_pkts.add(1); ControlPacket::send( - message->driver, message->destination, message->id); + message->driver, message->destination.ip, message->id); } globalNextTimeout = std::min(globalNextTimeout, nextTimeout); } return globalNextTimeout; } -/** - * Send out packets for any messages with unscheduled/granted bytes. - */ -void +/// See Homa::Core::Transport::trySend() +uint64_t Sender::trySend() { uint64_t start_tsc = PerfUtils::Cycles::rdtsc(); bool idle = true; + // Skip when there are no messages to send. + SpinLock::UniqueLock lock_queue(queueMutex); if (!sendReady) { - return; - } - - // Skip sending if another thread is already working on it. - if (sending.test_and_set()) { - return; + return 0; } /* The goal is to send out packets for messages that have bytes that have @@ -1015,15 +1030,15 @@ Sender::trySend() * Each time this method is called we will try to send enough packet to keep * the NIC busy but not too many as to cause excessive queue in the NIC. */ - SpinLock::UniqueLock lock_queue(queueMutex); uint32_t queuedBytesEstimate = driver->getQueuedBytes(); // Optimistically assume we will finish sending every granted packet this // round; we will set again sendReady if it turns out we don't finish. sendReady = false; + uint64_t waitUntil = 0; auto it = sendQueue.begin(); while (it != sendQueue.end()) { Message& message = *it; - assert(message.state.load() == OutMessage::Status::IN_PROGRESS); + assert(message.getStatus() == OutMessage::Status::IN_PROGRESS); QueuedMessageInfo* info = &message.queuedMessageInfo; assert(info->packetsGranted <= info->packets->numPackets); while (info->packetsSent < info->packetsGranted) { @@ -1038,10 +1053,9 @@ Sender::trySend() break; } // ... if not, send away! - packet->priority = info->priority; Perf::counters.tx_data_pkts.add(1); Perf::counters.tx_bytes.add(packet->length); - driver->sendPacket(packet); + driver->sendPacket(packet, message.destination.ip, info->priority); int packetDataBytes = packet->length - info->packets->TRANSPORT_HEADER_LENGTH; assert(info->unsentBytes >= packetDataBytes); @@ -1055,26 +1069,32 @@ Sender::trySend() } if (info->packetsSent >= info->packets->numPackets) { // We have finished sending the message. - message.state.store(OutMessage::Status::SENT); + message.setStatus(OutMessage::Status::SENT); it = sendQueue.remove(it); } else if (info->packetsSent >= info->packetsGranted) { // We have sent every granted packet. ++it; } else { - // We hit the DRIVER_QUEUED_BYTES_LIMIT; stop sending for now. + // We hit the DRIVER_QUEUED_BYTE_LIMIT; stop sending for now. // We didn't finish sending all granted packets. sendReady = true; + // Compute how much time the driver needs to drain its queue, + // then schedule to wake up a bit earlier to avoid blowing bubbles. + static const uint64_t us = PerfUtils::Cycles::fromMicroseconds(1); + waitUntil = + PerfUtils::Cycles::rdtsc() - 1 * us + + queuedBytesEstimate * DRIVER_CYCLES_TO_DRAIN_1MB / 1000000; break; } } - sending.clear(); uint64_t elapsed_cycles = PerfUtils::Cycles::rdtsc() - start_tsc; if (!idle) { Perf::counters.active_cycles.add(elapsed_cycles); } else { Perf::counters.idle_cycles.add(elapsed_cycles); } + return waitUntil; } } // namespace Core diff --git a/src/Sender.h b/src/Sender.h index 471925a..22d5864 100644 --- a/src/Sender.h +++ b/src/Sender.h @@ -16,12 +16,12 @@ #ifndef HOMA_CORE_SENDER_H #define HOMA_CORE_SENDER_H +#include #include -#include #include #include -#include +#include #include "Intrusive.h" #include "ObjectPool.h" @@ -42,18 +42,19 @@ namespace Core { class Sender { public: explicit Sender(uint64_t transportId, Driver* driver, + Transport::Callbacks* callbacks, Policy::Manager* policyManager, uint64_t messageTimeoutCycles, uint64_t pingIntervalCycles); virtual ~Sender(); - virtual Homa::OutMessage* allocMessage(); - virtual void handleDonePacket(Driver::Packet* packet, Driver* driver); - virtual void handleResendPacket(Driver::Packet* packet, Driver* driver); - virtual void handleGrantPacket(Driver::Packet* packet, Driver* driver); - virtual void handleUnknownPacket(Driver::Packet* packet, Driver* driver); - virtual void handleErrorPacket(Driver::Packet* packet, Driver* driver); - virtual void poll(); + virtual Homa::OutMessage* allocMessage(uint16_t sourcePort); + virtual void handleDonePacket(Driver::Packet* packet); + virtual void handleResendPacket(Driver::Packet* packet); + virtual void handleGrantPacket(Driver::Packet* packet); + virtual void handleUnknownPacket(Driver::Packet* packet); + virtual void handleErrorPacket(Driver::Packet* packet); virtual uint64_t checkTimeouts(); + virtual uint64_t trySend(); private: /// Forward declarations @@ -96,7 +97,7 @@ class Sender { Protocol::MessageId id; /// Contains destination address this message. - Driver::Address destination; + SocketAddress destination; /// Handle to the queue Message for access to the packets that will /// be sent. This member documents that the packets are logically owned @@ -126,18 +127,19 @@ class Sender { * Sender::Message objects are contained in the Transport::Op but should * only be accessed by the Sender. */ - class Message : public Homa::OutMessage { + class Message final : public Homa::OutMessage { public: /** * Construct an Message. */ - explicit Message(Sender* sender, Driver* driver) + explicit Message(Sender* sender, uint16_t sourcePort) : sender(sender) - , driver(driver) + , driver(sender->driver) , TRANSPORT_HEADER_LENGTH(sizeof(Protocol::Packet::DataHeader)) , PACKET_DATA_LENGTH(driver->getMaxPayloadSize() - TRANSPORT_HEADER_LENGTH) , id(0, 0) + , source{driver->getLocalAddress(), sourcePort} , destination() , options(Options::NONE) , start(0) @@ -161,14 +163,16 @@ class Sender { virtual void prepend(const void* source, size_t count); virtual void release(); virtual void reserve(size_t count); - virtual void send(Driver::Address destination, + virtual void send(SocketAddress destination, Options options = Options::NONE); private: + void setStatus(Status newStatus); + /// Define the maximum number of packets that a message can hold. static const size_t MAX_MESSAGE_PACKETS = 1024; - Driver::Packet* getPacket(size_t index) const; + Driver::Packet* getPacket(size_t index); Driver::Packet* getOrAllocPacket(size_t index); /// The Sender responsible for sending this message. @@ -188,8 +192,11 @@ class Sender { /// Contains the unique identifier for this message. Protocol::MessageId id; - /// Contains destination address this message. - Driver::Address destination; + /// Contains source address of this message. + SocketAddress source; + + /// Contains destination address of this message. + SocketAddress destination; /// Contains flags for any requested optional send behavior. Options options; @@ -209,7 +216,7 @@ class Sender { /// Collection of Packet objects that make up this context's Message. /// These Packets will be released when this context is destroyed. - Driver::Packet* packets[MAX_MESSAGE_PACKETS]; + Driver::Packet packets[MAX_MESSAGE_PACKETS]; /// This message's current state. std::atomic state; @@ -384,17 +391,20 @@ class Sender { Protocol::MessageId::Hasher hasher; }; - void sendMessage(Sender::Message* message, Driver::Address destination, + void sendMessage(Sender::Message* message, SocketAddress destination, Message::Options options = Message::Options::NONE); void cancelMessage(Sender::Message* message); void dropMessage(Sender::Message* message); uint64_t checkMessageTimeouts(); uint64_t checkPingTimeouts(); - void trySend(); /// Transport identifier. const uint64_t transportId; + /// User-defined transport callbacks; not owned by this class. As a general + /// rule, one should not hold any locks when invoking a callback. + Transport::Callbacks* const callbacks; + /// Driver with which all packets will be sent and received. This driver /// is chosen by the Transport that owns this Sender. Driver* const driver; @@ -408,25 +418,27 @@ class Sender { /// The maximum number of bytes that should be queued in the Driver. const uint32_t DRIVER_QUEUED_BYTE_LIMIT; + /// Rdtsc cycles for the Driver to drain one MB of data at line rate. + const uint32_t DRIVER_CYCLES_TO_DRAIN_1MB; + /// Tracks all outbound messages being sent by the Sender. MessageBucketMap messageBuckets; - /// Protects the readyQueue. + /// Protects the sendQueue and sendReady. SpinLock queueMutex; + /// Hint whether there are messages ready to be sent (i.e. granted messages + /// in the sendQueue). Encoded into a single bool so that checking if there + /// is work to do is more efficient. This bool can be cleared by trySend() + /// and set to true when new GRANTs arrive, when new outgoing messages + /// appear, and when retransmission is requested. Access to this field is + /// protected by queueMutex. + bool sendReady; + /// A list of outbound messages that have unsent packets. Messages are kept /// in order of priority. Intrusive::List sendQueue; - /// True if the Sender is currently executing trySend(); false, otherwise. - /// Use to prevent concurrent trySend() calls from blocking on each other. - std::atomic_flag sending = ATOMIC_FLAG_INIT; - - /// Hint whether there are messages ready to be sent (i.e. there are granted - /// messages in the sendQueue. Encoded into a single bool so that checking - /// if there is work to do is more efficient. - std::atomic sendReady; - /// Used to allocate Message objects. struct { /// Protects the messageAllocator.pool diff --git a/src/SenderTest.cc b/src/SenderTest.cc index fdae6ab..2c56a16 100644 --- a/src/SenderTest.cc +++ b/src/SenderTest.cc @@ -29,24 +29,46 @@ using ::testing::_; using ::testing::Eq; using ::testing::Mock; using ::testing::NiceMock; -using ::testing::Pointee; using ::testing::Return; +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + +class MockCallbacks : public Transport::Callbacks { + public: + explicit MockCallbacks() = default; + + bool deliver(uint16_t port, Homa::unique_ptr message) override + { + return true; + } +}; + class SenderTest : public ::testing::Test { public: SenderTest() - : mockDriver() - , mockPacket(&payload) + : mockCallbacks() + , mockDriver() + , mockPacket() , mockPolicyManager(&mockDriver) + , payload() + , packetBuf{&payload} , sender() , savedLogPolicy(Debug::getLogPolicy()) { + mockPacket = packetBuf.toPacket(); ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1027)); + ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1031)); ON_CALL(mockDriver, getQueuedBytes).WillByDefault(Return(0)); Debug::setLogPolicy( Debug::logPolicyFromString("src/ObjectPool@SILENT")); - sender = new Sender(22, &mockDriver, &mockPolicyManager, + sender = new Sender(22, &mockDriver, &mockCallbacks, &mockPolicyManager, messageTimeoutCycles, pingIntervalCycles); PerfUtils::Cycles::mockTscValue = 10000; } @@ -58,10 +80,12 @@ class SenderTest : public ::testing::Test { PerfUtils::Cycles::mockTscValue = 0; } + MockCallbacks mockCallbacks; NiceMock mockDriver; - NiceMock mockPacket; + Driver::Packet mockPacket; NiceMock mockPolicyManager; char payload[1028]; + Homa::Mock::MockDriver::PacketBuf packetBuf; Sender* sender; std::vector> savedLogPolicy; @@ -96,7 +120,7 @@ class SenderTest : public ::testing::Test { } static bool setMessagePacket(Sender::Message* message, int index, - Driver::Packet* packet) + Driver::Packet packet) { if (message->occupied.test(index)) { return false; @@ -124,7 +148,7 @@ TEST_F(SenderTest, allocMessage) { EXPECT_EQ(0U, sender->messageAllocator.pool.outstandingObjects); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); } @@ -132,18 +156,17 @@ TEST_F(SenderTest, handleDonePacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); Protocol::Packet::DoneHeader* header = static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(2); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(2); // No message. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_NE(Homa::OutMessage::Status::COMPLETED, message->state); @@ -151,7 +174,7 @@ TEST_F(SenderTest, handleDonePacket_basic) message->state = Homa::OutMessage::Status::SENT; // Normal expected behavior. - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -162,7 +185,7 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::CANCELED; @@ -170,17 +193,16 @@ TEST_F(SenderTest, handleDonePacket_CANCELED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); } TEST_F(SenderTest, handleDonePacket_COMPLETED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::COMPLETED; @@ -188,13 +210,12 @@ TEST_F(SenderTest, handleDonePacket_COMPLETED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -211,7 +232,7 @@ TEST_F(SenderTest, handleDonePacket_FAILED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::FAILED; @@ -219,13 +240,12 @@ TEST_F(SenderTest, handleDonePacket_FAILED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -244,7 +264,7 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -252,13 +272,12 @@ TEST_F(SenderTest, handleDonePacket_IN_PROGRESS) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -277,7 +296,7 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); addMessage(sender, id, message); message->state = Homa::OutMessage::Status::NOT_STARTED; @@ -285,13 +304,12 @@ TEST_F(SenderTest, handleDonePacket_NO_STARTED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleDonePacket(&mockPacket, &mockDriver); + sender->handleDonePacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -310,10 +328,13 @@ TEST_F(SenderTest, handleResendPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - std::vector packets; + dynamic_cast(sender->allocMessage(0)); + std::vector packets; + std::vector priorities; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packets.push_back(packetBuf->toPacket()); + priorities.push_back(0); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -331,26 +352,30 @@ TEST_F(SenderTest, handleResendPacket_basic) resendHdr->priority = 4; EXPECT_CALL(mockPolicyManager, getResendPriority).WillOnce(Return(7)); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[3]))).Times(1); - EXPECT_CALL(mockDriver, sendPacket(Eq(packets[4]))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packets[3]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[3] = p; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packets[4]), _, _)) + .WillOnce( + [&priorities](auto _1, auto _2, int p) { priorities[4] = p; }); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); EXPECT_EQ(4, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_EQ(0, packets[2]->priority); - EXPECT_EQ(7, packets[3]->priority); - EXPECT_EQ(7, packets[4]->priority); - EXPECT_EQ(0, packets[5]->priority); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_EQ(0, priorities[2]); + EXPECT_EQ(7, priorities[3]); + EXPECT_EQ(7, priorities[4]); + EXPECT_EQ(0, priorities[5]); + EXPECT_TRUE(sender->sendReady); for (int i = 0; i < 10; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -363,10 +388,9 @@ TEST_F(SenderTest, handleResendPacket_staleResend) resendHdr->index = 3; resendHdr->num = 5; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); } TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) @@ -374,10 +398,11 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload); + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload}; + Driver::Packet packet = packetBuf->toPacket(); setMessagePacket(message, 0, packet); Protocol::Packet::ResendHeader* resendHdr = @@ -387,13 +412,12 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) resendHdr->num = 5; resendHdr->priority = 4; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -407,17 +431,18 @@ TEST_F(SenderTest, handleResendPacket_badRequest_singlePacketMessage) Debug::setLogHandler(std::function()); - delete packet; + delete packetBuf; } TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - std::vector packets; + dynamic_cast(sender->allocMessage(0)); + std::vector packets; for (int i = 0; i < 10; ++i) { - packets.push_back(new Homa::Mock::MockDriver::MockPacket(payload)); + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packets.push_back(packetBuf->toPacket()); setMessagePacket(message, i, packets[i]); } SenderTest::addMessage(sender, id, message, true, 5); @@ -434,13 +459,12 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) resendHdr->num = 5; resendHdr->priority = 4; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -456,7 +480,8 @@ TEST_F(SenderTest, handleResendPacket_badRequest_outOfRange) Debug::setLogHandler(std::function()); for (int i = 0; i < 10; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -464,11 +489,12 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); char data[1028]; - Homa::Mock::MockDriver::MockPacket dataPacket(data); + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{data}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(); for (int i = 0; i < 10; ++i) { - setMessagePacket(message, i, &dataPacket); + setMessagePacket(message, i, dataPacket); } SenderTest::addMessage(sender, id, message, true, 5); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -484,18 +510,19 @@ TEST_F(SenderTest, handleResendPacket_eagerResend) // Expect the BUSY control packet. char busy[1028]; - Homa::Mock::MockDriver::MockPacket busyPacket(busy); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&busyPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&busyPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&busyPacket), Eq(1))) + Homa::Mock::MockDriver::PacketBuf busyPacketBuf{busy}; + Driver::Packet busyPacket = busyPacketBuf.toPacket(); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&busyPacket](Driver::Packet* p) { *p = busyPacket; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&busyPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&busyPacket), Eq(1))) .Times(1); // Expect no data to be sent but the RESEND packet to be release. - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(0); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&dataPacket), _, _)).Times(0); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleResendPacket(&mockPacket, &mockDriver); + sender->handleResendPacket(&mockPacket); EXPECT_EQ(5U, info->packetsSent); EXPECT_EQ(8U, info->packetsGranted); @@ -511,7 +538,7 @@ TEST_F(SenderTest, handleGrantPacket_basic) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -527,23 +554,22 @@ TEST_F(SenderTest, handleGrantPacket_basic) header->byteLimit = 7000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(7, info->packetsGranted); EXPECT_EQ(6, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_excessiveGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; message->state = Homa::OutMessage::Status::IN_PROGRESS; @@ -559,13 +585,12 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) header->byteLimit = 11000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -583,14 +608,14 @@ TEST_F(SenderTest, handleGrantPacket_excessiveGrant) EXPECT_EQ(6, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_staleGrant) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); message->numPackets = 10; Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; @@ -605,16 +630,15 @@ TEST_F(SenderTest, handleGrantPacket_staleGrant) header->byteLimit = 4000; header->priority = 6; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); EXPECT_EQ(5, info->packetsGranted); EXPECT_EQ(2, info->priority); EXPECT_EQ(11000U, message->messageTimeout.expirationCycleTime); EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleGrantPacket_dropGrant) @@ -625,28 +649,28 @@ TEST_F(SenderTest, handleGrantPacket_dropGrant) header->common.messageId = id; header->byteLimit = 4000; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleGrantPacket(&mockPacket, &mockDriver); + sender->handleGrantPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_basic) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - std::vector packets; + dynamic_cast(sender->allocMessage(0)); + std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload[i]}; + Driver::Packet packet = packetBuf->toPacket(); Protocol::Packet::DataHeader* header = - static_cast(packet->payload); + static_cast(packet.payload); header->policyVersion = policyOld.version; header->unscheduledIndexLimit = 2; packets.push_back(packet); @@ -674,18 +698,16 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); for (int i = 0; i < 3; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = packets[i]; Protocol::Packet::DataHeader* header = - static_cast(packet->payload); + static_cast(packets[i].payload); EXPECT_EQ(policyNew.version, header->policyVersion); EXPECT_EQ(3U, header->unscheduledIndexLimit); } @@ -696,28 +718,30 @@ TEST_F(SenderTest, handleUnknownPacket_basic) EXPECT_EQ(policyNew.priority, info->priority); EXPECT_EQ(0U, info->packetsSent); EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); for (int i = 0; i < 5; ++i) { - delete packets[i]; + uintptr_t packetBuf = packets[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policyOld = {1, 2000, 1}; Core::Policy::Unscheduled policyNew = {2, 3000, 2}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - Homa::Mock::MockDriver::MockPacket dataPacket(payload); + dynamic_cast(sender->allocMessage(0)); + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(); Protocol::Packet::DataHeader* dataHeader = static_cast(dataPacket.payload); dataHeader->policyVersion = policyOld.version; dataHeader->unscheduledIndexLimit = 2; - setMessagePacket(message, 0, &dataPacket); + setMessagePacket(message, 0, dataPacket); message->destination = destination; message->messageLength = 500; message->state.store(Homa::OutMessage::Status::SENT); @@ -733,13 +757,12 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_CALL( mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(message->messageLength))) + getUnscheduledPolicy(Eq(destination.ip), Eq(message->messageLength))) .WillOnce(Return(policyNew)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&dataPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, sendPacket(Eq(message->packets), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(policyNew.version, dataHeader->policyVersion); @@ -748,23 +771,24 @@ TEST_F(SenderTest, handleUnknownPacket_singlePacketMessage) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) { Protocol::MessageId id = {42, 1}; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); message->options = OutMessage::Options::NO_RETRY; - std::vector packets; + std::vector packets; char payload[5][1028]; for (int i = 0; i < 5; ++i) { - Homa::Mock::MockDriver::MockPacket* packet = - new Homa::Mock::MockDriver::MockPacket(payload[i]); + Homa::Mock::MockDriver::PacketBuf* packetBuf = + new Homa::Mock::MockDriver::PacketBuf{payload[i]}; + Driver::Packet packet = packetBuf->toPacket(); packets.push_back(packet); setMessagePacket(message, i, packet); } @@ -781,10 +805,9 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_FALSE( sender->sendQueue.contains(&message->queuedMessageInfo.sendQueueNode)); @@ -792,7 +815,7 @@ TEST_F(SenderTest, handleUnknownPacket_NO_RETRY) EXPECT_EQ(nullptr, message->pingTimeout.node.list); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message->state); EXPECT_EQ(Homa::OutMessage::Status::FAILED, message->state); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, handleUnknownPacket_no_message) @@ -803,17 +826,16 @@ TEST_F(SenderTest, handleUnknownPacket_no_message) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); } TEST_F(SenderTest, handleUnknownPacket_done) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -823,10 +845,9 @@ TEST_F(SenderTest, handleUnknownPacket_done) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleUnknownPacket(&mockPacket, &mockDriver); + sender->handleUnknownPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::COMPLETED, message->state); EXPECT_EQ(0U, message->messageTimeout.expirationCycleTime); @@ -838,7 +859,7 @@ TEST_F(SenderTest, handleErrorPacket_basic) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); bucket->messageTimeouts.setTimeout(&message->messageTimeout); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -848,10 +869,9 @@ TEST_F(SenderTest, handleErrorPacket_basic) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(nullptr, message->messageTimeout.node.list); EXPECT_EQ(nullptr, message->pingTimeout.node.list); @@ -863,7 +883,7 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::CANCELED); @@ -871,10 +891,9 @@ TEST_F(SenderTest, handleErrorPacket_CANCELED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(Homa::OutMessage::Status::CANCELED, message->state); } @@ -884,7 +903,7 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::NOT_STARTED); @@ -892,13 +911,12 @@ TEST_F(SenderTest, handleErrorPacket_NOT_STARTED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -920,7 +938,7 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::IN_PROGRESS); @@ -928,13 +946,12 @@ TEST_F(SenderTest, handleErrorPacket_IN_PROGRESS) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -956,7 +973,7 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::COMPLETED); @@ -964,13 +981,12 @@ TEST_F(SenderTest, handleErrorPacket_COMPLETED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -992,7 +1008,7 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) Protocol::MessageId id = {42, 1}; Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message); message->state.store(Homa::OutMessage::Status::FAILED); @@ -1000,13 +1016,12 @@ TEST_F(SenderTest, handleErrorPacket_FAILED) static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); VectorHandler handler; Debug::setLogHandler(std::ref(handler)); - sender->handleErrorPacket(&mockPacket, &mockDriver); + sender->handleErrorPacket(&mockPacket); EXPECT_EQ(1U, handler.messages.size()); const Debug::DebugMessage& m = handler.messages.at(0); @@ -1027,20 +1042,13 @@ TEST_F(SenderTest, handleErrorPacket_noMessage) Protocol::Packet::ErrorHeader* header = static_cast(mockPacket.payload); header->common.messageId = id; - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) - .Times(1); - sender->handleErrorPacket(&mockPacket, &mockDriver); -} - -TEST_F(SenderTest, poll) -{ - // Nothing to test. - sender->poll(); + EXPECT_CALL(mockDriver, releasePackets(Eq(&mockPacket), Eq(1))).Times(1); + sender->handleErrorPacket(&mockPacket); } TEST_F(SenderTest, checkTimeouts) { - Sender::Message message(sender, &mockDriver); + Sender::Message message(sender, 0); Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); bucket->pingTimeouts.setTimeout(&message.pingTimeout); bucket->messageTimeouts.setTimeout(&message.messageTimeout); @@ -1065,7 +1073,7 @@ TEST_F(SenderTest, Message_destructor) const int MAX_RAW_PACKET_LENGTH = 2000; ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message* msg = new Sender::Message(sender, &mockDriver); + Sender::Message* msg = new Sender::Message(sender, 0); const uint16_t NUM_PKTS = 5; @@ -1086,10 +1094,10 @@ TEST_F(SenderTest, Message_append_basic) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + MAX_RAW_PACKET_LENGTH}; const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1098,19 +1106,20 @@ TEST_F(SenderTest, Message_append_basic) TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH); char source[] = "Hello, world!"; - setMessagePacket(&msg, 0, &packet0); - packet0.length = MAX_RAW_PACKET_LENGTH - 7; + setMessagePacket(&msg, 0, packetBuf0.toPacket(MAX_RAW_PACKET_LENGTH - 7)); msg.messageLength = PACKET_DATA_LENGTH - 7; - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet1)); + Driver::Packet packet1 = packetBuf1.toPacket(); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.append(source, 14); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); EXPECT_EQ(2U, msg.numPackets); - EXPECT_TRUE(msg.packets[1] == &packet1); - EXPECT_EQ(MAX_RAW_PACKET_LENGTH, packet0.length); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, packet1.length); + EXPECT_EQ(msg.packets[1].payload, packetBuf1.buffer); + EXPECT_EQ(MAX_RAW_PACKET_LENGTH, msg.packets[0].length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, msg.packets[1].length); EXPECT_TRUE(std::memcmp(buf + MAX_RAW_PACKET_LENGTH - 7, source, 7) == 0); EXPECT_TRUE( std::memcmp(buf + MAX_RAW_PACKET_LENGTH + TRANSPORT_HEADER_LENGTH, @@ -1126,17 +1135,14 @@ TEST_F(SenderTest, Message_append_truncated) ON_CALL(mockDriver, getMaxPayloadSize) .WillByDefault(Return(MAX_RAW_PACKET_LENGTH)); - Sender::Message msg(sender, &mockDriver); - char buf[2 * MAX_RAW_PACKET_LENGTH]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + MAX_RAW_PACKET_LENGTH); - - const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; - const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; + Sender::Message msg(sender, 0); + char buf[MAX_RAW_PACKET_LENGTH]; + Homa::Mock::MockDriver::PacketBuf packetBuf{buf}; char source[] = "Hello, world!"; - setMessagePacket(&msg, msg.MAX_MESSAGE_PACKETS - 1, &packet0); - packet0.length = msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH - 7; + setMessagePacket(&msg, msg.MAX_MESSAGE_PACKETS - 1, packetBuf.toPacket()); + Driver::Packet& packet = msg.packets[msg.MAX_MESSAGE_PACKETS - 1]; + packet.length = msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH - 7; msg.messageLength = msg.PACKET_DATA_LENGTH * msg.MAX_MESSAGE_PACKETS - 7; EXPECT_EQ(1U, msg.numPackets); @@ -1146,7 +1152,7 @@ TEST_F(SenderTest, Message_append_truncated) msg.messageLength); EXPECT_EQ(1U, msg.numPackets); EXPECT_EQ(msg.TRANSPORT_HEADER_LENGTH + msg.PACKET_DATA_LENGTH, - packet0.length); + packet.length); EXPECT_TRUE(std::memcmp(buf + MAX_RAW_PACKET_LENGTH - 7, source, 7) == 0); EXPECT_EQ(1U, handler.messages.size()); @@ -1155,7 +1161,7 @@ TEST_F(SenderTest, Message_append_truncated) EXPECT_STREQ("append", m.function); EXPECT_EQ(int(Debug::LogLevel::WARNING), m.logLevel); EXPECT_EQ( - "Max message size limit (2020352B) reached; 7 of 14 bytes appended", + "Max message size limit (2016256B) reached; 7 of 14 bytes appended", m.message); Debug::setLogHandler(std::function()); @@ -1174,7 +1180,7 @@ TEST_F(SenderTest, Message_getStatus) TEST_F(SenderTest, Message_length) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); msg.messageLength = 200; msg.start = 20; EXPECT_EQ(180U, msg.length()); @@ -1183,16 +1189,18 @@ TEST_F(SenderTest, Message_length) TEST_F(SenderTest, Message_prepend) { ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(2048)); - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; - EXPECT_CALL(mockDriver, allocPacket) - .WillOnce(Return(&packet0)) - .WillOnce(Return(&packet1)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.reserve(PACKET_DATA_LENGTH + 7); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); @@ -1218,10 +1226,12 @@ TEST_F(SenderTest, Message_release) TEST_F(SenderTest, Message_reserve) { - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); const int TRANSPORT_HEADER_LENGTH = msg.TRANSPORT_HEADER_LENGTH; const int PACKET_DATA_LENGTH = msg.PACKET_DATA_LENGTH; @@ -1230,26 +1240,28 @@ TEST_F(SenderTest, Message_reserve) EXPECT_EQ(0U, msg.messageLength); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet0)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }); msg.reserve(PACKET_DATA_LENGTH - 7); EXPECT_EQ(PACKET_DATA_LENGTH - 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH - 7, msg.messageLength); EXPECT_EQ(1U, msg.numPackets); - EXPECT_EQ(&packet0, msg.getPacket(0)); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH - 7, packet0.length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH - 7, + msg.packets[0].length); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet1)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet1](Driver::Packet* packet) { *packet = packet1; }); msg.reserve(14); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.start); EXPECT_EQ(PACKET_DATA_LENGTH + 7, msg.messageLength); EXPECT_EQ(2U, msg.numPackets); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH, packet0.length); - EXPECT_EQ(&packet1, msg.getPacket(1)); - EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, packet1.length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + PACKET_DATA_LENGTH, + msg.packets[0].length); + EXPECT_EQ(TRANSPORT_HEADER_LENGTH + 7, msg.packets[1].length); } TEST_F(SenderTest, Message_send) @@ -1259,9 +1271,9 @@ TEST_F(SenderTest, Message_send) TEST_F(SenderTest, Message_getPacket) { - Sender::Message msg(sender, &mockDriver); - Driver::Packet* packet = (Driver::Packet*)42; - msg.packets[0] = packet; + Sender::Message msg(sender, 0); + msg.packets[0] = {}; + Driver::Packet* packet = &msg.packets[0]; EXPECT_EQ(nullptr, msg.getPacket(0)); @@ -1273,21 +1285,24 @@ TEST_F(SenderTest, Message_getPacket) TEST_F(SenderTest, Message_getOrAllocPacket) { // TODO(cstlee): cleanup - Sender::Message msg(sender, &mockDriver); + Sender::Message msg(sender, 0); char buf[4096]; - Homa::Mock::MockDriver::MockPacket packet0(buf + 0); - Homa::Mock::MockDriver::MockPacket packet1(buf + 2048); + Homa::Mock::MockDriver::PacketBuf packetBuf0{buf + 0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{buf + 2048}; + Driver::Packet packet0 = packetBuf0.toPacket(); + Driver::Packet packet1 = packetBuf1.toPacket(); EXPECT_FALSE(msg.occupied.test(0)); EXPECT_EQ(0U, msg.numPackets); - EXPECT_CALL(mockDriver, allocPacket).WillOnce(Return(&packet0)); + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([&packet0](Driver::Packet* packet) { *packet = packet0; }); - EXPECT_TRUE(&packet0 == msg.getOrAllocPacket(0)); + EXPECT_EQ(packet0.descriptor, msg.getOrAllocPacket(0)->descriptor); EXPECT_TRUE(msg.occupied.test(0)); EXPECT_EQ(1U, msg.numPackets); - EXPECT_TRUE(&packet0 == msg.getOrAllocPacket(0)); + EXPECT_EQ(packet0.descriptor, msg.getOrAllocPacket(0)->descriptor); EXPECT_TRUE(msg.occupied.test(0)); EXPECT_EQ(1U, msg.numPackets); @@ -1298,9 +1313,9 @@ TEST_F(SenderTest, MessageBucket_findMessage) Sender::MessageBucket* bucket = sender->messageBuckets.buckets.at(0); Sender::Message* msg0 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::Message* msg1 = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); msg0->id = {42, 0}; msg1->id = {42, 1}; Protocol::MessageId id_none = {42, 42}; @@ -1329,35 +1344,44 @@ TEST_F(SenderTest, sendMessage_basic) { Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; + uint16_t sport = 0; + uint16_t dport = 60001; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(sport)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); - setMessagePacket(message, 0, &mockPacket); + setMessagePacket(message, 0, mockPacket); + Driver::Packet& mockPacket = message->packets[0]; message->messageLength = 420; mockPacket.length = message->messageLength + message->TRANSPORT_HEADER_LENGTH; - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, dport}; Core::Policy::Unscheduled policy = {1, 3000, 2}; EXPECT_FALSE(bucket->messages.contains(&message->bucketNode)); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(420))) .WillOnce(Return(policy)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); + int mockPriority = 0; + EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket), Eq(destination.ip), _)) + .WillOnce( + [&mockPriority](auto _1, auto _2, int p) { mockPriority = p; }); sender->sendMessage(message, destination, Sender::Message::Options::NO_RETRY); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Sender::Message::Options::NO_RETRY, message->options); // Check packet metadata Protocol::Packet::DataHeader* header = static_cast(mockPacket.payload); + EXPECT_EQ(htobe16(sport), header->common.prefix.sport); + EXPECT_EQ(htobe16(dport), header->common.prefix.dport); EXPECT_EQ(id, header->common.messageId); EXPECT_EQ(420U, header->totalLength); EXPECT_EQ(policy.version, header->policyVersion); @@ -1370,59 +1394,61 @@ TEST_F(SenderTest, sendMessage_basic) EXPECT_EQ(10100U, message->pingTimeout.expirationCycleTime); // Check sent packet metadata - EXPECT_EQ(22U, (uint64_t)mockPacket.address); - EXPECT_EQ(policy.priority, mockPacket.priority); + EXPECT_EQ(policy.priority, mockPriority); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); - EXPECT_FALSE(sender->sendReady.load()); + EXPECT_FALSE(sender->sendReady); } TEST_F(SenderTest, sendMessage_multipacket) { char payload0[1027]; char payload1[1027]; - NiceMock packet0(payload0); - NiceMock packet1(payload1); + Homa::Mock::MockDriver::PacketBuf packetBuf0{payload0}; + Homa::Mock::MockDriver::PacketBuf packetBuf1{payload1}; Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); - setMessagePacket(message, 0, &packet0); - setMessagePacket(message, 1, &packet1); + setMessagePacket(message, 0, packetBuf0.toPacket()); + setMessagePacket(message, 1, packetBuf1.toPacket()); + Driver::Packet& packet0 = message->packets[0]; + Driver::Packet& packet1 = message->packets[1]; + message->messageLength = 1420; - packet0.length = 1000 + 27; - packet1.length = 420 + 27; - Driver::Address destination = (Driver::Address)22; + packet0.length = 1000 + 31; + packet1.length = 420 + 31; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 1000, 2}; - EXPECT_EQ(27U, sizeof(Protocol::Packet::DataHeader)); + EXPECT_EQ(31U, sizeof(Protocol::Packet::DataHeader)); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); EXPECT_CALL(mockPolicyManager, - getUnscheduledPolicy(Eq(destination), Eq(1420))) + getUnscheduledPolicy(Eq(destination.ip), Eq(1420))) .WillOnce(Return(policy)); sender->sendMessage(message, destination); // Check Message metadata EXPECT_EQ(id, message->id); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); // Check packet metadata Protocol::Packet::DataHeader* header = nullptr; // Packet0 - EXPECT_EQ(22U, (uint64_t)packet0.address); header = static_cast(packet0.payload); EXPECT_EQ(message->id, header->common.messageId); EXPECT_EQ(message->messageLength, header->totalLength); // Packet1 - EXPECT_EQ(22U, (uint64_t)packet1.address); header = static_cast(packet1.payload); EXPECT_EQ(message->id, header->common.messageId); - EXPECT_EQ(destination, message->destination); + EXPECT_EQ(destination.ip, message->destination.ip); + EXPECT_EQ(destination.port, message->destination.port); EXPECT_EQ(message->messageLength, header->totalLength); // Check Sender metadata @@ -1433,7 +1459,7 @@ TEST_F(SenderTest, sendMessage_multipacket) // Check sendQueue metadata Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); - EXPECT_TRUE(sender->sendReady.load()); + EXPECT_TRUE(sender->sendReady); } TEST_F(SenderTest, sendMessage_missingPacket) @@ -1441,13 +1467,13 @@ TEST_F(SenderTest, sendMessage_missingPacket) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); - setMessagePacket(message, 1, &mockPacket); + dynamic_cast(sender->allocMessage(0)); + setMessagePacket(message, 1, mockPacket); Core::Policy::Unscheduled policy = {1, 1000, 2}; ON_CALL(mockPolicyManager, getUnscheduledPolicy(_, _)) .WillByDefault(Return(policy)); - EXPECT_DEATH(sender->sendMessage(message, Driver::Address()), + EXPECT_DEATH(sender->sendMessage(message, SocketAddress{0, 0}), ".*Incomplete message with id \\(22:1\\); missing packet at " "offset 0; this shouldn't happen.*"); } @@ -1457,17 +1483,17 @@ TEST_F(SenderTest, sendMessage_unscheduledLimit) Protocol::MessageId id = {sender->transportId, sender->nextMessageSequenceNumber}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); for (int i = 0; i < 9; ++i) { - setMessagePacket(message, i, &mockPacket); + mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); + setMessagePacket(message, i, mockPacket); } message->messageLength = 9000; - mockPacket.length = 1000 + sizeof(Protocol::Packet::DataHeader); - Driver::Address destination = (Driver::Address)22; + SocketAddress destination = {22, 60001}; Core::Policy::Unscheduled policy = {1, 4500, 2}; EXPECT_EQ(9U, message->numPackets); EXPECT_EQ(1000U, message->PACKET_DATA_LENGTH); - EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination, 9000)) + EXPECT_CALL(mockPolicyManager, getUnscheduledPolicy(destination.ip, 9000)) .WillOnce(Return(policy)); sender->sendMessage(message, destination); @@ -1481,7 +1507,7 @@ TEST_F(SenderTest, cancelMessage) { Protocol::MessageId id = {42, 1}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message, true, 5); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message->pingTimeout); @@ -1505,7 +1531,7 @@ TEST_F(SenderTest, cancelMessage) TEST_F(SenderTest, dropMessage) { Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); EXPECT_EQ(1U, sender->messageAllocator.pool.outstandingObjects); sender->dropMessage(message); @@ -1518,7 +1544,7 @@ TEST_F(SenderTest, checkMessageTimeouts_basic) Sender::Message* message[4]; for (uint64_t i = 0; i < 4; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->messageTimeouts.setTimeout(&message[i]->messageTimeout); @@ -1581,7 +1607,7 @@ TEST_F(SenderTest, checkPingTimeouts_basic) Sender::Message* message[5]; for (uint64_t i = 0; i < 5; ++i) { Protocol::MessageId id = {42, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); SenderTest::addMessage(sender, id, message[i]); Sender::MessageBucket* bucket = sender->messageBuckets.getBucket(id); bucket->pingTimeouts.setTimeout(&message[i]->pingTimeout); @@ -1605,9 +1631,10 @@ TEST_F(SenderTest, checkPingTimeouts_basic) EXPECT_EQ(10000U, PerfUtils::Cycles::rdtsc()); - EXPECT_CALL(mockDriver, allocPacket()).WillOnce(Return(&mockPacket)); - EXPECT_CALL(mockDriver, sendPacket(Eq(&mockPacket))).Times(1); - EXPECT_CALL(mockDriver, releasePackets(Pointee(&mockPacket), Eq(1))) + EXPECT_CALL(mockDriver, allocPacket(_)) + .WillOnce([this](Driver::Packet* packet) { *packet = mockPacket; }); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&mockPacket), _, _)).Times(1); + EXPECT_CALL(mockDriver, releasePackets(EqPacket(&mockPacket), Eq(1))) .Times(1); uint64_t nextTimeout = sender->checkPingTimeouts(); @@ -1645,16 +1672,18 @@ TEST_F(SenderTest, trySend_basic) { Protocol::MessageId id = {42, 10}; Sender::Message* message = - dynamic_cast(sender->allocMessage()); + dynamic_cast(sender->allocMessage(0)); Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; SenderTest::addMessage(sender, id, message, true, 3); - Homa::Mock::MockDriver::MockPacket* packet[5]; + Driver::Packet packet[5]; + uint64_t waitUntil; const uint32_t PACKET_SIZE = sender->driver->getMaxPayloadSize(); const uint32_t PACKET_DATA_SIZE = PACKET_SIZE - message->TRANSPORT_HEADER_LENGTH; for (int i = 0; i < 5; ++i) { - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); - packet[i]->length = PACKET_SIZE; + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packet[i] = packetBuf->toPacket(); + packet[i].length = PACKET_SIZE; setMessagePacket(message, i, packet[i]); info->unsentBytes += PACKET_DATA_SIZE; } @@ -1668,9 +1697,9 @@ TEST_F(SenderTest, trySend_basic) EXPECT_TRUE(sender->sendQueue.contains(&info->sendQueueNode)); // 3 granted packets; 2 will send; queue limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); + waitUntil = sender->trySend(); // < test call EXPECT_TRUE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1681,8 +1710,8 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); // 1 packet to be sent; grant limit reached. - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1695,7 +1724,7 @@ TEST_F(SenderTest, trySend_basic) // No additional grants; spurious ready hint. EXPECT_CALL(mockDriver, sendPacket).Times(0); sender->sendReady = true; - sender->trySend(); // < test call + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::IN_PROGRESS, message->state); EXPECT_EQ(3U, info->packetsGranted); @@ -1708,9 +1737,9 @@ TEST_F(SenderTest, trySend_basic) // 2 more granted packets; will finish. info->packetsGranted = 5; sender->sendReady = true; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[3]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[4]))); - sender->trySend(); // < test call + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[3]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[4]), _, _)); + waitUntil = sender->trySend(); // < test call EXPECT_FALSE(sender->sendReady); EXPECT_EQ(Homa::OutMessage::Status::SENT, message->state); EXPECT_EQ(5U, info->packetsGranted); @@ -1721,7 +1750,8 @@ TEST_F(SenderTest, trySend_basic) Mock::VerifyAndClearExpectations(&mockDriver); for (int i = 0; i < 5; ++i) { - delete packet[i]; + uintptr_t packetBuf = packet[i].descriptor; + delete (Homa::Mock::MockDriver::PacketBuf*)packetBuf; } } @@ -1729,17 +1759,18 @@ TEST_F(SenderTest, trySend_multipleMessages) { Sender::Message* message[3]; Sender::QueuedMessageInfo* info[3]; - Homa::Mock::MockDriver::MockPacket* packet[3]; + Driver::Packet packet[3]; for (uint64_t i = 0; i < 3; ++i) { Protocol::MessageId id = {22, 10 + i}; - message[i] = dynamic_cast(sender->allocMessage()); + message[i] = dynamic_cast(sender->allocMessage(0)); info[i] = &message[i]->queuedMessageInfo; SenderTest::addMessage(sender, id, message[i], true, 1); - packet[i] = new Homa::Mock::MockDriver::MockPacket(payload); - packet[i]->length = sender->driver->getMaxPayloadSize() / 4; + auto* packetBuf = new Homa::Mock::MockDriver::PacketBuf{payload}; + packet[i] = packetBuf->toPacket(); + packet[i].length = sender->driver->getMaxPayloadSize() / 4; setMessagePacket(message[i], 0, packet[i]); info[i]->unsentBytes += - (packet[i]->length - message[i]->TRANSPORT_HEADER_LENGTH); + (packet[i].length - message[i]->TRANSPORT_HEADER_LENGTH); message[i]->state = Homa::OutMessage::Status::IN_PROGRESS; } sender->sendReady = true; @@ -1751,19 +1782,20 @@ TEST_F(SenderTest, trySend_multipleMessages) // Message 1: Will reach grant limit EXPECT_EQ(1, info[1]->packetsGranted); info[1]->packetsSent = 0; - setMessagePacket(message[1], 1, nullptr); + setMessagePacket(message[1], 1, {}); EXPECT_EQ(2, message[1]->numPackets); // Message 2: Will finish EXPECT_EQ(1, info[2]->packetsGranted); info[2]->packetsSent = 0; - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[0]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[1]))); - EXPECT_CALL(mockDriver, sendPacket(Eq(packet[2]))); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[0]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[1]), _, _)); + EXPECT_CALL(mockDriver, sendPacket(EqPacket(&packet[2]), _, _)); - sender->trySend(); + uint64_t waitUntil = sender->trySend(); + EXPECT_EQ(waitUntil, 0); EXPECT_EQ(1U, info[0]->packetsSent); EXPECT_EQ(Homa::OutMessage::Status::SENT, message[0]->state); EXPECT_FALSE(sender->sendQueue.contains(&info[0]->sendQueueNode)); @@ -1775,33 +1807,12 @@ TEST_F(SenderTest, trySend_multipleMessages) EXPECT_FALSE(sender->sendQueue.contains(&info[2]->sendQueueNode)); } -TEST_F(SenderTest, trySend_alreadyRunning) -{ - Protocol::MessageId id = {42, 1}; - Sender::Message* message = - dynamic_cast(sender->allocMessage()); - Sender::QueuedMessageInfo* info = &message->queuedMessageInfo; - SenderTest::addMessage(sender, id, message, true, 1); - setMessagePacket(message, 0, &mockPacket); - message->messageLength = 1000; - EXPECT_EQ(1U, message->numPackets); - EXPECT_EQ(1, info->packetsGranted); - EXPECT_EQ(0, info->packetsSent); - - sender->sending.test_and_set(); - - EXPECT_CALL(mockDriver, sendPacket).Times(0); - - sender->trySend(); - - EXPECT_EQ(0, info->packetsSent); -} - TEST_F(SenderTest, trySend_nothingToSend) { EXPECT_TRUE(sender->sendQueue.empty()); EXPECT_CALL(mockDriver, sendPacket).Times(0); - sender->trySend(); + uint64_t waitUntil = sender->trySend(); + EXPECT_EQ(waitUntil, 0); } } // namespace diff --git a/src/TransportImpl.cc b/src/TransportImpl.cc index 310e099..b249b22 100644 --- a/src/TransportImpl.cc +++ b/src/TransportImpl.cc @@ -14,17 +14,11 @@ */ #include "TransportImpl.h" - -#include -#include -#include - #include "Cycles.h" #include "Perf.h" #include "Protocol.h" -namespace Homa { -namespace Core { +namespace Homa::Core { // Basic timeout unit. const uint64_t BASE_TIMEOUT_US = 2000; @@ -35,121 +29,142 @@ const uint64_t PING_INTERVAL_US = 3 * BASE_TIMEOUT_US; /// Microseconds to wait before performing retires on inbound messages. const uint64_t RESEND_INTERVAL_US = BASE_TIMEOUT_US; +/// See Homa::Core::Transport::create() +Homa::unique_ptr +Transport::create(Driver* driver, Callbacks* callbacks, uint64_t transportId) +{ + Transport* transport = + new Core::TransportImpl(driver, callbacks, transportId); + return Homa::unique_ptr(transport); +} + /** - * Construct an instances of a Homa-based transport. + * Construct an instance of a Homa-based transport. * * @param driver * Driver with which this transport should send and receive packets. + * @param callbacks + * User-defined transport callbacks. * @param transportId * This transport's unique identifier in the group of transports among * which this transport will communicate. */ -TransportImpl::TransportImpl(Driver* driver, uint64_t transportId) +TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, + uint64_t transportId) : transportId(transportId) + , callbacks(callbacks) , driver(driver) , policyManager(new Policy::Manager(driver)) - , sender(new Sender(transportId, driver, policyManager.get(), + , sender(new Sender(transportId, driver, callbacks, policyManager.get(), PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(PING_INTERVAL_US))) , receiver( - new Receiver(driver, policyManager.get(), + new Receiver(driver, callbacks, policyManager.get(), PerfUtils::Cycles::fromMicroseconds(MESSAGE_TIMEOUT_US), PerfUtils::Cycles::fromMicroseconds(RESEND_INTERVAL_US))) - , nextTimeoutCycles(0) {} /** - * TransportImpl Destructor. + * Construct an instance of a Homa-based transport for unit testing. */ -TransportImpl::~TransportImpl() = default; +TransportImpl::TransportImpl(Driver* driver, Callbacks* callbacks, + Sender* sender, Receiver* receiver, + uint64_t transportId) + : transportId(transportId) + , callbacks(callbacks) + , driver(driver) + , policyManager(new Policy::Manager(driver)) + , sender(sender) + , receiver(receiver) +{} -/// See Homa::Transport::poll() +/// See Homa::TransportBase::free() void -TransportImpl::poll() +TransportImpl::free() { - // Receive and dispatch incoming packets. - processPackets(); + // We simply call "delete this" here because the only way to instantiate + // a Core::TransportImpl instance is via "new" in Transport::create(). + // An alternative would be to provide a static free() method that takes + // a pointer to Transport, the downside of this approach is that we must + // cast the argument to TransportImpl* because polymorphic deletion is + // disabled on the Transport interface. + delete this; +} - // Allow sender and receiver to make incremental progress. - sender->poll(); - receiver->poll(); +/// See Homa::TransportBase::alloc() +Homa::unique_ptr +TransportImpl::alloc(uint16_t port) +{ + OutMessage* outMessage = sender->allocMessage(port); + return unique_ptr(outMessage); +} - if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { - uint64_t requestedTimeoutCycles; - requestedTimeoutCycles = sender->checkTimeouts(); - nextTimeoutCycles.store(requestedTimeoutCycles); - requestedTimeoutCycles = receiver->checkTimeouts(); - if (nextTimeoutCycles.load() > requestedTimeoutCycles) { - nextTimeoutCycles.store(requestedTimeoutCycles); - } - } +/// See Homa::Core::Transport::checkTimeouts() +uint64_t +TransportImpl::checkTimeouts() +{ + uint64_t requestedTimeoutCycles = + std::min(sender->checkTimeouts(), receiver->checkTimeouts()); + return requestedTimeoutCycles; } -/** - * Helper method which receives a burst of incoming packets and process them - * through the transport protocol. Pulled out of TransportImpl::poll() to - * simplify unit testing. - */ +/// See Homa::Core::Transport::processPacket() void -TransportImpl::processPackets() +TransportImpl::processPacket(Driver::Packet* packet, IpAddress sourceIp) { - // Keep track of time spent doing active processing versus idle. - Perf::Timer activityTimer; - activityTimer.split(); - uint64_t activeTime = 0; - uint64_t idleTime = 0; - - const int MAX_BURST = 32; - Driver::Packet* packets[MAX_BURST]; - int numPackets = driver->receivePackets(MAX_BURST, packets); - for (int i = 0; i < numPackets; ++i) { - Driver::Packet* packet = packets[i]; - assert(packet->length >= - Util::downCast(sizeof(Protocol::Packet::CommonHeader))); - Perf::counters.rx_bytes.add(packet->length); - Protocol::Packet::CommonHeader* header = - static_cast(packet->payload); - switch (header->opcode) { - case Protocol::Packet::DATA: - Perf::counters.rx_data_pkts.add(1); - receiver->handleDataPacket(packet, driver); - break; - case Protocol::Packet::GRANT: - Perf::counters.rx_grant_pkts.add(1); - sender->handleGrantPacket(packet, driver); - break; - case Protocol::Packet::DONE: - Perf::counters.rx_done_pkts.add(1); - sender->handleDonePacket(packet, driver); - break; - case Protocol::Packet::RESEND: - Perf::counters.rx_resend_pkts.add(1); - sender->handleResendPacket(packet, driver); - break; - case Protocol::Packet::BUSY: - Perf::counters.rx_busy_pkts.add(1); - receiver->handleBusyPacket(packet, driver); - break; - case Protocol::Packet::PING: - Perf::counters.rx_ping_pkts.add(1); - receiver->handlePingPacket(packet, driver); - break; - case Protocol::Packet::UNKNOWN: - Perf::counters.rx_unknown_pkts.add(1); - sender->handleUnknownPacket(packet, driver); - break; - case Protocol::Packet::ERROR: - Perf::counters.rx_error_pkts.add(1); - sender->handleErrorPacket(packet, driver); - break; - } - activeTime += activityTimer.split(); + assert(packet->length >= + Util::downCast(sizeof(Protocol::Packet::CommonHeader))); + Perf::counters.rx_bytes.add(packet->length); + Protocol::Packet::CommonHeader* header = + static_cast(packet->payload); + switch (header->opcode) { + case Protocol::Packet::DATA: + Perf::counters.rx_data_pkts.add(1); + receiver->handleDataPacket(packet, sourceIp); + break; + case Protocol::Packet::GRANT: + Perf::counters.rx_grant_pkts.add(1); + sender->handleGrantPacket(packet); + break; + case Protocol::Packet::DONE: + Perf::counters.rx_done_pkts.add(1); + sender->handleDonePacket(packet); + break; + case Protocol::Packet::RESEND: + Perf::counters.rx_resend_pkts.add(1); + sender->handleResendPacket(packet); + break; + case Protocol::Packet::BUSY: + Perf::counters.rx_busy_pkts.add(1); + receiver->handleBusyPacket(packet); + break; + case Protocol::Packet::PING: + Perf::counters.rx_ping_pkts.add(1); + receiver->handlePingPacket(packet, sourceIp); + break; + case Protocol::Packet::UNKNOWN: + Perf::counters.rx_unknown_pkts.add(1); + sender->handleUnknownPacket(packet); + break; + case Protocol::Packet::ERROR: + Perf::counters.rx_error_pkts.add(1); + sender->handleErrorPacket(packet); + break; } - idleTime += activityTimer.split(); +} - Perf::counters.active_cycles.add(activeTime); - Perf::counters.idle_cycles.add(idleTime); +/// See Homa::Core::Transport::trySend() +uint64_t +TransportImpl::trySend() +{ + return sender->trySend(); +} + +/// See Homa::Core::Transport::trySendGrants() +bool +TransportImpl::trySendGrants() +{ + return receiver->trySendGrants(); } -} // namespace Core -} // namespace Homa +} // namespace Homa::Core diff --git a/src/TransportImpl.h b/src/TransportImpl.h index 2d559be..f083375 100644 --- a/src/TransportImpl.h +++ b/src/TransportImpl.h @@ -16,14 +16,10 @@ #ifndef HOMA_CORE_TRANSPORT_H #define HOMA_CORE_TRANSPORT_H -#include +#include #include #include -#include -#include -#include -#include #include "ObjectPool.h" #include "Policy.h" @@ -34,51 +30,46 @@ /** * Homa */ -namespace Homa { -namespace Core { +namespace Homa::Core { /** - * Internal implementation of Homa::Transport. - * + * Internal implementation of Homa::Core::Transport. */ -class TransportImpl : public Transport { +class TransportImpl final : public Transport { public: - explicit TransportImpl(Driver* driver, uint64_t transportId); - ~TransportImpl(); - - /// See Homa::Transport::alloc() - virtual Homa::unique_ptr alloc() - { - return Homa::unique_ptr(sender->allocMessage()); - } - - /// See Homa::Transport::receive() - virtual Homa::unique_ptr receive() - { - return Homa::unique_ptr(receiver->receiveMessage()); - } - - virtual void poll(); - - /// See Homa::Transport::getDriver() - virtual Driver* getDriver() + explicit TransportImpl(Driver* driver, Callbacks* callbacks, + uint64_t transportId); + explicit TransportImpl(Driver* driver, Callbacks* callbacks, Sender* sender, + Receiver* receiver, uint64_t transportId); + virtual ~TransportImpl() = default; + void free() override; + Homa::unique_ptr alloc(uint16_t port) override; + uint64_t checkTimeouts() override; + void processPacket(Driver::Packet* packet, IpAddress source) override; + uint64_t trySend() override; + bool trySendGrants() override; + + /// See Homa::Core::Transport::getDriver() + Driver* getDriver() override { return driver; } - /// See Homa::Transport::getId() - virtual uint64_t getId() + /// See Homa::TransportBase::getId() + uint64_t getId() override { return transportId; } private: - void processPackets(); - /// Unique identifier for this transport. - const std::atomic transportId; + const uint64_t transportId; + + /// User-defined transport callbacks. Not owned by this class. + Callbacks* const callbacks; /// Driver from which this transport will send and receive packets. + /// Not owned by this class. Driver* const driver; /// Module which manages the network packet priority policy. @@ -89,12 +80,8 @@ class TransportImpl : public Transport { /// Module which receives packets and forms them into messages. std::unique_ptr receiver; - - /// Caches the next cycle time that timeouts will need to rechecked. - std::atomic nextTimeoutCycles; }; -} // namespace Core -} // namespace Homa +} // namespace Homa::Core #endif // HOMA_CORE_TRANSPORT_H diff --git a/src/TransportImplTest.cc b/src/TransportImplTest.cc index 0e0ab60..f3706f5 100644 --- a/src/TransportImplTest.cc +++ b/src/TransportImplTest.cc @@ -27,6 +27,7 @@ namespace Homa { namespace Core { namespace { +using ::testing::_; using ::testing::DoAll; using ::testing::Eq; using ::testing::NiceMock; @@ -36,138 +37,37 @@ using ::testing::SetArrayArgument; class TransportImplTest : public ::testing::Test { public: TransportImplTest() - : mockDriver() - , transport(new TransportImpl(&mockDriver, 22)) - , mockSender( - new NiceMock(22, &mockDriver, 0, 0)) - , mockReceiver( - new NiceMock(&mockDriver, 0, 0)) + : mockDriver(allocMockDriver()) + , mockSender(new NiceMock(22, mockDriver, 0, 0)) + , mockReceiver(new NiceMock(mockDriver, 0, 0)) + , transport(mockDriver, nullptr, mockSender, mockReceiver, 22) { - transport->sender.reset(mockSender); - transport->receiver.reset(mockReceiver); - ON_CALL(mockDriver, getBandwidth).WillByDefault(Return(8000)); - ON_CALL(mockDriver, getMaxPayloadSize).WillByDefault(Return(1024)); PerfUtils::Cycles::mockTscValue = 10000; } ~TransportImplTest() { - delete transport; + delete mockDriver; PerfUtils::Cycles::mockTscValue = 0; } - NiceMock mockDriver; - TransportImpl* transport; + NiceMock* allocMockDriver() + { + auto driver = new NiceMock(); + ON_CALL(*driver, getBandwidth).WillByDefault(Return(8000)); + ON_CALL(*driver, getMaxPayloadSize).WillByDefault(Return(1024)); + return driver; + } + + NiceMock* mockDriver; NiceMock* mockSender; NiceMock* mockReceiver; + TransportImpl transport; }; -TEST_F(TransportImplTest, poll) -{ - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - - transport->poll(); - - EXPECT_EQ(10000U, transport->nextTimeoutCycles); - - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); - EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); - - transport->poll(); - - EXPECT_EQ(10100U, transport->nextTimeoutCycles); - - EXPECT_CALL(mockDriver, receivePackets).WillOnce(Return(0)); - EXPECT_CALL(*mockSender, poll).Times(1); - EXPECT_CALL(*mockReceiver, poll).Times(1); - EXPECT_CALL(*mockSender, checkTimeouts).Times(0); - EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); - - transport->poll(); - - EXPECT_EQ(10100U, transport->nextTimeoutCycles); -} - -TEST_F(TransportImplTest, processPackets) +TEST_F(TransportImplTest, processPacket) { - char payload[8][1024]; - Homa::Driver::Packet* packets[8]; - - // Set DATA packet - Homa::Mock::MockDriver::MockPacket dataPacket(payload[0], 1024); - static_cast(dataPacket.payload) - ->common.opcode = Protocol::Packet::DATA; - packets[0] = &dataPacket; - EXPECT_CALL(*mockReceiver, - handleDataPacket(Eq(&dataPacket), Eq(&mockDriver))); - - // Set GRANT packet - Homa::Mock::MockDriver::MockPacket grantPacket(payload[1], 1024); - static_cast(grantPacket.payload) - ->common.opcode = Protocol::Packet::GRANT; - packets[1] = &grantPacket; - EXPECT_CALL(*mockSender, - handleGrantPacket(Eq(&grantPacket), Eq(&mockDriver))); - - // Set DONE packet - Homa::Mock::MockDriver::MockPacket donePacket(payload[2], 1024); - static_cast(donePacket.payload) - ->common.opcode = Protocol::Packet::DONE; - packets[2] = &donePacket; - EXPECT_CALL(*mockSender, - handleDonePacket(Eq(&donePacket), Eq(&mockDriver))); - - // Set RESEND packet - Homa::Mock::MockDriver::MockPacket resendPacket(payload[3], 1024); - static_cast(resendPacket.payload) - ->common.opcode = Protocol::Packet::RESEND; - packets[3] = &resendPacket; - EXPECT_CALL(*mockSender, - handleResendPacket(Eq(&resendPacket), Eq(&mockDriver))); - - // Set BUSY packet - Homa::Mock::MockDriver::MockPacket busyPacket(payload[4], 1024); - static_cast(busyPacket.payload) - ->common.opcode = Protocol::Packet::BUSY; - packets[4] = &busyPacket; - EXPECT_CALL(*mockReceiver, - handleBusyPacket(Eq(&busyPacket), Eq(&mockDriver))); - - // Set PING packet - Homa::Mock::MockDriver::MockPacket pingPacket(payload[5], 1024); - static_cast(pingPacket.payload) - ->common.opcode = Protocol::Packet::PING; - packets[5] = &pingPacket; - EXPECT_CALL(*mockReceiver, - handlePingPacket(Eq(&pingPacket), Eq(&mockDriver))); - - // Set UNKNOWN packet - Homa::Mock::MockDriver::MockPacket unknownPacket(payload[6], 1024); - static_cast(unknownPacket.payload) - ->common.opcode = Protocol::Packet::UNKNOWN; - packets[6] = &unknownPacket; - EXPECT_CALL(*mockSender, - handleUnknownPacket(Eq(&unknownPacket), Eq(&mockDriver))); - - // Set ERROR packet - Homa::Mock::MockDriver::MockPacket errorPacket(payload[7], 1024); - static_cast(errorPacket.payload) - ->common.opcode = Protocol::Packet::ERROR; - packets[7] = &errorPacket; - EXPECT_CALL(*mockSender, - handleErrorPacket(Eq(&errorPacket), Eq(&mockDriver))); - - EXPECT_CALL(mockDriver, receivePackets) - .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); - - transport->processPackets(); + // tested sufficiently in PollModeTransportImpl tests } } // namespace diff --git a/src/Transports/PollModeTransportImpl.cc b/src/Transports/PollModeTransportImpl.cc new file mode 100644 index 0000000..890e9ff --- /dev/null +++ b/src/Transports/PollModeTransportImpl.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "PollModeTransportImpl.h" + +namespace Homa { + +Homa::unique_ptr +PollModeTransport::create(Driver* driver, uint64_t transportId) +{ + return Homa::unique_ptr( + new PollModeTransportImpl(driver, transportId)); +} + +/** + * Constructor. + * + * @param driver + * Driver with which this transport should send and receive packets. + * @param transportId + * This transport's unique identifier in the group of transports among + * which this transport will communicate. + */ +PollModeTransportImpl::PollModeTransportImpl(Driver* driver, + uint64_t transportId) + : callbacks(this) + , core(driver, &callbacks, transportId) + , nextTimeoutCycles(0) +{} + +/** + * Construct for unit testing. + */ +PollModeTransportImpl::PollModeTransportImpl(Driver* driver, + Core::Sender* sender, + Core::Receiver* receiver, + uint64_t transportId) + : callbacks(this) + , core(driver, &callbacks, sender, receiver, transportId) + , nextTimeoutCycles(0) +{} + +/// See Homa::PollModeTransport::alloc() +Homa::unique_ptr +PollModeTransportImpl::alloc(uint16_t port) +{ + return core.alloc(port); +} + +/// See Homa::PollModeTransport::free() +void +PollModeTransportImpl::free() +{ + // This instance must be allocated via new from PollModeTransport::create(). + delete this; +} + +/// See Homa::PollModeTransport::getId() +uint64_t +PollModeTransportImpl::getId() +{ + return core.getId(); +} + +void +PollModeTransportImpl::poll() +{ + // Receive and dispatch incoming packets. + processPackets(); + + // Allow sender and receiver to make incremental progress. + core.trySend(); + core.trySendGrants(); + + if (PerfUtils::Cycles::rdtsc() >= nextTimeoutCycles.load()) { + uint64_t requestedTimeoutCycles = core.checkTimeouts(); + nextTimeoutCycles.store(requestedTimeoutCycles); + } +} + +/// See Homa::PollModeTransport::receive +Homa::unique_ptr +PollModeTransportImpl::receive() +{ + if (receiveQueue.empty()) { + return nullptr; + } + Homa::unique_ptr message = std::move(receiveQueue.back()); + receiveQueue.pop_back(); + return message; +} + +/** + * Helper method which receives a burst of incoming packets and process them + * through the transport protocol. Pulled out of PollModeTransportImpl::poll() + * to simplify unit testing. + */ +void +PollModeTransportImpl::processPackets() +{ + // Keep track of time spent doing active processing versus idle. + uint64_t cycles = PerfUtils::Cycles::rdtsc(); + + const int MAX_BURST = 32; + Driver::Packet packets[MAX_BURST]; + IpAddress srcAddrs[MAX_BURST]; + Driver* driver = core.getDriver(); + int numPackets = driver->receivePackets(MAX_BURST, packets, srcAddrs); + for (int i = 0; i < numPackets; ++i) { + core.processPacket(&packets[i], srcAddrs[i]); + } + + cycles = PerfUtils::Cycles::rdtsc() - cycles; + if (numPackets > 0) { + Perf::counters.active_cycles.add(cycles); + } else { + Perf::counters.idle_cycles.add(cycles); + } +} + +} // namespace Homa diff --git a/src/Transports/PollModeTransportImpl.h b/src/Transports/PollModeTransportImpl.h new file mode 100644 index 0000000..d614c0c --- /dev/null +++ b/src/Transports/PollModeTransportImpl.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef HOMA_POLLMODETRANSPORT_H +#define HOMA_POLLMODETRANSPORT_H + +#include +#include +#include "../TransportImpl.h" + +namespace Homa { + +/** + * Internal implementation of Homa::PollModeTransport. + */ +class PollModeTransportImpl final : public PollModeTransport { + public: + explicit PollModeTransportImpl(Driver* driver, uint64_t transportId); + explicit PollModeTransportImpl(Driver* driver, Core::Sender* sender, + Core::Receiver* receiver, + uint64_t transportId); + virtual ~PollModeTransportImpl() = default; + Homa::unique_ptr alloc(uint16_t port) override; + void free() override; + uint64_t getId() override; + void poll() override; + Homa::unique_ptr receive() override; + + private: + /** + * Callbacks defined for the polling-based transport implementation. + */ + class PollModeCallbacks : public Core::Transport::Callbacks { + public: + explicit PollModeCallbacks(PollModeTransportImpl* owner) + : owner(owner) + {} + + ~PollModeCallbacks() override = default; + + bool deliver(uint16_t port, + Homa::unique_ptr message) override + { + (void)port; + SpinLock::Lock _(owner->mutex); + owner->receiveQueue.push_back(std::move(message)); + return true; + } + + private: + PollModeTransportImpl* owner; + }; + + void processPackets(); + + /// Transport callbacks. + PollModeCallbacks callbacks; + + /// Core transport instance. + Core::TransportImpl core; + + /// Caches the next cycle time that timeouts will need to rechecked. + std::atomic nextTimeoutCycles; + + /// Monitor-style lock which protects the receive queue. + SpinLock mutex; + + /// Queue of completed incoming messages. + std::vector> receiveQueue; +}; + +} // namespace Homa + +#endif // HOMA_POLLMODETRANSPORT_H \ No newline at end of file diff --git a/src/Transports/PollModeTransportImplTest.cc b/src/Transports/PollModeTransportImplTest.cc new file mode 100644 index 0000000..286174c --- /dev/null +++ b/src/Transports/PollModeTransportImplTest.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2018-2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include + +#include "Mock/MockDriver.h" +#include "Mock/MockReceiver.h" +#include "Mock/MockSender.h" +#include "PollModeTransportImpl.h" +#include "Protocol.h" +#include "TransportImpl.h" +#include "Tub.h" + +namespace Homa { +namespace Core { +namespace { + +using ::testing::_; +using ::testing::DoAll; +using ::testing::Eq; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArrayArgument; + +/** + * Defines a matcher EqPacket(p) to match two Driver::Packet* by their + * underlying packet buffer descriptors. + */ +MATCHER_P(EqPacket, p, "") +{ + return arg->descriptor == p->descriptor; +} + +class PollModeTransportImplTest : public ::testing::Test { + public: + PollModeTransportImplTest() + : mockDriver(allocMockDriver()) + , mockSender(new NiceMock(22, mockDriver, 0, 0)) + , mockReceiver(new NiceMock(mockDriver, 0, 0)) + , transport(new PollModeTransportImpl(mockDriver, mockSender, + mockReceiver, 22)) + { + PerfUtils::Cycles::mockTscValue = 10000; + } + + ~PollModeTransportImplTest() + { + delete transport; + delete mockDriver; + PerfUtils::Cycles::mockTscValue = 0; + } + + NiceMock* allocMockDriver() + { + auto driver = new NiceMock(); + ON_CALL(*driver, getBandwidth).WillByDefault(Return(8000)); + ON_CALL(*driver, getMaxPayloadSize).WillByDefault(Return(1024)); + return driver; + } + + NiceMock* mockDriver; + NiceMock* mockSender; + NiceMock* mockReceiver; + PollModeTransportImpl* transport; +}; + +TEST_F(PollModeTransportImplTest, poll) +{ + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10000)); + EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); + + transport->poll(); + + EXPECT_EQ(10000U, transport->nextTimeoutCycles); + + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).WillOnce(Return(10200)); + EXPECT_CALL(*mockReceiver, checkTimeouts).WillOnce(Return(10100)); + + transport->poll(); + + EXPECT_EQ(10100U, transport->nextTimeoutCycles); + + EXPECT_CALL(*mockDriver, receivePackets).WillOnce(Return(0)); + EXPECT_CALL(*mockSender, trySend).Times(1); + EXPECT_CALL(*mockReceiver, trySendGrants).Times(1); + EXPECT_CALL(*mockSender, checkTimeouts).Times(0); + EXPECT_CALL(*mockReceiver, checkTimeouts).Times(0); + + transport->poll(); + + EXPECT_EQ(10100U, transport->nextTimeoutCycles); +} + +TEST_F(PollModeTransportImplTest, processPackets) +{ + char payload[8][1024]; + Homa::Driver::Packet packets[8]; + + // Set DATA packet + Homa::Mock::MockDriver::PacketBuf dataPacketBuf{payload[0]}; + Driver::Packet dataPacket = dataPacketBuf.toPacket(1024); + static_cast(dataPacket.payload) + ->common.opcode = Protocol::Packet::DATA; + packets[0] = dataPacket; + EXPECT_CALL(*mockReceiver, handleDataPacket(EqPacket(&packets[0]), _)); + + // Set GRANT packet + Homa::Mock::MockDriver::PacketBuf grantPacketBuf{payload[1]}; + Driver::Packet grantPacket = grantPacketBuf.toPacket(1024); + static_cast(grantPacket.payload) + ->common.opcode = Protocol::Packet::GRANT; + packets[1] = grantPacket; + EXPECT_CALL(*mockSender, handleGrantPacket(EqPacket(&packets[1]))); + + // Set DONE packet + Homa::Mock::MockDriver::PacketBuf donePacketBuf{payload[2]}; + Driver::Packet donePacket = donePacketBuf.toPacket(1024); + static_cast(donePacket.payload) + ->common.opcode = Protocol::Packet::DONE; + packets[2] = donePacket; + EXPECT_CALL(*mockSender, handleDonePacket(EqPacket(&packets[2]))); + + // Set RESEND packet + Homa::Mock::MockDriver::PacketBuf resendPacketBuf{payload[3]}; + Driver::Packet resendPacket = resendPacketBuf.toPacket(1024); + static_cast(resendPacket.payload) + ->common.opcode = Protocol::Packet::RESEND; + packets[3] = resendPacket; + EXPECT_CALL(*mockSender, handleResendPacket(EqPacket(&packets[3]))); + + // Set BUSY packet + Homa::Mock::MockDriver::PacketBuf busyPacketBuf{payload[4]}; + Driver::Packet busyPacket = busyPacketBuf.toPacket(1024); + static_cast(busyPacket.payload) + ->common.opcode = Protocol::Packet::BUSY; + packets[4] = busyPacket; + EXPECT_CALL(*mockReceiver, handleBusyPacket(EqPacket(&packets[4]))); + + // Set PING packet + Homa::Mock::MockDriver::PacketBuf pingPacketBuf{payload[5]}; + Driver::Packet pingPacket = pingPacketBuf.toPacket(1024); + static_cast(pingPacket.payload) + ->common.opcode = Protocol::Packet::PING; + packets[5] = pingPacket; + EXPECT_CALL(*mockReceiver, handlePingPacket(EqPacket(&packets[5]), _)); + + // Set UNKNOWN packet + Homa::Mock::MockDriver::PacketBuf unknownPacketBuf{payload[6]}; + Driver::Packet unknownPacket = unknownPacketBuf.toPacket(1024); + static_cast(unknownPacket.payload) + ->common.opcode = Protocol::Packet::UNKNOWN; + packets[6] = unknownPacket; + EXPECT_CALL(*mockSender, handleUnknownPacket(EqPacket(&packets[6]))); + + // Set ERROR packet + Homa::Mock::MockDriver::PacketBuf errorPacketBuf{payload[7]}; + Driver::Packet errorPacket = errorPacketBuf.toPacket(1024); + static_cast(errorPacket.payload) + ->common.opcode = Protocol::Packet::ERROR; + packets[7] = errorPacket; + EXPECT_CALL(*mockSender, handleErrorPacket(EqPacket(&packets[7]))); + + EXPECT_CALL(*mockDriver, receivePackets) + .WillOnce(DoAll(SetArrayArgument<1>(packets, packets + 8), Return(8))); + + transport->processPackets(); +} + +} // namespace +} // namespace Core +} // namespace Homa diff --git a/src/Transports/Shenango.cc b/src/Transports/Shenango.cc new file mode 100644 index 0000000..b7d05bd --- /dev/null +++ b/src/Transports/Shenango.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "Homa/Transports/Shenango.h" + +#include +#include "../Debug.h" +#include "Homa/Core/Transport.h" + +using namespace Homa; + +/** + * Shorthand for declaring "extern" function pointers to Shenango functions. + * These functions pointers will be initialized on the Shenango side in homa.c. + */ +#define DECLARE_SHENANGO_FUNC(ReturnType, MethodName, ...) \ + extern ReturnType (*shenango_##MethodName)(__VA_ARGS__); + +/** + * Protect RCU read-side critical sections. + */ +DECLARE_SHENANGO_FUNC(void, rcu_read_lock) +DECLARE_SHENANGO_FUNC(void, rcu_read_unlock) + +/** + * Allocate a Shenango mbuf struct to hold an egress Homa packet. + */ +DECLARE_SHENANGO_FUNC(void*, homa_tx_alloc_mbuf, void**) + +/** + * Free a packet buffer allocated earlier. + */ +DECLARE_SHENANGO_FUNC(void, mbuf_free, void*) + +/** + * Transmit an IP packet using Shenango's driver stack. + */ +DECLARE_SHENANGO_FUNC(int, homa_tx_ip, uintptr_t, void*, int32_t, uint8_t, + uint32_t, uint8_t) + +/** + * Deliver an ingress message to a homa socket in Shenango. + */ +DECLARE_SHENANGO_FUNC(void, homa_mb_deliver, void*, homa_inmsg) + +/** + * Return the number of bytes queued up in the transmit queue. + */ +DECLARE_SHENANGO_FUNC(uint32_t, homa_queued_bytes) + +/** + * Find a socket that matches the 5-tuple. + */ +DECLARE_SHENANGO_FUNC(void*, trans_table_lookup, uint8_t, SocketAddress, + SocketAddress) + +/** + * Callback functions specialized for the Shenango runtime. + */ +class ShenangoCallbacks final : Core::Transport::Callbacks { + public: + explicit ShenangoCallbacks(uint8_t proto, uint32_t local_ip, + std::function notify_send_ready) + : proto(proto) + , local_ip{local_ip} + , notify_send_ready(std::move(notify_send_ready)) + {} + + ~ShenangoCallbacks() override = default; + + bool deliver(uint16_t port, Homa::unique_ptr message) override + { + // The socket table in Shenango is protected by an RCU. + shenango_rcu_read_lock(); + SocketAddress laddr = {local_ip, port}; + void* trans_entry = shenango_trans_table_lookup(proto, laddr, {}); + if (trans_entry) { + shenango_homa_mb_deliver(trans_entry, + homa_inmsg{message.release()}); + } + shenango_rcu_read_unlock(); + return trans_entry != nullptr; + } + + void notifySendReady() override + { + notify_send_ready(); + } + + /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. + const uint8_t proto; + + /// Local IP address of the transport. + const IpAddress local_ip; + + /// Callback function for notifySendReady(). + std::function notify_send_ready; +}; + +/** + * A simple shim driver that translates Driver operations to Shenango + * functions. + */ +class ShenangoDriver final : public Driver { + public: + explicit ShenangoDriver(uint8_t proto, uint32_t local_ip, + uint32_t max_payload, uint32_t link_speed) + : Driver() + , proto(proto) + , local_ip{local_ip} + , max_payload(max_payload) + , link_speed(link_speed) + , callbacks() + {} + + ~ShenangoDriver() override = default; + + void allocPacket(Packet* packet) override + { + void* mbuf = shenango_homa_tx_alloc_mbuf(&packet->payload); + packet->descriptor = reinterpret_cast(mbuf); + packet->length = 0; + } + + void sendPacket(Packet* packet, IpAddress destination, + int priority) override + { + shenango_homa_tx_ip(packet->descriptor, packet->payload, packet->length, + proto, (uint32_t)destination, (uint8_t)priority); + } + + uint32_t receivePackets(uint32_t maxPackets, Packet receivedPackets[], + IpAddress sourceAddresses[]) override + { + (void)maxPackets; + (void)receivedPackets; + (void)sourceAddresses; + PANIC("receivePackets must not be called when used with Shenango"); + return 0; + } + + void releasePackets(Packet packets[], uint16_t numPackets) override + { + for (uint16_t i = 0; i < numPackets; i++) { + shenango_mbuf_free((void*)packets[i].descriptor); + } + } + + uint32_t getMaxPayloadSize() override + { + return max_payload; + } + + uint32_t getBandwidth() override + { + return link_speed; + } + + IpAddress getLocalAddress() override + { + return local_ip; + } + + uint32_t getQueuedBytes() override + { + return shenango_homa_queued_bytes(); + } + + /// Protocol number reserved for Homa; defined as IPPROTO_HOMA in Shenango. + const uint8_t proto; + + /// Local IP address of the driver. + const IpAddress local_ip; + + /// # bytes in a payload + const uint32_t max_payload; + + /// Effective network bandwidth, in Mbits/second. + const uint32_t link_speed; + + /// Callback object. Piggybacked here to allow automatic destruction. + std::unique_ptr callbacks; +}; + +homa_trans +homa_create_shenango_trans(uint64_t id, uint8_t proto, uint32_t local_ip, + uint32_t max_payload, uint32_t link_speed, + void (*cb_send_ready)(void*), void* cb_data) +{ + ShenangoCallbacks* callbacks = new ShenangoCallbacks( + proto, local_ip, std::bind(cb_send_ready, cb_data)); + ShenangoDriver* drv = + new ShenangoDriver(proto, local_ip, max_payload, link_speed); + drv->callbacks.reset(callbacks); + return homa_trans_create(homa_driver{drv}, homa_callbacks{callbacks}, id); +} + +void +homa_free_shenango_trans(homa_trans trans) +{ + homa_driver drv = homa_trans_get_drv(trans); + homa_trans_free(trans); + delete static_cast(drv.p); +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 403a340..e01e945 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,6 +34,18 @@ target_link_libraries(system_test docopt ) +## dpdk_test ################################################################# + +add_executable(dpdk_test + dpdk_test.cc +) +target_link_libraries(dpdk_test + PRIVATE + Homa::DpdkDriver + docopt + PerfUtils +) + ## Perf ######################################################################## add_executable(Perf diff --git a/test/Output.h b/test/Output.h new file mode 100644 index 0000000..467280e --- /dev/null +++ b/test/Output.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef HOMA_TEST_OUTPUT_H +#define HOMA_TEST_OUTPUT_H + +#include +#include +#include +#include + +namespace Output { + +using Latency = std::chrono::duration; + +struct TimeDist { + Latency min; // Fastest time seen (seconds). + Latency p50; // Median time per operation (seconds). + Latency p90; // 90th percentile time/op (seconds). + Latency p99; // 99th percentile time/op (seconds). + Latency p999; // 99.9th percentile time/op (seconds). +}; + +std::string +format(const std::string& format, ...) +{ + va_list args; + va_start(args, format); + size_t len = std::vsnprintf(NULL, 0, format.c_str(), args); + va_end(args); + std::vector vec(len + 1); + va_start(args, format); + std::vsnprintf(&vec[0], len + 1, format.c_str(), args); + va_end(args); + return &vec[0]; +} + +std::string +formatTime(Latency seconds) +{ + if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f ns", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.1f us", + std::chrono::duration(seconds).count()); + } else if (seconds < std::chrono::duration(1)) { + return format( + "%5.2f ms", + std::chrono::duration(seconds).count()); + } else { + return format("%5.2f s ", seconds.count()); + } +} + +std::string +basicHeader() +{ + return "median min p90 p99 p999 description"; +} + +std::string +basic(std::vector& times, const std::string description) +{ + int count = times.size(); + std::sort(times.begin(), times.end()); + + TimeDist dist; + + dist.min = times[0]; + int index = count / 2; + if (index < count) { + dist.p50 = times.at(index); + } else { + dist.p50 = dist.min; + } + index = count - (count + 5) / 10; + if (index < count) { + dist.p90 = times.at(index); + } else { + dist.p90 = dist.p50; + } + index = count - (count + 50) / 100; + if (index < count) { + dist.p99 = times.at(index); + } else { + dist.p99 = dist.p90; + } + index = count - (count + 500) / 1000; + if (index < count) { + dist.p999 = times.at(index); + } else { + dist.p999 = dist.p99; + } + + std::string output = ""; + output += format("%9s", formatTime(dist.p50).c_str()); + output += format(" %9s", formatTime(dist.min).c_str()); + output += format(" %9s", formatTime(dist.p90).c_str()); + output += format(" %9s", formatTime(dist.p99).c_str()); + output += format(" %9s", formatTime(dist.p999).c_str()); + output += " "; + output += description; + return output; +} + +} // namespace Output + +#endif // HOMA_TEST_OUTPUT_H \ No newline at end of file diff --git a/test/dpdk_test.cc b/test/dpdk_test.cc new file mode 100644 index 0000000..38c9bcc --- /dev/null +++ b/test/dpdk_test.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2020, Stanford University + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include + +#include +#include +#include +#include + +#include "Output.h" + +static const char USAGE[] = R"(DPDK Driver Test. + + Usage: + dpdk_test [options] (--server | ) + + Options: + -h --help Show this screen. + --version Show version. + --timetrace Enable TimeTrace output [default: false]. +)"; + +int +main(int argc, char* argv[]) +{ + std::map args = + docopt::docopt(USAGE, {argv + 1, argv + argc}, + true, // show help if requested + "DPDK Driver Test"); // version string + + std::string iface = args[""].asString(); + bool isServer = args["--server"].asBool(); + std::string server_ip_string; + if (!isServer) { + server_ip_string = args[""].asString(); + } + + Homa::Drivers::DPDK::DpdkDriver driver(iface.c_str()); + + if (isServer) { + std::cout << Homa::IpAddress::toString(driver.getLocalAddress()) + << std::endl; + while (true) { + Homa::Driver::Packet incoming[10]; + Homa::IpAddress srcAddrs[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); + } while (receivedPackets == 0); + Homa::Driver::Packet pong; + driver.allocPacket(&pong); + pong.length = 100; + driver.sendPacket(&pong, srcAddrs[0], 0); + driver.releasePackets(incoming, receivedPackets); + driver.releasePackets(&pong, 1); + } + } else { + Homa::IpAddress server_ip = + Homa::IpAddress::fromString(server_ip_string.c_str()); + std::vector times; + for (int i = 0; i < 100000; ++i) { + uint64_t start = PerfUtils::Cycles::rdtsc(); + PerfUtils::TimeTrace::record(start, "START"); + Homa::Driver::Packet ping; + driver.allocPacket(&ping); + PerfUtils::TimeTrace::record("allocPacket"); + ping.length = 100; + PerfUtils::TimeTrace::record("set ping args"); + driver.sendPacket(&ping, server_ip, 0); + PerfUtils::TimeTrace::record("sendPacket"); + driver.releasePackets(&ping, 1); + PerfUtils::TimeTrace::record("releasePacket"); + Homa::Driver::Packet incoming[10]; + Homa::IpAddress srcAddrs[10]; + uint32_t receivedPackets; + do { + receivedPackets = driver.receivePackets(10, incoming, srcAddrs); + PerfUtils::TimeTrace::record("receivePackets"); + } while (receivedPackets == 0); + driver.releasePackets(incoming, receivedPackets); + PerfUtils::TimeTrace::record("releasePacket"); + uint64_t stop = PerfUtils::Cycles::rdtsc(); + times.emplace_back(PerfUtils::Cycles::toSeconds(stop - start)); + } + if (args["--timetrace"].asBool()) { + PerfUtils::TimeTrace::print(); + } + std::cout << Output::basicHeader() << std::endl; + std::cout << Output::basic(times, "DpdkDriver Ping-Pong") << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/test/system_test.cc b/test/system_test.cc index 8e43238..1e5b8c7 100644 --- a/test/system_test.cc +++ b/test/system_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019, Stanford University +/* Copyright (c) 2019-2020, Stanford University * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -15,12 +15,10 @@ #include #include -#include -#include +#include #include #include -#include #include #include #include @@ -47,6 +45,7 @@ static const char USAGE[] = R"(Homa System Test. bool _PRINT_CLIENT_ = false; bool _PRINT_SERVER_ = false; +static const uint16_t SERVER_PORT = 60001; struct MessageHeader { uint64_t id; @@ -57,29 +56,27 @@ struct Node { explicit Node(uint64_t id) : id(id) , driver() - , transport(Homa::Transport::create(&driver, id)) + , transport(Homa::PollModeTransport::create(&driver, id)) , thread() , run(false) {} const uint64_t id; Homa::Drivers::Fake::FakeDriver driver; - Homa::Transport* transport; + Homa::unique_ptr transport; std::thread thread; std::atomic run; }; void -serverMain(Node* server, std::vector addresses) +serverMain(Node* server, std::vector addresses) { while (true) { if (server->run.load() == false) { break; } - Homa::unique_ptr message = - server->transport->receive(); - + Homa::unique_ptr message(server->transport->receive()); if (message) { MessageHeader header; message->get(0, &header, sizeof(MessageHeader)); @@ -101,7 +98,7 @@ serverMain(Node* server, std::vector addresses) * Number of Op that failed. */ int -clientMain(int count, int size, std::vector addresses) +clientMain(int count, int size, std::vector addresses) { std::random_device rd; std::mt19937 gen(rd()); @@ -115,13 +112,13 @@ clientMain(int count, int size, std::vector addresses) for (int i = 0; i < count; ++i) { uint64_t id = nextId++; char payload[size]; - for (int i = 0; i < size; ++i) { - payload[i] = randData(gen); + for (char& byte : payload) { + byte = randData(gen); } - std::string destAddress = addresses[randAddr(gen)]; + Homa::IpAddress destAddress = addresses[randAddr(gen)]; - Homa::unique_ptr message = client.transport->alloc(); + Homa::unique_ptr message = client.transport->alloc(0); { MessageHeader header; header.id = id; @@ -133,7 +130,7 @@ clientMain(int count, int size, std::vector addresses) << std::endl; } } - message->send(client.driver.getAddress(&destAddress)); + message->send(Homa::SocketAddress{destAddress, SERVER_PORT}); while (1) { Homa::OutMessage::Status status = message->getStatus(); @@ -185,25 +182,22 @@ main(int argc, char* argv[]) Homa::Drivers::Fake::FakeNetworkConfig::setPacketLossRate(packetLossRate); uint64_t nextServerId = 101; - std::vector addresses; + std::vector addresses; std::vector servers; for (int i = 0; i < numServers; ++i) { Node* server = new Node(nextServerId++); - addresses.emplace_back(std::string( - server->driver.addressToString(server->driver.getLocalAddress()))); + addresses.emplace_back(server->driver.getLocalAddress()); servers.push_back(server); } - for (auto it = servers.begin(); it != servers.end(); ++it) { - Node* server = *it; + for (auto server : servers) { server->run = true; server->thread = std::move(std::thread(&serverMain, server, addresses)); } int numFails = clientMain(numTests, numBytes, addresses); - for (auto it = servers.begin(); it != servers.end(); ++it) { - Node* server = *it; + for (auto server : servers) { server->run = false; server->thread.join(); delete server;