|
14 | 14 | #include <sycl/detail/ur.hpp> |
15 | 15 | #include <sycl/device.hpp> |
16 | 16 | #include <sycl/ext/intel/experimental/usm_properties.hpp> |
| 17 | +#include <sycl/ext/oneapi/experimental/register_host_memory.hpp> |
17 | 18 | #include <sycl/ext/oneapi/memcpy2d.hpp> |
18 | 19 | #include <sycl/usm.hpp> |
19 | 20 |
|
@@ -609,6 +610,83 @@ void release_from_device_copy(const void *Ptr, const queue &Queue) { |
609 | 610 | release_from_usm_device_copy(Ptr, Queue.get_context()); |
610 | 611 | } |
611 | 612 |
|
| 613 | +// Host memory registration APIs, see sycl_ext_oneapi_register_host_memory. |
| 614 | + |
| 615 | +namespace detail { |
| 616 | + |
| 617 | +// Throws errc::feature_not_supported unless every device in the context |
| 618 | +// reports aspect::ext_oneapi_register_host_memory. |
| 619 | +static void checkRegisterHostMemorySupport(const context &Ctxt) { |
| 620 | + detail::context_impl &CtxtImpl = *detail::getSyclObjImpl(Ctxt); |
| 621 | + for (detail::device_impl &Dev : CtxtImpl.getDevices()) { |
| 622 | + if (!Dev.has(aspect::ext_oneapi_register_host_memory)) |
| 623 | + throw sycl::exception( |
| 624 | + make_error_code(errc::feature_not_supported), |
| 625 | + "At least one device in the context does not support registering " |
| 626 | + "host memory (aspect::ext_oneapi_register_host_memory)."); |
| 627 | + } |
| 628 | +} |
| 629 | + |
| 630 | +// Maps a failed UR result from the host memory registration APIs to a |
| 631 | +// sycl::exception with the error code mandated by the extension specification. |
| 632 | +// Invalid argument conditions map to errc::invalid; anything else is a backend |
| 633 | +// error. |
| 634 | +static void throwRegisterHostMemoryError(ur_result_t Err, const char *What) { |
| 635 | + errc Code; |
| 636 | + switch (Err) { |
| 637 | + case UR_RESULT_ERROR_INVALID_NULL_POINTER: |
| 638 | + case UR_RESULT_ERROR_INVALID_VALUE: |
| 639 | + case UR_RESULT_ERROR_INVALID_ARGUMENT: |
| 640 | + Code = errc::invalid; |
| 641 | + break; |
| 642 | + default: |
| 643 | + Code = errc::runtime; |
| 644 | + break; |
| 645 | + } |
| 646 | + throw detail::set_ur_error(sycl::exception(make_error_code(Code), What), Err); |
| 647 | +} |
| 648 | + |
| 649 | +void register_host_memory(void *Ptr, size_t NumBytes, const context &Ctxt, |
| 650 | + uint32_t Flags) { |
| 651 | + if (Ptr == nullptr) |
| 652 | + throw sycl::exception(make_error_code(errc::invalid), |
| 653 | + "register_host_memory: pointer must not be null."); |
| 654 | + if (NumBytes == 0) |
| 655 | + throw sycl::exception(make_error_code(errc::invalid), |
| 656 | + "register_host_memory: size must not be zero."); |
| 657 | + checkRegisterHostMemorySupport(Ctxt); |
| 658 | + |
| 659 | + ur_exp_usm_host_alloc_register_properties_t Props = { |
| 660 | + UR_STRUCTURE_TYPE_EXP_USM_HOST_ALLOC_REGISTER_PROPERTIES, |
| 661 | + /*pNext=*/nullptr, |
| 662 | + /*flags=*/0}; |
| 663 | + if (Flags & register_host_memory_flag_read_only) |
| 664 | + Props.flags |= UR_EXP_USM_HOST_ALLOC_REGISTER_FLAG_READ_ONLY; |
| 665 | + |
| 666 | + auto [urCtx, Adapter] = get_ur_handles(Ctxt); |
| 667 | + ur_result_t Err = |
| 668 | + Adapter->call_nocheck<detail::UrApiKind::urUSMHostAllocRegisterExp>( |
| 669 | + urCtx, Ptr, NumBytes, &Props); |
| 670 | + if (Err != UR_RESULT_SUCCESS) |
| 671 | + throwRegisterHostMemoryError(Err, "register_host_memory failed."); |
| 672 | +} |
| 673 | + |
| 674 | +void unregister_host_memory(void *Ptr, const context &Ctxt) { |
| 675 | + if (Ptr == nullptr) |
| 676 | + throw sycl::exception(make_error_code(errc::invalid), |
| 677 | + "unregister_host_memory: pointer must not be null."); |
| 678 | + checkRegisterHostMemorySupport(Ctxt); |
| 679 | + |
| 680 | + auto [urCtx, Adapter] = get_ur_handles(Ctxt); |
| 681 | + ur_result_t Err = |
| 682 | + Adapter->call_nocheck<detail::UrApiKind::urUSMHostAllocUnregisterExp>( |
| 683 | + urCtx, Ptr); |
| 684 | + if (Err != UR_RESULT_SUCCESS) |
| 685 | + throwRegisterHostMemoryError(Err, "unregister_host_memory failed."); |
| 686 | +} |
| 687 | + |
| 688 | +} // namespace detail |
| 689 | + |
612 | 690 | void *malloc_device(size_t numBytes, const device &syclDevice, |
613 | 691 | const property_list &propList) { |
614 | 692 | sycl::context ctxt = syclDevice.get_platform().khr_get_default_context(); |
|
0 commit comments