Skip to content

Cooperative vector API #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ if (DRJIT_ENABLE_OPTIX)
src/optix.h
src/optix_api.h
src/optix_api.cpp
src/optix_coop_vec.h
src/optix_coop_vec.cpp
src/optix_core.cpp)
endif()

Expand All @@ -119,6 +121,7 @@ add_library(
src/registry.h src/registry.cpp
src/util.h src/util.cpp
src/record_ts.h src/record_ts.cpp
src/coop_vec.h src/coop_vec.cpp

# CUDA backend
src/cuda_api.h
Expand Down Expand Up @@ -158,6 +161,8 @@ add_library(
src/llvm_packet.cpp
src/llvm_array.h
src/llvm_array.cpp
src/llvm_coop_vec.h
src/llvm_coop_vec.cpp

src/io.h src/io.cpp
src/eval.h src/eval.cpp
Expand Down
2 changes: 1 addition & 1 deletion ext/nanothread
18 changes: 9 additions & 9 deletions include/drjit-core/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,15 @@ template <JitBackend Backend_, typename Value_> struct JitArray {
friend bool any(const JitArray &a) { return jit_var_any(a.m_index); }
friend bool none(const JitArray &a) { return !jit_var_any(a.m_index); }

friend const char *label(const JitArray &v) {
return jit_var_label(v.m_index);
}

friend void set_label(JitArray &v, const char *label) {
uint32_t index = jit_var_set_label(v.m_index, 1, label);
jit_var_dec_ref(v.m_index);
v.m_index = index;
}
friend const char *label(const JitArray &v) {
return jit_var_label(v.m_index);
}

friend void set_label(JitArray &v, const char *label) {
uint32_t index = jit_var_set_label(v.m_index, 1, label);
jit_var_dec_ref(v.m_index);
v.m_index = index;
}
protected:
uint32_t m_index = 0;
};
Expand Down
103 changes: 102 additions & 1 deletion include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ extern JIT_EXPORT void* jit_cuda_pop_context();
/// Query the compute capability of the current device (e.g. '52')
extern JIT_EXPORT int jit_cuda_compute_capability();

/// Get the major and minor CUDA version
extern JIT_EXPORT void jit_cuda_version(int *major, int *minor);

/**
* \brief Override generated PTX version and compute capability
*
Expand Down Expand Up @@ -637,7 +640,10 @@ enum class JitOp : uint32_t {
Rcp, Rsqrt,

// Multi-function generator (CUDA)
Sin, Cos, Exp2, Log2,
Sin, Cos, Exp2, Log2, Tanh,

// Step function
Step,

// Total number of operations
Count
Expand Down Expand Up @@ -799,6 +805,9 @@ extern JIT_EXPORT uint32_t jit_var_exp2_intrinsic(uint32_t a0);
/// Approximate `log2(a0)` and return a variable representing the result
extern JIT_EXPORT uint32_t jit_var_log2_intrinsic(uint32_t a0);

/// Approximate `tanh(a0)` and return a variable representing the result
extern JIT_EXPORT uint32_t jit_var_tanh_intrinsic(uint32_t a0);

/// Return a variable indicating valid lanes within a function call
extern JIT_EXPORT uint32_t jit_var_call_mask(JitBackend backend);

Expand Down Expand Up @@ -2403,6 +2412,7 @@ struct VarInfo {
void *data;
};
bool is_array;
bool is_coop_vec;
bool unaligned;
};

Expand Down Expand Up @@ -2677,6 +2687,97 @@ extern JIT_EXPORT void jit_freeze_abort(JitBackend backend);
*/
extern JIT_EXPORT void jit_freeze_destroy(Recording *recording);

// ====================================================================
// Cooperative vector API
// ====================================================================

/// Pack a set of regular Dr.Jit variables to form a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in);

/// Unpack a cooperative vector into its components
extern JIT_EXPORT void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out);

/// Create a cooperative vectors, whose components are a uniform literal constant
extern JIT_EXPORT uint32_t jit_coop_vec_literal(JIT_ENUM JitBackend backend,
JIT_ENUM VarType type,
const void *value,
size_t size JIT_DEF(1),
uint32_t length JIT_DEF(1));

/// Load a cooperative vector from memory
extern JIT_EXPORT uint32_t jit_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length);

/// Determine the length of a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_length(uint32_t index);

/// Perform a unary operation on a cooperative vector
extern JIT_EXPORT uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0);

/// Perform a binary operation on a pair of cooperative vectors
extern JIT_EXPORT uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1);

/// Perform a ternary operation on a triplet of cooperative vectors
extern JIT_EXPORT uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2);

/// Encodes a type of request for jit_coop_vec_pack_matrices()
enum class MatrixLayout : uint32_t {
RowMajor,
InferencingOptimal,
TrainingOptimal
};

