Skip to content

Cooperative vector API #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ if (DRJIT_ENABLE_OPTIX)
src/optix.h
src/optix_api.h
src/optix_api.cpp
src/optix_coop_vec.h
src/optix_coop_vec.cpp
src/optix_core.cpp)
endif()

Expand All @@ -119,6 +121,7 @@ add_library(
src/registry.h src/registry.cpp
src/util.h src/util.cpp
src/record_ts.h src/record_ts.cpp
src/coop_vec.h src/coop_vec.cpp

# CUDA backend
src/cuda_api.h
Expand Down Expand Up @@ -158,6 +161,8 @@ add_library(
src/llvm_packet.cpp
src/llvm_array.h
src/llvm_array.cpp
src/llvm_coop_vec.h
src/llvm_coop_vec.cpp

src/io.h src/io.cpp
src/eval.h src/eval.cpp
Expand Down
2 changes: 1 addition & 1 deletion ext/nanothread
18 changes: 9 additions & 9 deletions include/drjit-core/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,15 @@ template <JitBackend Backend_, typename Value_> struct JitArray {
friend bool any(const JitArray &a) { return jit_var_any(a.m_index); }
friend bool none(const JitArray &a) { return !jit_var_any(a.m_index); }

friend const char *label(const JitArray &v) {
return jit_var_label(v.m_index);
}

friend void set_label(JitArray &v, const char *label) {
uint32_t index = jit_var_set_label(v.m_index, 1, label);
jit_var_dec_ref(v.m_index);
v.m_index = index;
}
friend const char *label(const JitArray &v) {
return jit_var_label(v.m_index);
}

friend void set_label(JitArray &v, const char *label) {
uint32_t index = jit_var_set_label(v.m_index, 1, label);
jit_var_dec_ref(v.m_index);
v.m_index = index;
}
protected:
uint32_t m_index = 0;
};
Expand Down
158 changes: 157 additions & 1 deletion include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ extern JIT_EXPORT void* jit_cuda_pop_context();
/// Query the compute capability of the current device (e.g. '52')
extern JIT_EXPORT int jit_cuda_compute_capability();

/// Get the major and minor CUDA version
extern JIT_EXPORT void jit_cuda_version(int *major, int *minor);

/**
* \brief Override generated PTX version and compute capability
*
Expand Down Expand Up @@ -637,7 +640,10 @@ enum class JitOp : uint32_t {
Rcp, Rsqrt,

// Multi-function generator (CUDA)
Sin, Cos, Exp2, Log2,
Sin, Cos, Exp2, Log2, Tanh,

// Step function
Step,

// Total number of operations
Count
Expand Down Expand Up @@ -799,6 +805,9 @@ extern JIT_EXPORT uint32_t jit_var_exp2_intrinsic(uint32_t a0);
/// Approximate `log2(a0)` and return a variable representing the result
extern JIT_EXPORT uint32_t jit_var_log2_intrinsic(uint32_t a0);

/// Approximate `tanh(a0)` and return a variable representing the result
extern JIT_EXPORT uint32_t jit_var_tanh_intrinsic(uint32_t a0);

/// Return a variable indicating valid lanes within a function call
extern JIT_EXPORT uint32_t jit_var_call_mask(JitBackend backend);

Expand Down Expand Up @@ -2403,6 +2412,7 @@ struct VarInfo {
void *data;
};
bool is_array;
bool is_coop_vec;
bool unaligned;
};

Expand Down Expand Up @@ -2677,6 +2687,152 @@ extern JIT_EXPORT void jit_freeze_abort(JitBackend backend);
*/
extern JIT_EXPORT void jit_freeze_destroy(Recording *recording);

// ====================================================================
// Cooperative vector API
// ====================================================================

// A cooperative vector groups a set of JIT variables into a special (opaque)
// object that can be manipulated with the API below. Standard JIT variable
// operations (<tt>jit_var_*</tt>) are not permitted on cooperative vectors.

/// Pack a set of regular Dr.Jit variables to form a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in);

/// Unpack a cooperative vector into its components
extern JIT_EXPORT void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out);

/// Create a cooperative vectors, whose components are a uniform literal constant
extern JIT_EXPORT uint32_t jit_coop_vec_literal(JIT_ENUM JitBackend backend,
JIT_ENUM VarType type,
const void *value,
size_t size JIT_DEF(1),
uint32_t length JIT_DEF(1));

/// Load a cooperative vector from memory
extern JIT_EXPORT uint32_t jit_coop_vec_load(uint32_t buffer,
uint32_t offset,
uint32_t length);

/// Determine the length of a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_length(uint32_t index);

/// Perform a unary operation on a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0);

/// Perform a binary operation on a pair of cooperative vectors
extern JIT_EXPORT uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1);

