Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Graph] Add common reference semantics #16788

Open
wants to merge 5 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class __SYCL_EXPORT node {
/// Update the Range of this node if it is a kernel execution node
template <int Dimensions> void update_range(range<Dimensions> executionRange);

/// Common Reference Semantics
friend bool operator==(const node &LHS, const node &RHS) {
return LHS.impl == RHS.impl;
}
Bensuo marked this conversation as resolved.
Show resolved Hide resolved
friend bool operator!=(const node &LHS, const node &RHS) {
return !operator==(LHS, RHS);
}

private:
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}

Expand Down Expand Up @@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group {
size_t get_active_index() const;
void set_active_index(size_t Index);

/// Common Reference Semantics
friend bool operator==(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return !operator==(LHS, RHS);
}

private:
template <class Obj>
friend const decltype(Obj::impl) &
Expand Down Expand Up @@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph
/// Get a list of all root nodes (nodes without dependencies) in this graph.
std::vector<node> get_root_nodes() const;

/// Common Reference Semantics
friend bool operator==(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return !operator==(LHS, RHS);
}

protected:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
Expand Down Expand Up @@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph
/// @param Nodes The nodes to use for updating the graph.
void update(const std::vector<node> &Nodes);

/// Common Reference Semantics
friend bool operator==(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return !operator==(LHS, RHS);
}

protected:
/// Constructor used by internal runtime.
/// @param Graph Detail implementation class to construct with.
Expand Down Expand Up @@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base {
Graph,
size_t ParamSize, const void *Data);

/// Common Reference Semantics
friend bool operator==(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return !operator==(LHS, RHS);
}

protected:
void updateValue(const void *NewValue, size_t Size);

Expand Down Expand Up @@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice,

} // namespace _V1
} // namespace sycl

namespace std {
template <> struct __SYCL_EXPORT hash<sycl::ext::oneapi::experimental::node> {
size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const;
};

template <>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_command_group> {
size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group
&DynamicCGH) const;
};

template <sycl::ext::oneapi::experimental::graph_state State>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::command_graph<State>> {
size_t operator()(const sycl::ext::oneapi::experimental::command_graph<State>
&Graph) const {
auto ID = sycl::detail::getSyclObjImpl(Graph)->getID();
return std::hash<decltype(ID)>()(ID);
}
};

template <typename ValueT>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
size_t
operator()(const sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>
&DynamicParam) const {
auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID();
return std::hash<decltype(ID)>()(ID);
}
};
} // namespace std
23 changes: 20 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
const sycl::device &SyclDevice,
const sycl::property_list &PropList)
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap(), MInorderQueueMap() {
MEventsMap(), MInorderQueueMap(),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
if (PropList.has_property<property::graph::no_cycle_check>()) {
MSkipCycleChecks = true;
Expand Down Expand Up @@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()),
MEnableProfiling(
PropList.has_property<property::graph::enable_profiling>()) {
PropList.has_property<property::graph::enable_profiling>()),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
// If the graph has been marked as updatable then check if the backend
// actually supports that. Devices supporting aspect::ext_oneapi_graph must
Expand Down Expand Up @@ -2035,7 +2037,8 @@ void dynamic_parameter_impl::updateCGAccessor(

dynamic_command_group_impl::dynamic_command_group_impl(
const command_graph<graph_state::modifiable> &Graph)
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {}

void dynamic_command_group_impl::finalizeCGFList(
const std::vector<std::function<void(handler &)>> &CGFList) {
Expand Down Expand Up @@ -2159,3 +2162,17 @@ void dynamic_command_group::set_active_index(size_t Index) {
} // namespace ext
} // namespace _V1
} // namespace sycl

size_t std::hash<sycl::ext::oneapi::experimental::node>::operator()(
const sycl::ext::oneapi::experimental::node &Node) const {
auto ID = sycl::detail::getSyclObjImpl(Node)->getID();
return std::hash<decltype(ID)>()(ID);
}

size_t
std::hash<sycl::ext::oneapi::experimental::dynamic_command_group>::operator()(
const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCG)
const {
auto ID = sycl::detail::getSyclObjImpl(DynamicCG)->getID();
return std::hash<decltype(ID)>()(ID);
}
32 changes: 30 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
return MBarrierDependencyMap[Queue];
}

unsigned long long getID() const { return MID; }

private:
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
/// @param NodeFunc A function which receives as input a node in the graph to
Expand Down Expand Up @@ -1198,6 +1200,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MBarrierDependencyMap;

unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

/// Class representing the implementation of command_graph<executable>.
Expand Down Expand Up @@ -1297,6 +1303,8 @@ class exec_graph_impl {

void updateImpl(std::shared_ptr<node_impl> NodeImpl);

unsigned long long getID() const { return MID; }

private:
/// Create a command-group for the node and add it to command-buffer by going
/// through the scheduler.
Expand Down Expand Up @@ -1408,21 +1416,27 @@ class exec_graph_impl {
// Stores a cache of node ids from modifiable graph nodes to the companion
// node(s) in this graph. Used for quick access when updating this graph.
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;

unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_parameter_impl {
public:
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
size_t ParamSize, const void *Data)
: MGraph(GraphImpl), MValueStorage(ParamSize) {
: MGraph(GraphImpl), MValueStorage(ParamSize),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
std::memcpy(MValueStorage.data(), Data, ParamSize);
}

/// sycl_ext_oneapi_raw_kernel_arg constructor
/// Parameter size is taken from member of raw_kernel_arg object.
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl, size_t,
raw_kernel_arg *Data)
: MGraph(GraphImpl) {
: MGraph(GraphImpl),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
size_t RawArgSize = Data->MArgSize;
const void *RawArgData = Data->MArgData;
MValueStorage.reserve(RawArgSize);
Expand Down Expand Up @@ -1493,13 +1507,20 @@ class dynamic_parameter_impl {
int ArgIndex,
const sycl::detail::AccessorBaseHost *Acc);

unsigned long long getID() const { return MID; }

// Weak ptrs to node_impls which will be updated
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
// Dynamic command-groups which will be updated
std::vector<DynamicCGInfo> MDynCGs;

std::shared_ptr<graph_impl> MGraph;
std::vector<std::byte> MValueStorage;

private:
unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_command_group_impl
Expand Down Expand Up @@ -1540,6 +1561,13 @@ class dynamic_command_group_impl

/// List of nodes using this dynamic command-group.
std::vector<std::weak_ptr<node_impl>> MNodes;

unsigned long long getID() const { return MID; }

private:
unsigned long long MID;
// Used for std::hash in order to create a unique hash for the instance.
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};
} // namespace detail
} // namespace experimental
Expand Down
6 changes: 6 additions & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@$$QEAV0123456@@Z
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@AEBV0123456@@Z
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@V?$command_graph@$0A@@23456@_KPEBX@Z
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
??R?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@Z
??R?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVnode@experimental@oneapi@ext@_V1@sycl@@@Z
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
??0event@_V1@sycl@@AEAA@V?$shared_ptr@Vevent_impl@detail@_V1@sycl@@@std@@@Z
??0event@_V1@sycl@@QEAA@$$QEAV012@@Z
??0event@_V1@sycl@@QEAA@AEBV012@@Z
Expand Down
1 change: 1 addition & 0 deletions sycl/unittests/Extensions/CommandGraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)
add_sycl_unittest(CommandGraphExtensionTests OBJECT
Barrier.cpp
CommandGraph.cpp
CommonReferenceSemantics.cpp
Exceptions.cpp
InOrderQueue.cpp
MultiThreaded.cpp
Expand Down
Loading
Loading