diff --git a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp index e2e87c30ea945..ac0e713dfde95 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/graph.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/graph.hpp @@ -17,7 +17,7 @@ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES #include #endif -#include // for device +#include // for device #include // for graph properties classes #include // for range, nd_range #include // for is_property, is_property_of @@ -142,6 +142,14 @@ class __SYCL_EXPORT node { /// Update the Range of this node if it is a kernel execution node template void update_range(range executionRange); + /// Common Reference Semantics + friend bool operator==(const node &LHS, const node &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const node &LHS, const node &RHS) { + return LHS.impl != RHS.impl; + } + private: node(const std::shared_ptr &Impl) : impl(Impl) {} @@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group { size_t get_active_index() const; void set_active_index(size_t Index); + /// Common Reference Semantics + friend bool operator==(const dynamic_command_group &LHS, + const dynamic_command_group &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const dynamic_command_group &LHS, + const dynamic_command_group &RHS) { + return LHS.impl != RHS.impl; + } + private: template friend const decltype(Obj::impl) & @@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph /// Get a list of all root nodes (nodes without dependencies) in this graph. std::vector get_root_nodes() const; + /// Common Reference Semantics + friend bool operator==(const modifiable_command_graph &LHS, + const modifiable_command_graph &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const modifiable_command_graph &LHS, + const modifiable_command_graph &RHS) { + return LHS.impl != RHS.impl; + } + protected: /// Constructor used internally by the runtime. /// @param Impl Detail implementation class to construct object with. @@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph /// @param Nodes The nodes to use for updating the graph. void update(const std::vector &Nodes); + /// Common Reference Semantics + friend bool operator==(const executable_command_graph &LHS, + const executable_command_graph &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const executable_command_graph &LHS, + const executable_command_graph &RHS) { + return LHS.impl != RHS.impl; + } + protected: /// Constructor used by internal runtime. /// @param Graph Detail implementation class to construct with. @@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base { Graph, size_t ParamSize, const void *Data); + /// Common Reference Semantics + friend bool operator==(const dynamic_parameter_base &LHS, + const dynamic_parameter_base &RHS) { + return LHS.impl == RHS.impl; + } + friend bool operator!=(const dynamic_parameter_base &LHS, + const dynamic_parameter_base &RHS) { + return LHS.impl != RHS.impl; + } + protected: void updateValue(const void *NewValue, size_t Size); @@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice, } // namespace _V1 } // namespace sycl + +namespace std { +template <> struct __SYCL_EXPORT hash { + size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const; +}; + +template <> +struct __SYCL_EXPORT + hash { + size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group + &DynamicCGH) const; +}; + +template +struct __SYCL_EXPORT + hash> { + size_t operator()(const sycl::ext::oneapi::experimental::command_graph + &Graph) const { + auto ID = sycl::detail::getSyclObjImpl(Graph)->getID(); + return std::hash()(ID); + } +}; + +template +struct __SYCL_EXPORT + hash> { + size_t + operator()(const sycl::ext::oneapi::experimental::dynamic_parameter + &DynamicParam) const { + auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID(); + return std::hash()(ID); + } +}; +} // namespace std diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index e6181a559d8e6..a498cc455bd85 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList) : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(), - MEventsMap(), MInorderQueueMap() { + MEventsMap(), MInorderQueueMap(), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { checkGraphPropertiesAndThrow(PropList); if (PropList.has_property()) { MSkipCycleChecks = true; @@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, MExecutionEvents(), MIsUpdatable(PropList.has_property()), MEnableProfiling( - PropList.has_property()) { + PropList.has_property()), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { checkGraphPropertiesAndThrow(PropList); // If the graph has been marked as updatable then check if the backend // actually supports that. Devices supporting aspect::ext_oneapi_graph must @@ -2035,7 +2037,8 @@ void dynamic_parameter_impl::updateCGAccessor( dynamic_command_group_impl::dynamic_command_group_impl( const command_graph &Graph) - : MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {} + : MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {} void dynamic_command_group_impl::finalizeCGFList( const std::vector> &CGFList) { @@ -2159,3 +2162,17 @@ void dynamic_command_group::set_active_index(size_t Index) { } // namespace ext } // namespace _V1 } // namespace sycl + +size_t std::hash::operator()( + const sycl::ext::oneapi::experimental::node &Node) const { + auto ID = sycl::detail::getSyclObjImpl(Node)->getID(); + return std::hash()(ID); +} + +size_t +std::hash::operator()( + const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCGH) + const { + auto ID = sycl::detail::getSyclObjImpl(DynamicCGH)->getID(); + return std::hash()(ID); +} diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index e609123b4f285..114b50a81d38e 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this { return MBarrierDependencyMap[Queue]; } + unsigned long long getID() { return MID; } + private: /// Iterate over the graph depth-first and run \p NodeFunc on each node. /// @param NodeFunc A function which receives as input a node in the graph to @@ -1198,6 +1200,9 @@ class graph_impl : public std::enable_shared_from_this { std::map, std::shared_ptr, std::owner_less>> MBarrierDependencyMap; + + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; /// Class representing the implementation of command_graph. @@ -1297,6 +1302,8 @@ class exec_graph_impl { void updateImpl(std::shared_ptr NodeImpl); + unsigned long long getID() { return MID; } + private: /// Create a command-group for the node and add it to command-buffer by going /// through the scheduler. @@ -1408,13 +1415,17 @@ class exec_graph_impl { // Stores a cache of node ids from modifiable graph nodes to the companion // node(s) in this graph. Used for quick access when updating this graph. std::multimap> MIDCache; + + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; class dynamic_parameter_impl { public: dynamic_parameter_impl(std::shared_ptr GraphImpl, size_t ParamSize, const void *Data) - : MGraph(GraphImpl), MValueStorage(ParamSize) { + : MGraph(GraphImpl), MValueStorage(ParamSize), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { std::memcpy(MValueStorage.data(), Data, ParamSize); } @@ -1422,7 +1433,8 @@ class dynamic_parameter_impl { /// Parameter size is taken from member of raw_kernel_arg object. dynamic_parameter_impl(std::shared_ptr GraphImpl, size_t, raw_kernel_arg *Data) - : MGraph(GraphImpl) { + : MGraph(GraphImpl), + MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) { size_t RawArgSize = Data->MArgSize; const void *RawArgData = Data->MArgData; MValueStorage.reserve(RawArgSize); @@ -1493,6 +1505,8 @@ class dynamic_parameter_impl { int ArgIndex, const sycl::detail::AccessorBaseHost *Acc); + unsigned long long getID() { return MID; } + // Weak ptrs to node_impls which will be updated std::vector, int>> MNodes; // Dynamic command-groups which will be updated @@ -1500,6 +1514,10 @@ class dynamic_parameter_impl { std::shared_ptr MGraph; std::vector MValueStorage; + +private: + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; class dynamic_command_group_impl @@ -1540,6 +1558,12 @@ class dynamic_command_group_impl /// List of nodes using this dynamic command-group. std::vector> MNodes; + + unsigned long long getID() { return MID; } + +private: + unsigned long long MID; + inline static std::atomic NextAvailableID = 0; }; } // namespace detail } // namespace experimental diff --git a/sycl/unittests/Extensions/CommandGraph/CMakeLists.txt b/sycl/unittests/Extensions/CommandGraph/CMakeLists.txt index 31f899f6a2349..0e2a8113a0ebf 100644 --- a/sycl/unittests/Extensions/CommandGraph/CMakeLists.txt +++ b/sycl/unittests/Extensions/CommandGraph/CMakeLists.txt @@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF) add_sycl_unittest(CommandGraphExtensionTests OBJECT Barrier.cpp CommandGraph.cpp + CommonReferenceSemantics.cpp Exceptions.cpp InOrderQueue.cpp MultiThreaded.cpp @@ -12,3 +13,5 @@ add_sycl_unittest(CommandGraphExtensionTests OBJECT Update.cpp Properties.cpp ) + +add_executable(blablabla CommonReferenceSemantics.cpp) \ No newline at end of file diff --git a/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp b/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp new file mode 100644 index 0000000000000..7fd6800e6d06e --- /dev/null +++ b/sycl/unittests/Extensions/CommandGraph/CommonReferenceSemantics.cpp @@ -0,0 +1,193 @@ +//==----------------------- CommandGraph.cpp -------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "Common.hpp" + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi; + +/** + * Checks that the operators and constructors of graph related classes meet the + * common reference semantics. + * @param lambdaFactory A function object that returns an object to be tested. + */ +template +void testSemantics(LambdaType &&lambdaFactory) { + + T Obj1 = lambdaFactory(); + T Obj2 = lambdaFactory(); + + // Check the == and != operators. + ASSERT_FALSE(Obj1 == Obj2); + ASSERT_TRUE(Obj1 != Obj2); + + // Check Copy Constructor and Assignment operators. + T Obj1Copy = Obj1; + T Obj2CopyConstructed(Obj2); + ASSERT_TRUE(Obj1Copy == Obj1); + ASSERT_TRUE(Obj2CopyConstructed == Obj2); + + // Check Move Constructor and Move Assignment operators. + auto Obj1Move = std::move(Obj1); + auto Obj2MoveConstructed(std::move(Obj2)); + ASSERT_TRUE(Obj1Move == Obj1Copy); + ASSERT_TRUE(Obj2MoveConstructed == Obj2CopyConstructed); +} + +TEST_F(CommandGraphTest, ModifiableGraphSemantics) { + sycl::queue Queue; + auto factory = [&]() { + return experimental::command_graph(Queue.get_context(), Queue.get_device()); + }; + + ASSERT_NO_FATAL_FAILURE( + testSemantics< + experimental::command_graph>( + factory)); +} + +TEST_F(CommandGraphTest, ExecutableGraphSemantics) { + sycl::queue Queue; + + auto factory = [&]() { + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + return Graph.finalize(); + }; + ASSERT_NO_FATAL_FAILURE( + testSemantics< + experimental::command_graph>( + factory)); +} + +TEST_F(CommandGraphTest, NodeSemantics) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto factory = [&]() { + return Graph.add( + [&](handler &CGH) { CGH.parallel_for(1, [=](item<1> Item) {}); }); + }; + ASSERT_NO_FATAL_FAILURE(testSemantics(factory)); +} + +TEST_F(CommandGraphTest, DynamicCGSemantics) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto CGF = [&](handler &CGH) { CGH.parallel_for(1, [=](item<1> Item) {}); }; + + auto factory = [&]() { + return experimental::dynamic_command_group(Graph, {CGF}); + }; + ASSERT_NO_FATAL_FAILURE( + testSemantics(factory)); +} + +TEST_F(CommandGraphTest, DynamicParamSemantics) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto factory = [&]() { + return experimental::dynamic_parameter(Graph, 1); + }; + ASSERT_NO_FATAL_FAILURE( + testSemantics>(factory)); +} + +/** + * Checks for potential hash collisions in the hash implementations of graph + * related classes. + * @param lambdaFactory A function object that returns an object to be tested. + */ +template +void testHash(LambdaType &&lambdaFactory) { + + const int NumObjects = 100; + + std::unordered_map MapObjToBool{}; + T t1 = lambdaFactory(); + T t2 = lambdaFactory(); + T t3 = lambdaFactory(); + T t4 = lambdaFactory(); + + ASSERT_TRUE(MapObjToBool.insert({t1, true}).second); + ASSERT_TRUE(MapObjToBool.insert({t2, true}).second); + + // Insert objects and destroy them immediately to confirm that this doesn't + // create collisions with later insertions. + for (int i = 0; i < NumObjects; ++i) { + T instance = lambdaFactory(); + ASSERT_TRUE(MapObjToBool.insert({instance, true}).second); + } + + ASSERT_TRUE(MapObjToBool.insert({t3, true}).second); + ASSERT_TRUE(MapObjToBool.insert({t4, true}).second); + + ASSERT_TRUE(MapObjToBool.size() == (NumObjects + 4)); +} + +TEST_F(CommandGraphTest, ModifiableGraphHash) { + sycl::queue Queue; + auto factory = [&]() { + return experimental::command_graph(Queue.get_context(), Queue.get_device()); + }; + + ASSERT_NO_FATAL_FAILURE( + testHash< + experimental::command_graph>( + factory)); +} + +TEST_F(CommandGraphTest, ExecutableGraphHash) { + sycl::queue Queue; + + auto factory = [&]() { + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + return Graph.finalize(); + }; + ASSERT_NO_FATAL_FAILURE( + testHash< + experimental::command_graph>( + factory)); +} + +TEST_F(CommandGraphTest, NodeHash) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto factory = [&]() { + return Graph.add( + [&](handler &CGH) { CGH.parallel_for(1, [=](item<1> Item) {}); }); + }; + ASSERT_NO_FATAL_FAILURE(testHash(factory)); +} + +TEST_F(CommandGraphTest, DynamicCommandGroupHash) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto CGF = [&](handler &CGH) { CGH.parallel_for(1, [=](item<1> Item) {}); }; + + auto factory = [&]() { + return experimental::dynamic_command_group(Graph, {CGF}); + }; + ASSERT_NO_FATAL_FAILURE( + testHash(factory)); +} + +TEST_F(CommandGraphTest, DynamicParameterHash) { + sycl::queue Queue; + experimental::command_graph Graph(Queue.get_context(), Queue.get_device()); + + auto factory = [&]() { + return experimental::dynamic_parameter(Graph, 1); + }; + ASSERT_NO_FATAL_FAILURE( + testHash>(factory)); +}