/// Perform a ternary operation on a triplet of cooperative vectors
extern JIT_EXPORT uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2);

/// Encodes a type of request for jit_coop_vec_pack_matrices()
enum class MatrixLayout : uint32_t {
RowMajor,
InferencingOptimal,
TrainingOptimal
};

/// Summary of a matrix/vector that has been packed into a buffer. The packing
/// is vendor-specific and may involve padding to a larger size.
struct MatrixDescr {
VarType dtype; //< Variable type
MatrixLayout layout; //< Layout type
uint32_t rows, cols; //< Shape of the matrix
uint32_t offset; //< Offset from the beginning of the buffer (in elements)
uint32_t stride; //< Row stride (in elements)
uint32_t size; //< Total size (in elements)
};

/// Check if the backend supports cooperative vector operations
extern JIT_EXPORT bool jit_coop_vec_supported(JitBackend backend);

/**
* \brief Pack a sequence of matrices from row-major into a representation that
* is optimal for inference/training.
*
* The function can also perform the reverse (i.e., convert from
* training/inference-optimal back to row-major).
*
* The matrix/vector layout and offsets must be specified via the ``in_descr``
* and ``out_descr`` parameters, which should point to descriptor arrays with
* ``count`` records.
*
* The ``in`` and ``out`` arguments specify Dr.Jit variables from which data
* will be read/written.
*
* The operation launches a backend conversion kernel that runs asynchronously.
*/
extern JIT_EXPORT void jit_coop_vec_pack_matrices(uint32_t count,
uint32_t in,
const MatrixDescr *in_descr,
uint32_t out,
const MatrixDescr *out_descr);

/**
* \brief Query the backend to compute the size of an array/vector in a given
* layout
*
* This operation is needed to compute the output layout for a subsequent
* call to \ref jit_coop_vec_pack_matrices().
*
* The variable ``index`` is the variable index of a buffer holding the input
* matrix/vector, and ``in`` describes the associated matrix layout. The
* following elements of this data structure must be set: ``dtype``, ``layout``,
* ``row``, ``cols``, ``offset``.
*
* The function returns a new version of this record corresponding to the
* packed version (with elements copied or updated). It furthermore sets the
* ``stride`` and ``size`` members.
*/
extern JIT_EXPORT MatrixDescr
jit_coop_vec_compute_layout(uint32_t index, const MatrixDescr *in,
MatrixLayout layout, uint32_t offset);

/**
* \brief Perform a matrix-vector multiplication + bias addition (i.e.,
* <tt>A@x+b</tt>).
*
* The \c A_descr and \c b_descr arguments specify the layout and offset
* of the matrix and bias term in associated memory buffers `A_index` and
* ``b_index``.
*
* The ``x_index`` variable should provide the variable index of a cooperative
* vector created using the ``jit_coop_vec_*`` API.
*
* The \c b_index and \c b_descr arguments are optional when no bias term is
* desired and should be zero-initialized in that case.
*/
extern JIT_EXPORT uint32_t jit_coop_vec_matvec(uint32_t A_index,
const MatrixDescr *A_descr,
uint32_t x_index,
uint32_t b_index,
const MatrixDescr *b_descr,
int transpose);

/// Accumulate the coop. vector 'index' into the buffer 'target' with offset.
/// Potentially create a new buffer of size 'target_size' if target == 0.
extern JIT_EXPORT uint32_t jit_coop_vec_accum(uint32_t target,
uint32_t target_size,
uint32_t offset,
uint32_t index);

/// Cast a cooperative vector to a different precision
extern JIT_EXPORT uint32_t jit_coop_vec_cast(uint32_t index, VarType vt);

/**
* \brief Outer product accumulation for cooperative vectors
*
* Atomically accumulate the outer product of cooperative vectors \c a and \c b
* into buffer \c target, at a location described by \c descr. Potentially
* create a new buffer of size \c size if <tt>target == 0</tt>.
*/
extern JIT_EXPORT uint32_t jit_coop_vec_outer_product_accum(
uint32_t target,
uint32_t target_size,
const MatrixDescr *descr,
uint32_t a,
uint32_t b);

#if defined(__cplusplus)
}
#endif
Loading