diff --git a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp index e2e87c30ea945..ac0e713dfde95 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp @@ -17,7 +17,7 @@ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES #include #endif -#include // for device +#include // for device #include // for graph properties classes #include // for range, nd_range #include // for is_property, is_property_of @@ -142,6 +142,14 @@ class __SYCL_EXPORT node { /// Update the Range of this node if it is a kernel execution node template void update_range(range executionRange); + /// Common Reference Semantics + friend bool operator==(const node &LHS, const node &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const node &LHS, const node &RHS) { + return LHS.impl != RHS.impl; + } + private: node(const std::shared_ptr &Impl) : impl(Impl) {} @@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group { size_t get_active_index() const; void set_active_index(size_t Index); + /// Common Reference Semantics + friend bool operator==(const dynamic_command_group &LHS, + const dynamic_command_group &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const dynamic_command_group &LHS, + const dynamic_command_group &RHS) { + return LHS.impl != RHS.impl; + } + private: template friend const decltype(Obj::impl) & @@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph /// Get a list of all root nodes (nodes without dependencies) in this graph. std::vector get_root_nodes() const; + /// Common Reference Semantics + friend bool operator==(const modifiable_command_graph &LHS, + const modifiable_command_graph &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const modifiable_command_graph &LHS, + const modifiable_command_graph &RHS) { + return LHS.impl != RHS.impl; + } + protected: /// Constructor used internally by the runtime. /// @param Impl Detail implementation class to construct object with. @@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph /// @param Nodes The nodes to use for updating the graph. void update(const std::vector &Nodes); + /// Common Reference Semantics + friend bool operator==(const executable_command_graph &LHS, + const executable_command_graph &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const executable_command_graph &LHS, + const executable_command_graph &RHS) { + return LHS.impl != RHS.impl; + } + protected: /// Constructor used by internal runtime. /// @param Graph Detail implementation class to construct with. @@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base { Graph, size_t ParamSize, const void *Data); + /// Common Reference Semantics + friend bool operator==(const dynamic_parameter_base &LHS, + const dynamic_parameter_base &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const dynamic_parameter_base &LHS, + const dynamic_parameter_base &RHS) { + return LHS.impl != RHS.impl; + } + protected: void updateValue(const void *NewValue, size_t Size); @@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice, } // namespace _V1 } // namespace sycl + +namespace std { +template <> struct __SYCL_EXPORT hash { + size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const; +}; + +template <> +struct __SYCL_EXPORT + hash { + size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group + &DynamicCGH) const; +}; + +template +struct __SYCL_EXPORT + hash> { + size_t operator()(const sycl::ext::oneapi::experimental::command_graph + &Graph) const { + auto ID = sycl::detail::getSyclObjImpl(Graph)->getID(); + return std::hash()(ID); + } +}; + +template +struct __SYCL_EXPORT + hash> { + size_t + operator()(const sycl::ext::oneapi::experimental::dynamic_parameter + &DynamicParam) const { + auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID(); + return std::hash()(ID); + } +}; +} // namespace std diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index e6181a559d8e6..6df2cf848e9d3 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList) : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(), - MEventsMap(), MInorderQueueMap() { + MEventsMap(), MInorderQueueMap(), + MID(MNextAvailableID.fetch_add(1, std::memory_order_relaxed)) { checkGraphPropertiesAndThrow(PropList); if (PropList.has_property()) { MSkipCycleChecks = true; @@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, MExecutionEvents(), MIsUpdatable(PropList.has_property()), MEnableProfiling( - PropList.has_property()) { + PropList.has_property()), + MID(MNextAvailableID.fetch_add(1, std::memory_order_relaxed)) { checkGraphPropertiesAndThrow(PropList); // If the graph has been marked as updatable then check if the backend // actually supports that. Devices supporting aspect::ext_oneapi_graph must @@ -2035,7 +2037,8 @@ void dynamic_parameter_impl::updateCGAccessor( dynamic_command_group_impl::dynamic_command_group_impl( const command_graph &Graph) - : MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {} + : MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0), + MID(MNextAvailableID.fetch_add(1, std::memory_order_relaxed)) {} void dynamic_command_group_impl::finalizeCGFList( const std::vector> &CGFList) { @@ -2159,3 +2162,17 @@ void dynamic_command_group::set_active_index(size_t Index) { } // namespace ext } // namespace _V1 } // namespace sycl + +size_t std::hash::operator()( + const sycl::ext::oneapi::experimental::node &Node) const { + auto ID = sycl::detail::getSyclObjImpl(Node)->getID(); + return std::hash()(ID); +} + +size_t +std::hash::operator()( + const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCGH) + const { + auto ID = sycl::detail::getSyclObjImpl(DynamicCGH)->getID(); + return std::hash()(ID); +} diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index e609123b4f285..65a74996145da 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this { return MBarrierDependencyMap[Queue]; } + unsigned long long getID() { return MID; } + private: /// Iterate over the graph depth-first and run \p NodeFunc on each node. /// @param NodeFunc A function which receives as input a node in the graph to @@ -1198,6 +1200,9 @@ class graph_impl : public std::enable_shared_from_this { std::map, std::shared_ptr, std::owner_less>> MBarrierDependencyMap; + + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; /// Class representing the implementation of command_graph. @@ -1297,6 +1302,8 @@ class exec_graph_impl { void updateImpl(std::shared_ptr NodeImpl); + unsigned long long getID() { return MID; } + private: /// Create a command-group for the node and add it to command-buffer by going /// through the scheduler. @@ -1408,13 +1415,17 @@ class exec_graph_impl { // Stores a cache of node ids from modifiable graph nodes to the companion // node(s) in this graph. Used for quick access when updating this graph. std::multimap> MIDCache; + + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; class dynamic_parameter_impl { public: dynamic_parameter_impl(std::shared_ptr GraphImpl, size_t ParamSize, const void *Data) - : MGraph(GraphImpl), MValueStorage(ParamSize) { + : MGraph(GraphImpl), MValueStorage(ParamSize), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { std::memcpy(MValueStorage.data(), Data, ParamSize); } @@ -1422,7 +1433,8 @@ class dynamic_parameter_impl { /// Parameter size is taken from member of raw_kernel_arg object. dynamic_parameter_impl(std::shared_ptr GraphImpl, size_t, raw_kernel_arg *Data) - : MGraph(GraphImpl) { + : MGraph(GraphImpl), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { size_t RawArgSize = Data->MArgSize; const void *RawArgData = Data->MArgData; MValueStorage.reserve(RawArgSize); @@ -1493,6 +1505,8 @@ class dynamic_parameter_impl { int ArgIndex, const sycl::detail::AccessorBaseHost *Acc); + unsigned long long getID() { return MID; } + // Weak ptrs to node_impls which will be updated std::vector, int>> MNodes; // Dynamic command-groups which will be updated @@ -1500,6 +1514,10 @@ class dynamic_parameter_impl { std::shared_ptr MGraph; std::vector MValueStorage; + +private: + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; class dynamic_command_group_impl @@ -1540,6 +1558,12 @@ class dynamic_command_group_impl /// List of nodes using this dynamic command-group. std::vector> MNodes; + + unsigned long long getID() { return MID; } + +protected: + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; } // namespace detail } // namespace experimental diff --git a/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp b/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp new file mode 100644 index 0000000000000..7681591c0ff93 --- /dev/null +++ b/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp @@ -0,0 +1 @@ +#include "common_reference_semantics.hpp"