From 516ca32fd4eed46cd9d1d6e3fb530d2a5d131970 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Tue, 22 Apr 2025 13:45:27 -0700 Subject: [PATCH] [NFCI][SYCL] Less `shared_ptr` for `platform_impl` `GlobalHandler::MPlatformCache` keeps (shared) ownership of `platform_impl` objects, so none of them could be destroyed until SYCL RT library shutdown/unload process. As such, using raw pointers/reference to `platform_impl` throughout the SYCL RT is totally fine and avoids extra costs of `std::shared_ptr` I'm relatively sure `sycl::platform` could avoid using `std::shared_ptr impl` as well, but that would be a breaking change so not being implemented at this moment. --- sycl/include/sycl/platform.hpp | 7 ++++ sycl/source/backend.cpp | 8 ++--- sycl/source/backend/level_zero.cpp | 4 +-- sycl/source/detail/allowlist.cpp | 7 ++-- sycl/source/detail/buffer_impl.cpp | 9 +++-- sycl/source/detail/context_impl.cpp | 18 +++++----- sycl/source/detail/context_impl.hpp | 8 +++-- sycl/source/detail/device_impl.cpp | 12 +++---- sycl/source/detail/device_impl.hpp | 9 +++-- sycl/source/detail/device_info.hpp | 13 ++++--- sycl/source/detail/global_handler.cpp | 6 ++-- sycl/source/detail/global_handler.hpp | 9 +++-- sycl/source/detail/kernel_impl.cpp | 2 +- sycl/source/detail/platform_impl.cpp | 31 ++++++++-------- sycl/source/detail/platform_impl.hpp | 35 ++++++++++++++----- .../program_manager/program_manager.cpp | 21 ++++++----- sycl/source/detail/usm/usm_impl.cpp | 4 +-- sycl/source/device.cpp | 4 +-- sycl/source/platform.cpp | 10 +++--- sycl/test/gdb/printers.cpp | 3 +- sycl/unittests/program_manager/SubDevices.cpp | 2 +- sycl/unittests/queue/DeviceCheck.cpp | 5 +-- 22 files changed, 125 insertions(+), 102 deletions(-) diff --git a/sycl/include/sycl/platform.hpp b/sycl/include/sycl/platform.hpp index 7c696a3f11bb2..0fc0a00c1bafa 100644 --- a/sycl/include/sycl/platform.hpp +++ b/sycl/include/sycl/platform.hpp @@ -37,12 +37,16 @@ inline namespace _V1 { // Forward declaration class device; class context; +class platform; template auto get_native(const SyclObjectT &Obj) -> backend_return_t; namespace detail { class platform_impl; +template +std::enable_if_t, platform> +createSyclObjFromImpl(platform_impl &); /// Allows to enable/disable "Default Context" extension /// @@ -231,6 +235,9 @@ class __SYCL_EXPORT platform : public detail::OwnerLessBase { template friend const decltype(Obj::impl) & detail::getSyclObjImpl(const Obj &SyclObject); + template + friend std::enable_if_t, platform> + detail::createSyclObjFromImpl(detail::platform_impl &); template friend auto get_native(const SyclObjectT &Obj) diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index 73857b7ab2da3..f64d2c809b3db 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -89,9 +89,9 @@ __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle, NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice); // Construct the SYCL device from UR device. - auto Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter); + auto &Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter); return detail::createSyclObjFromImpl( - Platform->getOrMakeDeviceImpl(UrDevice, Platform)); + Platform.getOrMakeDeviceImpl(UrDevice, Platform)); } __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle, @@ -288,9 +288,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, std::transform( ProgramDevices.begin(), ProgramDevices.end(), std::back_inserter(Devices), [&Adapter](const auto &Dev) { - auto Platform = + platform_impl &Platform = detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter); - auto DeviceImpl = Platform->getOrMakeDeviceImpl(Dev, Platform); + auto DeviceImpl = Platform.getOrMakeDeviceImpl(Dev, Platform); return createSyclObjFromImpl(DeviceImpl); }); diff --git a/sycl/source/backend/level_zero.cpp b/sycl/source/backend/level_zero.cpp index 75a67745f6849..4536440182ac6 100644 --- a/sycl/source/backend/level_zero.cpp +++ b/sycl/source/backend/level_zero.cpp @@ -20,14 +20,14 @@ using namespace sycl::detail; __SYCL_EXPORT device make_device(const platform &Platform, ur_native_handle_t NativeHandle) { const auto &Adapter = ur::getAdapter(); - const auto &PlatformImpl = getSyclObjImpl(Platform); + platform_impl &PlatformImpl = *getSyclObjImpl(Platform).get(); // Create UR device first. ur_device_handle_t UrDevice; Adapter->call( NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice); return detail::createSyclObjFromImpl( - PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl)); + PlatformImpl.getOrMakeDeviceImpl(UrDevice, PlatformImpl)); } } // namespace ext::oneapi::level_zero::detail diff --git a/sycl/source/detail/allowlist.cpp b/sycl/source/detail/allowlist.cpp index e7fa030039e1c..0867036165ad6 100644 --- a/sycl/source/detail/allowlist.cpp +++ b/sycl/source/detail/allowlist.cpp @@ -375,8 +375,9 @@ void applyAllowList(std::vector &UrDevices, // Get platform's backend and put it to DeviceDesc DeviceDescT DeviceDesc; - auto PlatformImpl = platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter); - backend Backend = PlatformImpl->getBackend(); + platform_impl &PlatformImpl = + platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter); + backend Backend = PlatformImpl.getBackend(); for (const auto &SyclBe : getSyclBeMap()) { if (SyclBe.second == Backend) { @@ -395,7 +396,7 @@ void applyAllowList(std::vector &UrDevices, int InsertIDx = 0; for (ur_device_handle_t Device : UrDevices) { - auto DeviceImpl = PlatformImpl->getOrMakeDeviceImpl(Device, PlatformImpl); + auto DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device, PlatformImpl); // get DeviceType value and put it to DeviceDesc ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL; Adapter->call( diff --git a/sycl/source/detail/buffer_impl.cpp b/sycl/source/detail/buffer_impl.cpp index c39192f224760..150ed2c3d0476 100644 --- a/sycl/source/detail/buffer_impl.cpp +++ b/sycl/source/detail/buffer_impl.cpp @@ -79,12 +79,11 @@ buffer_impl::getNativeVector(backend BackendName) const { // doesn't have context and platform if (!Ctx) continue; - const PlatformImplPtr &Platform = Ctx->getPlatformImpl(); - assert(Platform && "Platform must be present for device context"); - if (Platform->getBackend() != BackendName) + const platform_impl &Platform = Ctx->getPlatformImpl(); + if (Platform.getBackend() != BackendName) continue; - auto Adapter = Platform->getAdapter(); + auto Adapter = Platform.getAdapter(); ur_native_handle_t Handle = 0; // When doing buffer interop we don't know what device the memory should be @@ -94,7 +93,7 @@ buffer_impl::getNativeVector(backend BackendName) const { &Handle); Handles.push_back(Handle); - if (Platform->getBackend() == backend::opencl) { + if (Platform.getBackend() == backend::opencl) { __SYCL_OCL_CALL(clRetainMemObject, ur::cast(Handle)); } } diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index 2a7d020909ea8..5881d201710d5 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -31,7 +31,7 @@ context_impl::context_impl(const device &Device, async_handler AsyncHandler, const property_list &PropList) : MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(1, Device), MContext(nullptr), - MPlatform(detail::getSyclObjImpl(Device.get_platform())), + MPlatform(detail::getSyclObjImpl(Device.get_platform()).get()), MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) { verifyProps(PropList); MKernelProgramCache.setContextPtr(this); @@ -41,10 +41,10 @@ context_impl::context_impl(const std::vector Devices, async_handler AsyncHandler, const property_list &PropList) : MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices), - MContext(nullptr), MPlatform(), MPropList(PropList), - MSupportBufferLocationByDevices(NotChecked) { + MContext(nullptr), + MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform()).get()), + MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) { verifyProps(PropList); - MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()); std::vector DeviceIds; for (const auto &D : MDevices) { if (D.has(aspect::ext_oneapi_is_composite)) { @@ -77,7 +77,7 @@ context_impl::context_impl(ur_context_handle_t UrContext, MDevices(DeviceList), MContext(UrContext), MPlatform(), MSupportBufferLocationByDevices(NotChecked) { if (!MDevices.empty()) { - MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()); + MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()).get(); } else { std::vector DeviceIds; uint32_t DevicesNum = 0; @@ -96,13 +96,13 @@ context_impl::context_impl(ur_context_handle_t UrContext, make_error_code(errc::invalid), "No devices in the provided device list and native context."); - std::shared_ptr Platform = + platform_impl &Platform = platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter); for (ur_device_handle_t Dev : DeviceIds) { MDevices.emplace_back(createSyclObjFromImpl( - Platform->getOrMakeDeviceImpl(Dev, Platform))); + Platform.getOrMakeDeviceImpl(Dev, Platform))); } - MPlatform = Platform; + MPlatform = &Platform; } // TODO catch an exception and put it to list of asynchronous exceptions // getAdapter() will be the same as the Adapter passed. This should be taken @@ -158,7 +158,7 @@ uint32_t context_impl::get_info() const { this->getAdapter()); } template <> platform context_impl::get_info() const { - return createSyclObjFromImpl(MPlatform); + return createSyclObjFromImpl(*MPlatform); } template <> std::vector diff --git a/sycl/source/detail/context_impl.hpp b/sycl/source/detail/context_impl.hpp index 9d88df4a77bbc..76b4728b63798 100644 --- a/sycl/source/detail/context_impl.hpp +++ b/sycl/source/detail/context_impl.hpp @@ -29,7 +29,6 @@ inline namespace _V1 { // Forward declaration class device; namespace detail { -using PlatformImplPtr = std::shared_ptr; class context_impl { public: /// Constructs a context_impl using a single SYCL devices. @@ -89,8 +88,10 @@ class context_impl { /// \return the Adapter associated with the platform of this context. const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); } + // TODO: Think more about `const` /// \return the PlatformImpl associated with this context. - const PlatformImplPtr &getPlatformImpl() const { return MPlatform; } + const platform_impl &getPlatformImpl() const { return *MPlatform; } + platform_impl &getPlatformImpl() { return *MPlatform; } /// Queries this context for information. /// @@ -257,7 +258,8 @@ class context_impl { async_handler MAsyncHandler; std::vector MDevices; ur_context_handle_t MContext; - PlatformImplPtr MPlatform; + // TODO: Make it a reference instead, but that needs a bit more refactoring: + platform_impl *MPlatform = nullptr; property_list MPropList; CachedLibProgramsT MCachedLibPrograms; std::mutex MCachedLibProgramsMutex; diff --git a/sycl/source/detail/device_impl.cpp b/sycl/source/detail/device_impl.cpp index 903af9d85c797..94e66b1c577bc 100644 --- a/sycl/source/detail/device_impl.cpp +++ b/sycl/source/detail/device_impl.cpp @@ -21,17 +21,17 @@ namespace detail { /// Constructs a SYCL device instance using the provided /// UR device instance. -device_impl::device_impl(ur_device_handle_t Device, PlatformImplPtr Platform) - : MDevice(Device), MPlatform(Platform), +device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform) + : MDevice(Device), MPlatform(&Platform), MDeviceHostBaseTime(std::make_pair(0, 0)) { - const AdapterPtr &Adapter = Platform->getAdapter(); + const AdapterPtr &Adapter = Platform.getAdapter(); // TODO catch an exception and put it to list of asynchronous exceptions Adapter->call( MDevice, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t), &MType, nullptr); // No need to set MRootDevice when MAlwaysRootDevice is true - if (!Platform->MAlwaysRootDevice) { + if (!Platform.MAlwaysRootDevice) { // TODO catch an exception and put it to list of asynchronous exceptions Adapter->call( MDevice, UR_DEVICE_INFO_PARENT_DEVICE, sizeof(ur_device_handle_t), @@ -74,7 +74,7 @@ cl_device_id device_impl::get() const { } platform device_impl::get_platform() const { - return createSyclObjFromImpl(MPlatform); + return createSyclObjFromImpl(*MPlatform); } template @@ -177,7 +177,7 @@ std::vector device_impl::create_sub_devices( std::for_each(SubDevices.begin(), SubDevices.end(), [&res, this](const ur_device_handle_t &a_ur_device) { device sycl_device = detail::createSyclObjFromImpl( - MPlatform->getOrMakeDeviceImpl(a_ur_device, MPlatform)); + MPlatform->getOrMakeDeviceImpl(a_ur_device, *MPlatform)); res.push_back(sycl_device); }); return res; diff --git a/sycl/source/detail/device_impl.hpp b/sycl/source/detail/device_impl.hpp index 6c0e0632b3e75..a2f2f45f944be 100644 --- a/sycl/source/detail/device_impl.hpp +++ b/sycl/source/detail/device_impl.hpp @@ -30,14 +30,13 @@ namespace detail { // Forward declaration class platform_impl; -using PlatformImplPtr = std::shared_ptr; // TODO: Make code thread-safe class device_impl { public: /// Constructs a SYCL device instance using the provided /// UR device instance. - explicit device_impl(ur_device_handle_t Device, PlatformImplPtr Platform); + explicit device_impl(ur_device_handle_t Device, platform_impl &Platform); ~device_impl(); @@ -278,9 +277,9 @@ class device_impl { /// Get the backend of this device backend getBackend() const { return MPlatform->getBackend(); } + // TODO: const-correctness /// @brief Get the platform impl serving this device - /// @return PlatformImplPtr - const PlatformImplPtr &getPlatformImpl() const { return MPlatform; } + platform_impl &getPlatformImpl() const { return *MPlatform; } /// Get device info string std::string get_device_info_string(ur_device_info_t InfoCode) const; @@ -292,7 +291,7 @@ class device_impl { ur_device_handle_t MDevice = 0; ur_device_type_t MType; ur_device_handle_t MRootDevice = nullptr; - PlatformImplPtr MPlatform; + platform_impl *MPlatform = nullptr; bool MUseNativeAssert = false; mutable std::string MDeviceName; mutable std::once_flag MDeviceNameFlag; diff --git a/sycl/source/detail/device_info.hpp b/sycl/source/detail/device_info.hpp index b21d157d55446..7a2b22e628bfe 100644 --- a/sycl/source/detail/device_info.hpp +++ b/sycl/source/detail/device_info.hpp @@ -35,7 +35,6 @@ namespace sycl { inline namespace _V1 { namespace detail { - inline std::vector readMemoryOrderBitfield(ur_memory_order_capability_flags_t bits) { std::vector result; @@ -1171,9 +1170,9 @@ template <> struct get_device_info_impl { throw exception(make_error_code(errc::invalid), "No parent for device because it is not a subdevice"); - const auto &Platform = Dev.getPlatformImpl(); + platform_impl &Platform = Dev.getPlatformImpl(); return createSyclObjFromImpl( - Platform->getOrMakeDeviceImpl(result, Platform)); + Platform.getOrMakeDeviceImpl(result, Platform)); } }; @@ -1337,10 +1336,10 @@ struct get_device_info_impl< ext::oneapi::experimental::info::device::component_devices>::value, ResultSize, Devs.data(), nullptr); std::vector Result; - const auto &Platform = Dev.getPlatformImpl(); + platform_impl &Platform = Dev.getPlatformImpl(); for (const auto &d : Devs) Result.push_back(createSyclObjFromImpl( - Platform->getOrMakeDeviceImpl(d, Platform))); + Platform.getOrMakeDeviceImpl(d, Platform))); return Result; } @@ -1363,9 +1362,9 @@ struct get_device_info_impl< sizeof(Result), &Result, nullptr); if (Result) { - const auto &Platform = Dev.getPlatformImpl(); + platform_impl &Platform = Dev.getPlatformImpl(); return createSyclObjFromImpl( - Platform->getOrMakeDeviceImpl(Result, Platform)); + Platform.getOrMakeDeviceImpl(Result, Platform)); } throw sycl::exception(make_error_code(errc::invalid), "A component with aspect::ext_oneapi_is_component " diff --git a/sycl/source/detail/global_handler.cpp b/sycl/source/detail/global_handler.cpp index 078da1aab83ba..bfb8726a3e31d 100644 --- a/sycl/source/detail/global_handler.cpp +++ b/sycl/source/detail/global_handler.cpp @@ -186,7 +186,7 @@ ProgramManager &GlobalHandler::getProgramManager() { return PM; } -std::unordered_map & +std::unordered_map & GlobalHandler::getPlatformToDefaultContextCache() { // The optimization with static reference is not done because // there are public methods of the GlobalHandler @@ -207,8 +207,8 @@ Sync &GlobalHandler::getSync() { return sync; } -std::vector &GlobalHandler::getPlatformCache() { - static std::vector &PlatformCache = +std::vector> &GlobalHandler::getPlatformCache() { + static std::vector> &PlatformCache = getOrCreate(MPlatformCache); return PlatformCache; } diff --git a/sycl/source/detail/global_handler.hpp b/sycl/source/detail/global_handler.hpp index ef5bc32c0db0d..e801dda2ecebd 100644 --- a/sycl/source/detail/global_handler.hpp +++ b/sycl/source/detail/global_handler.hpp @@ -27,7 +27,6 @@ class ods_target_list; class XPTIRegistry; class ThreadPool; -using PlatformImplPtr = std::shared_ptr; using ContextImplPtr = std::shared_ptr; using AdapterPtr = std::shared_ptr; @@ -60,9 +59,9 @@ class GlobalHandler { bool isSchedulerAlive() const; ProgramManager &getProgramManager(); Sync &getSync(); - std::vector &getPlatformCache(); + std::vector> &getPlatformCache(); - std::unordered_map & + std::unordered_map & getPlatformToDefaultContextCache(); std::mutex &getPlatformToDefaultContextCacheMutex(); @@ -118,8 +117,8 @@ class GlobalHandler { InstWithLock MScheduler; InstWithLock MProgramManager; InstWithLock MSync; - InstWithLock> MPlatformCache; - InstWithLock> + InstWithLock>> MPlatformCache; + InstWithLock> MPlatformToDefaultContextCache; InstWithLock MPlatformToDefaultContextCacheMutex; InstWithLock MPlatformMapMutex; diff --git a/sycl/source/detail/kernel_impl.cpp b/sycl/source/detail/kernel_impl.cpp index 1eaad23c71a50..50fadfd1f4f0a 100644 --- a/sycl/source/detail/kernel_impl.cpp +++ b/sycl/source/detail/kernel_impl.cpp @@ -126,7 +126,7 @@ void kernel_impl::checkIfValidForNumArgsInfoQuery() const { } void kernel_impl::enableUSMIndirectAccess() const { - if (!MContext->getPlatformImpl()->supports_usm()) + if (!MContext->getPlatformImpl().supports_usm()) return; // Some UR Adapters (like OpenCL) require this call to enable USM diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 2f8a1195e132d..a708a1f802fde 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -31,23 +31,21 @@ namespace sycl { inline namespace _V1 { namespace detail { -using PlatformImplPtr = std::shared_ptr; - -PlatformImplPtr +platform_impl & platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, const AdapterPtr &Adapter) { - PlatformImplPtr Result; + std::shared_ptr Result; { const std::lock_guard Guard( GlobalHandler::instance().getPlatformMapMutex()); - std::vector &PlatformCache = + std::vector> &PlatformCache = GlobalHandler::instance().getPlatformCache(); // If we've already seen this platform, return the impl for (const auto &PlatImpl : PlatformCache) { if (PlatImpl->getHandleRef() == UrPlatform) - return PlatImpl; + return *PlatImpl; } // Otherwise make the impl. Our ctor/dtor are private, so std::make_shared @@ -57,13 +55,14 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, : platform_impl(APlatform, AAdapter) {} }; Result = std::make_shared(UrPlatform, Adapter); + Result->Self = Result; PlatformCache.emplace_back(Result); } - return Result; + return *Result; } -PlatformImplPtr +platform_impl & platform_impl::getPlatformFromUrDevice(ur_device_handle_t UrDevice, const AdapterPtr &Adapter) { ur_platform_handle_t Plt = @@ -297,9 +296,9 @@ platform_impl::getDeviceImpl(ur_device_handle_t UrDevice) { return getDeviceImplHelper(UrDevice); } -std::shared_ptr platform_impl::getOrMakeDeviceImpl( - ur_device_handle_t UrDevice, - const std::shared_ptr &PlatformImpl) { +std::shared_ptr +platform_impl::getOrMakeDeviceImpl(ur_device_handle_t UrDevice, + platform_impl &PlatformImpl) { const std::lock_guard Guard(MDeviceMapMutex); // If we've already seen this device, return the impl std::shared_ptr Result = getDeviceImplHelper(UrDevice); @@ -335,7 +334,7 @@ static bool supportsPartitionProperty(const device &dev, static std::vector amendDeviceAndSubDevices( backend PlatformBackend, std::vector &DeviceList, ods_target_list *OdsTargetList, const std::vector &original_indices, - PlatformImplPtr PlatformImpl) { + platform_impl &PlatformImpl) { constexpr info::partition_property partitionProperty = info::partition_property::partition_by_affinity_domain; constexpr info::partition_affinity_domain affinityDomain = @@ -344,7 +343,7 @@ static std::vector amendDeviceAndSubDevices( std::vector FinalResult; // (Only) when amending sub-devices for ONEAPI_DEVICE_SELECTOR, all // sub-devices are treated as root. - TempAssignGuard TAG(PlatformImpl->MAlwaysRootDevice, true); + TempAssignGuard TAG(PlatformImpl.MAlwaysRootDevice, true); for (unsigned i = 0; i < DeviceList.size(); i++) { // device has already been screened. The question is whether it should be a @@ -528,12 +527,12 @@ platform_impl::get_devices(info::device_type DeviceType) const { // The next step is to inflate the filtered UrDevices into SYCL Device // objects. - PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, MAdapter); + platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, MAdapter); std::transform( UrDevices.begin(), UrDevices.end(), std::back_inserter(Res), - [PlatformImpl](const ur_device_handle_t UrDevice) -> device { + [&PlatformImpl](const ur_device_handle_t UrDevice) -> device { return detail::createSyclObjFromImpl( - PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl)); + PlatformImpl.getOrMakeDeviceImpl(UrDevice, PlatformImpl)); }); // The reference counter for handles, that we used to create sycl objects, is diff --git a/sycl/source/detail/platform_impl.hpp b/sycl/source/detail/platform_impl.hpp index a262925a4226c..3b7bbddd1f59a 100644 --- a/sycl/source/detail/platform_impl.hpp +++ b/sycl/source/detail/platform_impl.hpp @@ -171,9 +171,8 @@ class platform_impl { /// \param PlatormImpl is the Platform for that Device /// /// \return a shared_ptr corresponding to the device - std::shared_ptr - getOrMakeDeviceImpl(ur_device_handle_t UrDevice, - const std::shared_ptr &PlatformImpl); + std::shared_ptr getOrMakeDeviceImpl(ur_device_handle_t UrDevice, + platform_impl &PlatformImpl); /// Queries the cache to see if the specified UR platform has been seen /// before. If so, return the cached platform_impl, otherwise create a new @@ -182,9 +181,8 @@ class platform_impl { /// \param UrPlatform is the UR Platform handle representing the platform /// \param Adapter is the UR adapter providing the backend for the platform /// \return the platform_impl representing the UR platform - static std::shared_ptr - getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, - const AdapterPtr &Adapter); + static platform_impl &getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, + const AdapterPtr &Adapter); /// Queries the cache for the specified platform based on an input device. /// If found, returns the the cached platform_impl, otherwise creates a new @@ -195,9 +193,11 @@ class platform_impl { /// \param Adapter is the UR adapter providing the backend for the device and /// platform /// \return the platform_impl that contains the input device - static std::shared_ptr - getPlatformFromUrDevice(ur_device_handle_t UrDevice, - const AdapterPtr &Adapter); + static platform_impl &getPlatformFromUrDevice(ur_device_handle_t UrDevice, + const AdapterPtr &Adapter); + + // Temporary while we're reducing usage of `std::shared_ptr`s. + std::shared_ptr getSharedPtrToSelf() { return Self.lock(); } // when getting sub-devices for ONEAPI_DEVICE_SELECTOR we may temporarily // ensure every device is a root one. @@ -223,8 +223,25 @@ class platform_impl { std::vector> MDeviceCache; std::mutex MDeviceMapMutex; + + // Temporary while we're reducing usage of `std::shared_ptr`s. + std::weak_ptr Self; }; } // namespace detail } // namespace _V1 } // namespace sycl + +#include + +namespace sycl { +inline namespace _V1 { +namespace detail { +template +std::enable_if_t, platform> +createSyclObjFromImpl(platform_impl &p) { + return platform{p.getSharedPtrToSelf()}; +} +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index be49934e8d678..c028e894d6cd9 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -369,7 +369,7 @@ static void appendCompileOptionsFromImage(std::string &CompileOpts, appendCompileOptionsForGRFSizeProperties(CompileOpts, Img, isEsimdImage); const detail::DeviceImplPtr &DeviceImpl = detail::getSyclObjImpl(Devs[0]); - const detail::PlatformImplPtr &PlatformImpl = DeviceImpl->getPlatformImpl(); + const platform_impl &PlatformImpl = DeviceImpl->getPlatformImpl(); // Add optimization flags. auto str = getUint32PropAsOptStr(Img, "optLevel"); @@ -388,7 +388,7 @@ static void appendCompileOptionsFromImage(std::string &CompileOpts, const char *backend_option = nullptr; // Empty string is returned in backend_option when no appropriate backend // option is available for a given frontend option. - PlatformImpl->getBackendOption(optLevelStr, &backend_option); + PlatformImpl.getBackendOption(optLevelStr, &backend_option); if (backend_option && backend_option[0] != '\0') { if (!CompileOpts.empty()) CompileOpts += " "; @@ -396,8 +396,8 @@ static void appendCompileOptionsFromImage(std::string &CompileOpts, } } bool IsIntelGPU = - (PlatformImpl->getBackend() == backend::ext_oneapi_level_zero || - PlatformImpl->getBackend() == backend::opencl) && + (PlatformImpl.getBackend() == backend::ext_oneapi_level_zero || + PlatformImpl.getBackend() == backend::opencl) && std::all_of(Devs.begin(), Devs.end(), [](const device &Dev) { return Dev.is_gpu() && Dev.get_info() == 0x8086; @@ -408,7 +408,7 @@ static void appendCompileOptionsFromImage(std::string &CompileOpts, Pos != std::string::npos) { const char *BackendOption = nullptr; if (IsIntelGPU) - PlatformImpl->getBackendOption(TargetCompileFast, &BackendOption); + PlatformImpl.getBackendOption(TargetCompileFast, &BackendOption); auto OptLen = strlen(TargetCompileFast); if (IsIntelGPU && BackendOption && BackendOption[0] != '\0') CompileOpts.replace(Pos, OptLen, BackendOption); @@ -451,8 +451,7 @@ static void appendCompileOptionsFromImage(std::string &CompileOpts, for (const std::string_view Opt : ReplaceOpts) { if (auto Pos = CompileOpts.find(Opt); Pos != std::string::npos) { const char *BackendOption = nullptr; - PlatformImpl->getBackendOption(std::string(Opt).c_str(), - &BackendOption); + PlatformImpl.getBackendOption(std::string(Opt).c_str(), &BackendOption); CompileOpts.replace(Pos, Opt.length(), BackendOption); } } @@ -1149,7 +1148,7 @@ ProgramManager::getOrCreateKernel(const ContextImplPtr &ContextImpl, Program, KernelName.data(), &Kernel); // Only set UR_USM_INDIRECT_ACCESS if the platform can handle it. - if (ContextImpl->getPlatformImpl()->supports_usm()) { + if (ContextImpl->getPlatformImpl().supports_usm()) { // Some UR Adapters (like OpenCL) require this call to enable USM // For others, UR will turn this into a NOP. const ur_bool_t UrTrue = true; @@ -1656,7 +1655,7 @@ getDeviceLibPrograms(const ContextImplPtr Context, Devices.begin(), Devices.end(), [&Context](ur_device_handle_t Device) { std::string DevExtList = Context->getPlatformImpl() - ->getDeviceImpl(Device) + .getDeviceImpl(Device) ->get_device_info_string( UrInfoCode::value); return (DevExtList.npos != DevExtList.find("cl_khr_fp64")); @@ -1667,7 +1666,7 @@ getDeviceLibPrograms(const ContextImplPtr Context, for (auto Device : Devices) { std::string DevExtList = Context->getPlatformImpl() - ->getDeviceImpl(Device) + .getDeviceImpl(Device) ->get_device_info_string( UrInfoCode::value); @@ -3199,7 +3198,7 @@ ProgramManager::getOrCreateKernel(const context &Context, &Kernel); // Only set UR_USM_INDIRECT_ACCESS if the platform can handle it. - if (Ctx->getPlatformImpl()->supports_usm()) { + if (Ctx->getPlatformImpl().supports_usm()) { bool EnableAccess = true; Adapter->call( Kernel, UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS, sizeof(ur_bool_t), diff --git a/sycl/source/detail/usm/usm_impl.cpp b/sycl/source/detail/usm/usm_impl.cpp index de6a48180ca1e..daa48e4f5e396 100644 --- a/sycl/source/detail/usm/usm_impl.cpp +++ b/sycl/source/detail/usm/usm_impl.cpp @@ -592,10 +592,10 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) { // The device is not necessarily a member of the context, it could be a // member's descendant instead. Fetch the corresponding device from the cache. - const std::shared_ptr &PltImpl = + detail::platform_impl &PltImpl = detail::getSyclObjImpl(Ctxt)->getPlatformImpl(); std::shared_ptr DevImpl = - PltImpl->getDeviceImpl(DeviceId); + PltImpl.getDeviceImpl(DeviceId); if (DevImpl) return detail::createSyclObjFromImpl(DevImpl); throw exception(make_error_code(errc::runtime), diff --git a/sycl/source/device.cpp b/sycl/source/device.cpp index 1e27e4f898bc1..af651357ee11d 100644 --- a/sycl/source/device.cpp +++ b/sycl/source/device.cpp @@ -40,9 +40,9 @@ device::device(cl_device_id DeviceId) { Adapter->call( detail::ur::cast(DeviceId), Adapter->getUrAdapter(), nullptr, &Device); - auto Platform = + detail::platform_impl &Platform = detail::platform_impl::getPlatformFromUrDevice(Device, Adapter); - impl = Platform->getOrMakeDeviceImpl(Device, Platform); + impl = Platform.getOrMakeDeviceImpl(Device, Platform); __SYCL_OCL_CALL(clRetainDevice, DeviceId); } diff --git a/sycl/source/platform.cpp b/sycl/source/platform.cpp index 71c89460c9322..7e3637b20da94 100644 --- a/sycl/source/platform.cpp +++ b/sycl/source/platform.cpp @@ -29,7 +29,8 @@ platform::platform(cl_platform_id PlatformId) { Adapter->call( detail::ur::cast(PlatformId), Adapter->getUrAdapter(), /* pProperties = */ nullptr, &UrPlatform); - impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter); + impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter) + .getSharedPtrToSelf(); } // protected constructor for internal use @@ -88,9 +89,10 @@ platform::get_backend_info() const { #undef __SYCL_PARAM_TRAITS_SPEC context platform::khr_get_default_context() const { + // TODO: Is this still relevant? // Keeping the default context for platforms in the global cache to avoid // shared_ptr based circular dependency between platform and context classes - std::unordered_map + std::unordered_map &PlatformToDefaultContextCache = detail::GlobalHandler::instance().getPlatformToDefaultContextCache(); @@ -98,10 +100,10 @@ context platform::khr_get_default_context() const { detail::GlobalHandler::instance() .getPlatformToDefaultContextCacheMutex()}; - auto It = PlatformToDefaultContextCache.find(impl); + auto It = PlatformToDefaultContextCache.find(impl.get()); if (PlatformToDefaultContextCache.end() == It) std::tie(It, std::ignore) = PlatformToDefaultContextCache.insert( - {impl, detail::getSyclObjImpl(context{get_devices()})}); + {impl.get(), detail::getSyclObjImpl(context{get_devices()})}); return detail::createSyclObjFromImpl(It->second); } diff --git a/sycl/test/gdb/printers.cpp b/sycl/test/gdb/printers.cpp index 5d31ae98e87a2..43c9dfd1c03c5 100644 --- a/sycl/test/gdb/printers.cpp +++ b/sycl/test/gdb/printers.cpp @@ -64,8 +64,7 @@ sycl::range<1> r(3); // CHECK: 0 | class sycl::detail::device_impl // CHECK: 8 | ur_device_type_t MType -// CHECK: 24 | class std::shared_ptr MPlatform -// CHECK: 24 | element_type * _M_ptr +// CHECK: 24 | platform_impl * MPlatform // DEVICE: 0 | class sycl::detail::AccessorImplDevice<1> // DEVICE: 0 | class sycl::id<1> Offset diff --git a/sycl/unittests/program_manager/SubDevices.cpp b/sycl/unittests/program_manager/SubDevices.cpp index 602c16c3a5441..052a06b30932f 100644 --- a/sycl/unittests/program_manager/SubDevices.cpp +++ b/sycl/unittests/program_manager/SubDevices.cpp @@ -105,7 +105,7 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) { // Initialize root device rootDevice = sycl::detail::getSyclObjImpl(device)->getHandleRef(); // Initialize sub-devices - auto PltImpl = sycl::detail::getSyclObjImpl(Plt); + auto &PltImpl = *sycl::detail::getSyclObjImpl(Plt).get(); auto subDev1 = std::make_shared(urSubDev1, PltImpl); auto subDev2 = diff --git a/sycl/unittests/queue/DeviceCheck.cpp b/sycl/unittests/queue/DeviceCheck.cpp index 784541d7cb5f3..a6e940bdb855b 100644 --- a/sycl/unittests/queue/DeviceCheck.cpp +++ b/sycl/unittests/queue/DeviceCheck.cpp @@ -117,8 +117,9 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) { { ParentDevice = nullptr; device Device = detail::createSyclObjFromImpl( - std::make_shared(reinterpret_cast(0x01), - detail::getSyclObjImpl(Plt))); + std::make_shared( + reinterpret_cast(0x01), + *detail::getSyclObjImpl(Plt).get())); queue Q{Device}; EXPECT_NE(Q.get_context(), DefaultCtx); try {