/// Summary of a matrix/vector that has been packed into a buffer
struct MatrixDescr {
VarType dtype; //< Variable type
MatrixLayout layout; //< Layout type
uint32_t rows, cols; //< Shape of the matrix
uint32_t offset; //< Offset from the beginning of the buffer (in elements)
uint32_t stride; //< Row stride (in elements)
uint32_t size; //< Total size (in elements)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make a note that in some layouts, the size in number of elements may be larger than the actual number of elements of the matrix? (e.g. when there's mandatory padding, etc).

};

/// Pack a sequence of matrices from row-major into a representation that is
/// optimal for inference/training, or do the reverse.
extern JIT_EXPORT void jit_coop_vec_pack_matrices(uint32_t count,
uint32_t in,
const MatrixDescr *in_descr,
uint32_t out,
const MatrixDescr *out_descr);

/// Query the backend to compute the size of an array/vector in a given layout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify which fields will be ignored in in, and which fields will be filled-in in the output. Are the fields that are not computed by this function copied over from in into the return value?

extern JIT_EXPORT MatrixDescr jit_coop_vec_compute_layout(uint32_t index,
const MatrixDescr *in,
MatrixLayout layout,
uint32_t offset);

/// Perform a matrix-vector multiplication + bias addition
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify which arguments are optional (e.g. b_index = 0 for no bias, if allowed?)

extern JIT_EXPORT uint32_t jit_coop_vec_matvec(uint32_t A_index,
const MatrixDescr *A_descr,
uint32_t x_index,
uint32_t b_index,
const MatrixDescr *b_descr,
int transpose);

/// Accumulate the coop. vector 'index' into the buffer 'target' with offset.
/// Potentially create a new buffer of size 'size' if target == 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Potentially create a new buffer of size 'size' if target == 0.
/// Potentially create a new buffer of size 'target_size' if target == 0.

extern JIT_EXPORT uint32_t jit_coop_vec_accum(uint32_t target,
uint32_t target_size,
uint32_t offset,
uint32_t index);

/// Cast a cooperative vector to a different precision
extern JIT_EXPORT uint32_t jit_coop_vec_cast(uint32_t index, VarType vt);

/// Accumulate the outer product of cooperative vectors 'a' and 'b' into buffer
/// 'target', at a location described by 'descr'. Potentially create a new
/// buffer of size 'size' if target == 0.
extern JIT_EXPORT uint32_t jit_coop_vec_outer_product_accum(
uint32_t target,
uint32_t target_size,
const MatrixDescr *descr,
uint32_t a,
uint32_t b);

#if defined(__cplusplus)
}
#endif
106 changes: 103 additions & 3 deletions src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
#include "profile.h"
#include "array.h"
#include "record_ts.h"
#include "coop_vec.h"
#include <thread>
#include <condition_variable>
#include <drjit-core/half.h>
#include <drjit-core/texture.h>

#if defined(DRJIT_ENABLE_OPTIX)
#include <drjit-core/optix.h>
#include "optix.h"
# include <drjit-core/optix.h>
# include "optix.h"
#endif

#include <nanothread/nanothread.h>
Expand Down Expand Up @@ -334,6 +335,14 @@ void jit_llvm_version(int *major, int *minor, int *patch) {
*patch = jitc_llvm_version_patch;
}

void jit_cuda_version(int *major, int *minor) {
lock_guard guard(state.lock);
if (major)
*major = jitc_cuda_version_major;
if (minor)
*minor = jitc_cuda_version_minor;
}

uint32_t jit_llvm_vector_width() {
return jitc_llvm_vector_width;
}
Expand Down Expand Up @@ -562,7 +571,7 @@ uint32_t jit_var_scatter_inc(uint32_t *target, uint32_t index, uint32_t mask) {
}

uint32_t jit_var_pointer(JitBackend backend, const void *value,
uint32_t dep, int write) {
uint32_t dep, int write) {
lock_guard guard(state.lock);
return jitc_var_pointer(backend, value, dep, write);
}
Expand Down Expand Up @@ -1337,6 +1346,11 @@ uint32_t jit_var_log2_intrinsic(uint32_t a0) {
return jitc_var_log2_intrinsic(a0);
}

uint32_t jit_var_tanh_intrinsic(uint32_t a0) {
lock_guard guard(state.lock);
return jitc_var_tanh_intrinsic(a0);
}

