diff --git a/dpctl/tensor/_dlpack.pxd b/dpctl/tensor/_dlpack.pxd index 1a10c79fc7..f44db7c05c 100644 --- a/dpctl/tensor/_dlpack.pxd +++ b/dpctl/tensor/_dlpack.pxd @@ -41,6 +41,7 @@ cdef extern from "dlpack/dlpack.h" nogil: int device_WebGPU "kDLWebGPU" int device_Hexagon "kDLHexagon" int device_MAIA "kDLMAIA" + int device_Trn "kDLTrn" cpdef object to_dlpack_capsule(usm_ndarray array) except + cpdef object to_dlpack_versioned_capsule( diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index 29b3eccead..e1ba968621 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -36,7 +36,7 @@ from .._backend cimport ( DPCTLSyclDeviceRef, DPCTLSyclUSMRef, ) -from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, USM_ARRAY_WRITABLE, usm_ndarray +from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray import ctypes @@ -76,6 +76,7 @@ cdef extern from "dlpack/dlpack.h" nogil: kDLWebGPU kDLHexagon kDLMAIA + kDLTrn ctypedef struct DLDevice: DLDeviceType device_type @@ -88,6 +89,17 @@ cdef extern from "dlpack/dlpack.h" nogil: kDLBfloat kDLComplex kDLBool + kDLFloat8_e3m4 + kDLFloat8_e4m3 + kDLFloat8_e4m3b11fnuz + kDLFloat8_e4m3fn + kDLFloat8_e4m3fnuz + kDLFloat8_e5m2 + kDLFloat8_e5m2fnuz + kDLFloat8_e8m0fnu + kDLFloat6_e2m3fn + kDLFloat6_e3m2fn + kDLFloat4_e2m1fn ctypedef struct DLDataType: uint8_t code @@ -254,7 +266,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): cdef int64_t *shape_strides_ptr = NULL cdef int i = 0 cdef int device_id = -1 - cdef int flags = 0 cdef Py_ssize_t element_offset = 0 cdef Py_ssize_t byte_offset = 0 cdef Py_ssize_t si = 1 @@ -269,22 +280,21 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): raise MemoryError( "to_dlpack_capsule: Could not allocate memory for DLManagedTensor" ) - shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) - if shape_strides_ptr is NULL: - stdlib.free(dlm_tensor) - raise MemoryError( - "to_dlpack_capsule: Could not allocate memory for shape/strides" - ) - shape_ptr = usm_ary.get_shape() - for i in range(nd): - shape_strides_ptr[i] = shape_ptr[i] - strides_ptr = usm_ary.get_strides() - flags = usm_ary.flags_ - if strides_ptr: + if nd > 0: + shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) + if shape_strides_ptr is NULL: + stdlib.free(dlm_tensor) + raise MemoryError( + "to_dlpack_capsule: Could not allocate memory for shape/strides" + ) + shape_ptr = usm_ary.get_shape() for i in range(nd): - shape_strides_ptr[nd + i] = strides_ptr[i] - else: - if not (flags & USM_ARRAY_C_CONTIGUOUS): + shape_strides_ptr[i] = shape_ptr[i] + strides_ptr = usm_ary.get_strides() + if strides_ptr: + for i in range(nd): + shape_strides_ptr[nd + i] = strides_ptr[i] + else: si = 1 for i in range(0, nd): shape_strides_ptr[nd + i] = si @@ -300,11 +310,8 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary): dl_tensor.data = (data_ptr - byte_offset) dl_tensor.ndim = nd dl_tensor.byte_offset = byte_offset - dl_tensor.shape = &shape_strides_ptr[0] - if strides_ptr is NULL: - dl_tensor.strides = NULL - else: - dl_tensor.strides = &shape_strides_ptr[nd] + dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL + dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL dl_tensor.device.device_type = kDLOneAPI dl_tensor.device.device_id = device_id dl_tensor.dtype.lanes = 1 @@ -384,24 +391,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied): "to_dlpack_versioned_capsule: Could not allocate memory " "for DLManagedTensorVersioned" ) - shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) - if shape_strides_ptr is NULL: - stdlib.free(dlmv_tensor) - raise MemoryError( - "to_dlpack_versioned_capsule: Could not allocate memory " - "for shape/strides" - ) - # this can be a separate function for handling shapes and strides - shape_ptr = usm_ary.get_shape() - for i in range(nd): - shape_strides_ptr[i] = shape_ptr[i] - strides_ptr = usm_ary.get_strides() - flags = usm_ary.flags_ - if strides_ptr: + if nd > 0: + shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) + if shape_strides_ptr is NULL: + stdlib.free(dlmv_tensor) + raise MemoryError( + "to_dlpack_versioned_capsule: Could not allocate memory " + "for shape/strides" + ) + # this can be a separate function for handling shapes and strides + shape_ptr = usm_ary.get_shape() for i in range(nd): - shape_strides_ptr[nd + i] = strides_ptr[i] - else: - if not (flags & USM_ARRAY_C_CONTIGUOUS): + shape_strides_ptr[i] = shape_ptr[i] + strides_ptr = usm_ary.get_strides() + flags = usm_ary.flags_ + if strides_ptr: + for i in range(nd): + shape_strides_ptr[nd + i] = strides_ptr[i] + else: si = 1 for i in range(0, nd): shape_strides_ptr[nd + i] = si @@ -419,11 +426,8 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied): dl_tensor.data = (data_ptr - byte_offset) dl_tensor.ndim = nd dl_tensor.byte_offset = byte_offset - dl_tensor.shape = &shape_strides_ptr[0] - if strides_ptr is NULL: - dl_tensor.strides = NULL - else: - dl_tensor.strides = &shape_strides_ptr[nd] + dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL + dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL dl_tensor.device.device_type = kDLOneAPI dl_tensor.device.device_id = device_id dl_tensor.dtype.lanes = 1 @@ -503,10 +507,9 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied): "for DLManagedTensorVersioned" ) - is_c_contiguous = npy_ary.flags["C"] shape = npy_ary.ctypes.shape_as(ctypes.c_int64) strides = npy_ary.ctypes.strides_as(ctypes.c_int64) - if not is_c_contiguous: + if nd > 0: if npy_ary.size != 1: for i in range(nd): if shape[i] != 1 and strides[i] % itemsize != 0: @@ -517,18 +520,14 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied): "itemsize" ) shape_strides_ptr = stdlib.malloc((sizeof(int64_t) * 2) * nd) - else: - # no need to pass strides in this case - shape_strides_ptr = stdlib.malloc(sizeof(int64_t) * nd) - if shape_strides_ptr is NULL: - stdlib.free(dlmv_tensor) - raise MemoryError( - "numpy_to_dlpack_versioned_capsule: Could not allocate memory " - "for shape/strides" - ) - for i in range(nd): - shape_strides_ptr[i] = shape[i] - if not is_c_contiguous: + if shape_strides_ptr is NULL: + stdlib.free(dlmv_tensor) + raise MemoryError( + "numpy_to_dlpack_versioned_capsule: Could not allocate memory " + "for shape/strides" + ) + for i in range(nd): + shape_strides_ptr[i] = shape[i] shape_strides_ptr[nd + i] = strides[i] // itemsize writable_flag = npy_ary.flags["W"] @@ -540,11 +539,8 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied): dl_tensor.data = npy_ary.data dl_tensor.ndim = nd dl_tensor.byte_offset = byte_offset - dl_tensor.shape = &shape_strides_ptr[0] - if is_c_contiguous: - dl_tensor.strides = NULL - else: - dl_tensor.strides = &shape_strides_ptr[nd] + dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL + dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL dl_tensor.device.device_type = kDLCPU dl_tensor.device.device_id = 0 dl_tensor.dtype.lanes = 1 @@ -816,12 +812,8 @@ cpdef object from_dlpack_capsule(object py_caps): raise BufferError( "Can not import DLPack tensor with lanes != 1" ) - offset_min = 0 - if dl_tensor.strides is NULL: - for i in range(dl_tensor.ndim): - sz = sz * dl_tensor.shape[i] - offset_max = sz - 1 - else: + if dl_tensor.ndim > 0: + offset_min = 0 offset_max = 0 for i in range(dl_tensor.ndim): stride_i = dl_tensor.strides[i] @@ -876,15 +868,17 @@ cpdef object from_dlpack_capsule(object py_caps): (q).get_queue_ref(), memory_owner=tmp ) + py_shape = list() - for i in range(dl_tensor.ndim): - py_shape.append(dl_tensor.shape[i]) - if (dl_tensor.strides is NULL): - py_strides = None - else: + if (dl_tensor.shape is not NULL): + for i in range(dl_tensor.ndim): + py_shape.append(dl_tensor.shape[i]) + if (dl_tensor.strides is not NULL): py_strides = list() for i in range(dl_tensor.ndim): py_strides.append(dl_tensor.strides[i]) + else: + py_strides = None if (dl_tensor.dtype.code == kDLUInt): ary_dt = np.dtype("u" + str(element_bytesize)) elif (dl_tensor.dtype.code == kDLInt): diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 0e42b42faf..70bb4243f6 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -86,6 +86,8 @@ class DLDeviceType(IntEnum): Qualcomm Hexagon DSP ``kDLMAIA``: Microsoft MAIA device + ``kDLTrn``: + AWS Trainium device """ kDLCPU = c_dlpack.device_CPU kDLCUDA = c_dlpack.device_CUDA @@ -101,6 +103,7 @@ class DLDeviceType(IntEnum): kDLWebGPU = c_dlpack.device_WebGPU kDLHexagon = c_dlpack.device_Hexagon kDLMAIA = c_dlpack.device_MAIA + kDLTrn = c_dlpack.device_Trn cdef class InternalUSMArrayError(Exception): diff --git a/dpctl/tensor/include/dlpack/dlpack.h b/dpctl/tensor/include/dlpack/dlpack.h index bcb77949a8..9a710ebde6 100644 --- a/dpctl/tensor/include/dlpack/dlpack.h +++ b/dpctl/tensor/include/dlpack/dlpack.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2017 by Contributors + * Copyright (c) 2017 - by Contributors * \file dlpack.h * \brief The common header of DLPack. */ @@ -19,7 +19,7 @@ #define DLPACK_MAJOR_VERSION 1 /*! \brief The current minor version of dlpack */ -#define DLPACK_MINOR_VERSION 0 +#define DLPACK_MINOR_VERSION 2 /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -118,6 +118,8 @@ typedef enum { kDLHexagon = 16, /*! \brief Microsoft MAIA devices */ kDLMAIA = 17, + /*! \brief AWS Trainium */ + kDLTrn = 18, } DLDeviceType; /*! @@ -157,6 +159,26 @@ typedef enum { kDLComplex = 5U, /*! \brief boolean */ kDLBool = 6U, + /*! \brief FP8 data types */ + kDLFloat8_e3m4 = 7U, + kDLFloat8_e4m3 = 8U, + kDLFloat8_e4m3b11fnuz = 9U, + kDLFloat8_e4m3fn = 10U, + kDLFloat8_e4m3fnuz = 11U, + kDLFloat8_e5m2 = 12U, + kDLFloat8_e5m2fnuz = 13U, + kDLFloat8_e8m0fnu = 14U, + /*! \brief FP6 data types + * Setting bits != 6 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat6_e2m3fn = 15U, + kDLFloat6_e3m2fn = 16U, + /*! \brief FP4 data types + * Setting bits != 4 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat4_e2m1fn = 17U, } DLDataTypeCode; /*! @@ -170,6 +192,12 @@ typedef enum { * - int8: type_code = 0, bits = 8, lanes = 1 * - std::complex: type_code = 5, bits = 64, lanes = 1 * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + * - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory) + * - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory) + * - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory) + * + * When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e., + * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. */ typedef struct { /*! @@ -196,8 +224,8 @@ typedef struct { * types. This pointer is always aligned to 256 bytes as in CUDA. The * `byte_offset` field should be used to point to the beginning of the data. * - * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. @@ -226,11 +254,23 @@ typedef struct { int32_t ndim; /*! \brief The data type of the pointer*/ DLDataType dtype; - /*! \brief The shape of the tensor */ + /*! + * \brief The shape of the tensor + * + * When ndim == 0, shape can be set to NULL. + */ int64_t* shape; /*! - * \brief strides of the tensor (in number of elements, not bytes) - * can be NULL, indicating tensor is compact and row-majored. + * \brief strides of the tensor (in number of elements, not bytes), + * can not be NULL if ndim != 0, must points to + * an array of ndim elements that specifies the strides, + * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. + * + * When ndim == 0, strides can be set to NULL. + * + * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. + * This is not allowed in DLPack v1.2 and later. The rationale + * is to simplify the consumer handling. */ int64_t* strides; /*! \brief The offset in bytes to the beginning pointer to data */ @@ -267,7 +307,7 @@ typedef struct DLManagedTensor { void (*deleter)(struct DLManagedTensor * self); } DLManagedTensor; -// bit masks used in in the DLManagedTensorVersioned +// bit masks used in the DLManagedTensorVersioned /*! \brief bit mask to indicate that the tensor is read only. */ #define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) @@ -280,6 +320,14 @@ typedef struct DLManagedTensor { */ #define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) +/*! + * \brief bit mask to indicate that whether a sub-byte type is packed or padded. + * + * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can + * be set by the producer to signal that a tensor of sub-byte type is padded. + */ +#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL) + /*! * \brief A versioned and managed C Tensor object, manage memory of DLTensor. * @@ -290,7 +338,7 @@ typedef struct DLManagedTensor { * * \note This is the current standard DLPack exchange data structure. */ -struct DLManagedTensorVersioned { +typedef struct DLManagedTensorVersioned { /*! * \brief The API and ABI version of the current managed Tensor */ @@ -324,7 +372,266 @@ struct DLManagedTensorVersioned { uint64_t flags; /*! \brief DLTensor which is being memory managed */ DLTensor dl_tensor; -}; +} DLManagedTensorVersioned; + +//---------------------------------------------------------------------- +// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions +//---------------------------------------------------------------------- +/*! + * \brief Request a producer library to create a new tensor. + * + * Create a new `DLManagedTensorVersioned` within the context of the producer + * library. The allocation is defined via the prototype DLTensor. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, + * and device fields are used. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx Context for `SetError`. + * \param SetError The function to set the error. + * \return The owning DLManagedTensorVersioned* or NULL on failure. + * SetError is called exactly when NULL is returned (the implementor + * must ensure this). + * \note - As a C function, must not thrown C++ exceptions. + * - Error propagation via SetError to avoid any direct need + * of Python API. Due to this `SetError` may have to ensure the GIL is + * held since it will presumably set a Python error. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \return The owning DLManagedTensorVersioned* or NULL on failure with a + * Python exception set. If the data cannot be described using DLPack + * this should be a BufferError if possible. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // + void* py_object, // + DLManagedTensorVersioned** out // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. + * + * This function provides a faster interface for temporary, non-owning, exchange. + * The producer (implementor) still owns the memory of data, strides, shape. + * The liveness of the DLTensor and the data it views is only guaranteed until + * control is returned. + * + * This function currently assumes that the producer (implementor) can fill + * in the DLTensor shape and strides without the need for temporary allocations. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \param out The output DLTensor, whose space is pre-allocated on stack. + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackDLTensorFromPyObjectNoSync)( // + void* py_object, // + DLTensor* out // +); + +/*! + * \brief Obtain the current work stream of a device. + * + * Obtain the current work stream of a device from the producer framework. + * For example, it should map to torch.cuda.current_stream in PyTorch. + * + * When device_type is kDLCPU, the consumer do not have to query the stream + * and the producer can simply return NULL when queried. + * The consumer do not have to do anything on stream sync or setting. + * So CPU only framework can just provide a dummy implementation that + * always set out_current_stream[0] to NULL. + * + * \param device_type The device type. + * \param device_id The device id. + * \param out_current_stream The output current work stream. + * + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackCurrentWorkStream)( // + DLDeviceType device_type, // + int32_t device_id, // + void** out_current_stream // +); + +/*! + * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. + * + * Convert an owning DLManagedTensorVersioned* to the Python tensor of the + * producer (implementor) library with the correct type. + * + * This function does not perform any stream synchronization. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param tensor The DLManagedTensorVersioned to convert the ownership of the + * tensor is stolen. + * \param out_py_object The output Python object. + * \return 0 on success, -1 on failure with a Python exception set. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorToPyObjectNoSync)( // + DLManagedTensorVersioned* tensor, // + void** out_py_object // +); + +/*! + * \brief DLPackExchangeAPI stable header. + * \sa DLPackExchangeAPI + */ +typedef struct DLPackExchangeAPIHeader { + /*! + * \brief The provided DLPack version the consumer must check major version + * compatibility before using this struct. + */ + DLPackVersion version; + /*! + * \brief Optional pointer to an older DLPackExchangeAPI in the chain. + * + * It must be NULL if the framework does not support older versions. + * If the current major version is larger than the one supported by the + * consumer, the consumer may walk this to find an earlier supported version. + * + * \sa DLPackExchangeAPI + */ + struct DLPackExchangeAPIHeader* prev_api; +} DLPackExchangeAPIHeader; + +/*! + * \brief Framework-specific function pointers table for DLPack exchange. + * + * Additionally to `__dlpack__()` we define a C function table sharable by + * Python implementations via `__c_dlpack_exchange_api__`. + * This attribute must be set on the type as a Python integer compatible + * with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`. + * + * A consumer library may use a pattern such as: + * + * \code + * + * PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code + * MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj); + * if (api == NULL && PyErr_Occurred()) { goto handle_error; } + * + * \endcode + * + * Note that this must be defined on the type. The consumer should look up the + * attribute on the type and may cache the result for each unique type. + * + * The precise API table is given by: + * \code + * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { + * MyDLPackExchangeAPI() { + * header.version.major = DLPACK_MAJOR_VERSION; + * header.version.minor = DLPACK_MINOR_VERSION; + * header.prev_version_api = nullptr; + * + * managed_tensor_allocator = MyDLPackManagedTensorAllocator; + * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; + * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; + * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; + * current_work_stream = MyDLPackCurrentWorkStream; + * } + * + * static const DLPackExchangeAPI* Global() { + * static MyDLPackExchangeAPI inst; + * return &inst; + * } + * }; + * \endcode + * + * Guidelines for leveraging DLPackExchangeAPI: + * + * There are generally two kinds of consumer needs for DLPack exchange: + * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel + * with the data from x, y, z. The consumer is also expected to run the kernel with the same + * stream context as the producer. For example, when x, y, z is torch.Tensor, + * consumer should query exchange_api->current_work_stream to get the + * current stream and launch the kernel with the same stream. + * This setup is necessary for no synchronization in kernel launch and maximum compatibility + * with CUDA graph capture in the producer. + * This is the desirable behavior for library extension support for frameworks like PyTorch. + * - N1: data ingestion and retention + * + * Note that obj.__dlpack__() API should provide useful ways for N1. + * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 + * with the support of the function pointer current_work_stream. + * + * Array/Tensor libraries should statically create and initialize this structure + * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. + * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. + * + * One simple way to do so is to create a static instance of DLPackExchangeAPI + * within the framework and return a pointer to it. The following code + * shows an example to do so in C++. It should also be reasonably easy + * to do so in other languages. + */ +typedef struct DLPackExchangeAPI { + /*! + * \brief The header that remains stable across versions. + */ + DLPackExchangeAPIHeader header; + /*! + * \brief Producer function pointer for DLPackManagedTensorAllocator + * This function must not be NULL. + * \sa DLPackManagedTensorAllocator + */ + DLPackManagedTensorAllocator managed_tensor_allocator; + /*! + * \brief Producer function pointer for DLPackManagedTensorFromPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorFromPyObject + */ + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackManagedTensorToPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorToPyObject + */ + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackDLTensorFromPyObject + * This function can be NULL when the producer does not support this function. + * \sa DLPackDLTensorFromPyObjectNoSync + */ + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackCurrentWorkStream + * This function must be not NULL. + * \sa DLPackCurrentWorkStream + */ + DLPackCurrentWorkStream current_work_stream; +} DLPackExchangeAPI; #ifdef __cplusplus } // DLPACK_EXTERN_C