uint32_t jit_var_cast(uint32_t index, VarType target_type,
int reinterpret) {
lock_guard guard(state.lock);
Expand Down Expand Up @@ -1368,6 +1382,7 @@ VarInfo jit_set_backend(uint32_t index) noexcept {
info.state = jitc_var_state(index);
info.size = var->size;
info.is_array = var->is_array();
info.is_coop_vec = var->coop_vec;
info.unaligned = var->unaligned;
if(info.state == VarState::Literal)
info.literal = var->literal;
Expand Down Expand Up @@ -1584,3 +1599,88 @@ void jit_profile_stop() {
if (cuProfilerStart)
cuProfilerStop();
}

uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in) {
lock_guard guard(state.lock);
return jitc_coop_vec_pack(n, in);
}

void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out) {
lock_guard guard(state.lock);
jitc_coop_vec_unpack(index, n, out);
}

uint32_t jit_coop_vec_literal(JitBackend backend, VarType type,
const void *value, size_t size, uint32_t length) {
lock_guard guard(state.lock);
return jitc_coop_vec_literal(backend, type, value, size, length);
}

uint32_t jit_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length) {
lock_guard guard(state.lock);
return jitc_coop_vec_load(buffer, offset, length);
}

uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0) {
lock_guard guard(state.lock);
return jitc_coop_vec_unary_op(op, a0);
}

uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1) {
lock_guard guard(state.lock);
return jitc_coop_vec_binary_op(op, a0, a1);
}

uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2) {
lock_guard guard(state.lock);
return jitc_coop_vec_ternary_op(op, a0, a1, a2);
}

void jit_coop_vec_pack_matrices(uint32_t count,
uint32_t in, const MatrixDescr *in_descr,
uint32_t out, const MatrixDescr *out_descr) {
lock_guard guard(state.lock);
jitc_coop_vec_pack_matrices(count, in, in_descr, out, out_descr);
}

MatrixDescr jit_coop_vec_compute_layout(uint32_t index,
const MatrixDescr *in,
MatrixLayout layout,
uint32_t offset) {
lock_guard guard(state.lock);
return jitc_coop_vec_compute_layout(index, in, layout, offset);
}

uint32_t jit_coop_vec_matvec(uint32_t A_index, const MatrixDescr *A_descr,
uint32_t x_index, uint32_t b_index,
const MatrixDescr *b_descr, int transpose) {
lock_guard guard(state.lock);
return jitc_coop_vec_matvec(A_index, A_descr, x_index, b_index, b_descr,
transpose);
}

uint32_t jit_coop_vec_length(uint32_t index) {
lock_guard guard(state.lock);
const Variable *v = jitc_var(index);
if (!v->coop_vec)
jitc_raise("jit_coop_vec_length(): r%u is not a cooperative vector!", index);
return v->array_length;
}

uint32_t jit_coop_vec_accum(uint32_t target, uint32_t target_size,
uint32_t offset, uint32_t index) {
lock_guard guard(state.lock);
return jitc_coop_vec_accum(target, target_size, offset, index);
}

uint32_t jit_coop_vec_outer_product_accum(uint32_t target, uint32_t target_size,
const MatrixDescr *descr, uint32_t a,
uint32_t b) {
lock_guard guard(state.lock);
return jitc_coop_vec_outer_product_accum(target, target_size, descr, a, b);
}

uint32_t jit_coop_vec_cast(uint32_t index, VarType vt) {
lock_guard guard(state.lock);
return jitc_coop_vec_cast(index, vt);
}
10 changes: 10 additions & 0 deletions src/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "trace.h"
#include "util.h"
#include "var.h"
#include "coop_vec.h"

std::vector<CallData *> calls_assembled;

Expand Down Expand Up @@ -90,7 +91,12 @@ void jitc_var_call(const char *name, bool symbolic, uint32_t self,
} else if (!v->is_literal()) {
jitc_raise("jit_var_call(): input variable r%u must either be a "
"literal or symbolic wrapper around another variable!", in[i]);
} else {
// Literal field, read temporarily stashed size (see
// jitc_var_call_input in var.cpp)
size = std::max(size, v->unused);
}

if (v->size != 1)
jitc_raise(
"jit_var_call(): size of input variable r%u is %u (must be 1)!",
Expand Down Expand Up @@ -620,6 +626,10 @@ void jitc_var_call_analyze(CallData *call, uint32_t inst_id, uint32_t index,
PacketScatterData *psd = (PacketScatterData *) v->data;
for (uint32_t i : psd->values)
jitc_var_call_analyze(call, inst_id, i, data_offset);
} else if (kind == VarKind::CoopVecPack) {
CoopVecPackData *cvid = (CoopVecPackData *) v->data;
for (uint32_t i : cvid->indices)
jitc_var_call_analyze(call, inst_id, i, data_offset);
} else if (kind == VarKind::TraceRay) {
TraceData *td = (TraceData *) v->data;
for (uint32_t index_2: td->indices)
Expand Down
Loading