diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b3ac40d..170d4c01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,8 @@ if (DRJIT_ENABLE_OPTIX) src/optix.h src/optix_api.h src/optix_api.cpp + src/optix_coop_vec.h + src/optix_coop_vec.cpp src/optix_core.cpp) endif() @@ -119,6 +121,7 @@ add_library( src/registry.h src/registry.cpp src/util.h src/util.cpp src/record_ts.h src/record_ts.cpp + src/coop_vec.h src/coop_vec.cpp # CUDA backend src/cuda_api.h @@ -158,6 +161,8 @@ add_library( src/llvm_packet.cpp src/llvm_array.h src/llvm_array.cpp + src/llvm_coop_vec.h + src/llvm_coop_vec.cpp src/io.h src/io.cpp src/eval.h src/eval.cpp diff --git a/ext/nanothread b/ext/nanothread index 7ce0e33a..e76778dc 160000 --- a/ext/nanothread +++ b/ext/nanothread @@ -1 +1 @@ -Subproject commit 7ce0e33a24a6695438f3896636016aac6c93127a +Subproject commit e76778dca2bad7e399cc9e408e371dc73569e125 diff --git a/include/drjit-core/array.h b/include/drjit-core/array.h index 95e61424..184dbfdb 100644 --- a/include/drjit-core/array.h +++ b/include/drjit-core/array.h @@ -406,15 +406,15 @@ template struct JitArray { friend bool any(const JitArray &a) { return jit_var_any(a.m_index); } friend bool none(const JitArray &a) { return !jit_var_any(a.m_index); } - friend const char *label(const JitArray &v) { - return jit_var_label(v.m_index); - } - - friend void set_label(JitArray &v, const char *label) { - uint32_t index = jit_var_set_label(v.m_index, 1, label); - jit_var_dec_ref(v.m_index); - v.m_index = index; - } + friend const char *label(const JitArray &v) { + return jit_var_label(v.m_index); + } + + friend void set_label(JitArray &v, const char *label) { + uint32_t index = jit_var_set_label(v.m_index, 1, label); + jit_var_dec_ref(v.m_index); + v.m_index = index; + } protected: uint32_t m_index = 0; }; diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h index 5c3cbd37..200d3ccf 100644 --- a/include/drjit-core/jit.h +++ b/include/drjit-core/jit.h @@ -179,6 +179,9 @@ extern JIT_EXPORT void* jit_cuda_pop_context(); /// Query the compute capability of the current device (e.g. '52') extern JIT_EXPORT int jit_cuda_compute_capability(); +/// Get the major and minor CUDA version +extern JIT_EXPORT void jit_cuda_version(int *major, int *minor); + /** * \brief Override generated PTX version and compute capability * @@ -637,7 +640,10 @@ enum class JitOp : uint32_t { Rcp, Rsqrt, // Multi-function generator (CUDA) - Sin, Cos, Exp2, Log2, + Sin, Cos, Exp2, Log2, Tanh, + + // Step function + Step, // Total number of operations Count @@ -799,6 +805,9 @@ extern JIT_EXPORT uint32_t jit_var_exp2_intrinsic(uint32_t a0); /// Approximate `log2(a0)` and return a variable representing the result extern JIT_EXPORT uint32_t jit_var_log2_intrinsic(uint32_t a0); +/// Approximate `tanh(a0)` and return a variable representing the result +extern JIT_EXPORT uint32_t jit_var_tanh_intrinsic(uint32_t a0); + /// Return a variable indicating valid lanes within a function call extern JIT_EXPORT uint32_t jit_var_call_mask(JitBackend backend); @@ -2403,6 +2412,7 @@ struct VarInfo { void *data; }; bool is_array; + bool is_coop_vec; bool unaligned; }; @@ -2677,6 +2687,97 @@ extern JIT_EXPORT void jit_freeze_abort(JitBackend backend); */ extern JIT_EXPORT void jit_freeze_destroy(Recording *recording); +// ==================================================================== +// Cooperative vector API +// ==================================================================== + +/// Pack a set of regular Dr.Jit variables to form a cooperative vector +extern JIT_EXPORT uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in); + +/// Unpack a cooperative vector into its components +extern JIT_EXPORT void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out); + +/// Create a cooperative vectors, whose components are a uniform literal constant +extern JIT_EXPORT uint32_t jit_coop_vec_literal(JIT_ENUM JitBackend backend, + JIT_ENUM VarType type, + const void *value, + size_t size JIT_DEF(1), + uint32_t length JIT_DEF(1)); + +/// Load a cooperative vector from memory +extern JIT_EXPORT uint32_t jit_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length); + +/// Determine the length of a cooperative vector +extern JIT_EXPORT uint32_t jit_coop_vec_length(uint32_t index); + +/// Perform a unary operation on a cooperative vector +extern JIT_EXPORT uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0); + +/// Perform a binary operation on a pair of cooperative vectors +extern JIT_EXPORT uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1); + +/// Perform a ternary operation on a triplet of cooperative vectors +extern JIT_EXPORT uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2); + +/// Encodes a type of request for jit_coop_vec_pack_matrices() +enum class MatrixLayout : uint32_t { + RowMajor, + InferencingOptimal, + TrainingOptimal +}; + +/// Summary of a matrix/vector that has been packed into a buffer +struct MatrixDescr { + VarType dtype; //< Variable type + MatrixLayout layout; //< Layout type + uint32_t rows, cols; //< Shape of the matrix + uint32_t offset; //< Offset from the beginning of the buffer (in elements) + uint32_t stride; //< Row stride (in elements) + uint32_t size; //< Total size (in elements) +}; + +/// Pack a sequence of matrices from row-major into a representation that is +/// optimal for inference/training, or do the reverse. +extern JIT_EXPORT void jit_coop_vec_pack_matrices(uint32_t count, + uint32_t in, + const MatrixDescr *in_descr, + uint32_t out, + const MatrixDescr *out_descr); + +/// Query the backend to compute the size of an array/vector in a given layout +extern JIT_EXPORT MatrixDescr jit_coop_vec_compute_layout(uint32_t index, + const MatrixDescr *in, + MatrixLayout layout, + uint32_t offset); + +/// Perform a matrix-vector multiplication + bias addition +extern JIT_EXPORT uint32_t jit_coop_vec_matvec(uint32_t A_index, + const MatrixDescr *A_descr, + uint32_t x_index, + uint32_t b_index, + const MatrixDescr *b_descr, + int transpose); + +/// Accumulate the coop. vector 'index' into the buffer 'target' with offset. +/// Potentially create a new buffer of size 'size' if target == 0. +extern JIT_EXPORT uint32_t jit_coop_vec_accum(uint32_t target, + uint32_t target_size, + uint32_t offset, + uint32_t index); + +/// Cast a cooperative vector to a different precision +extern JIT_EXPORT uint32_t jit_coop_vec_cast(uint32_t index, VarType vt); + +/// Accumulate the outer product of cooperative vectors 'a' and 'b' into buffer +/// 'target', at a location described by 'descr'. Potentially create a new +/// buffer of size 'size' if target == 0. +extern JIT_EXPORT uint32_t jit_coop_vec_outer_product_accum( + uint32_t target, + uint32_t target_size, + const MatrixDescr *descr, + uint32_t a, + uint32_t b); + #if defined(__cplusplus) } #endif diff --git a/src/api.cpp b/src/api.cpp index f9616053..e335b4c2 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -22,14 +22,15 @@ #include "profile.h" #include "array.h" #include "record_ts.h" +#include "coop_vec.h" #include #include #include #include #if defined(DRJIT_ENABLE_OPTIX) -#include -#include "optix.h" +# include +# include "optix.h" #endif #include @@ -334,6 +335,14 @@ void jit_llvm_version(int *major, int *minor, int *patch) { *patch = jitc_llvm_version_patch; } +void jit_cuda_version(int *major, int *minor) { + lock_guard guard(state.lock); + if (major) + *major = jitc_cuda_version_major; + if (minor) + *minor = jitc_cuda_version_minor; +} + uint32_t jit_llvm_vector_width() { return jitc_llvm_vector_width; } @@ -562,7 +571,7 @@ uint32_t jit_var_scatter_inc(uint32_t *target, uint32_t index, uint32_t mask) { } uint32_t jit_var_pointer(JitBackend backend, const void *value, - uint32_t dep, int write) { + uint32_t dep, int write) { lock_guard guard(state.lock); return jitc_var_pointer(backend, value, dep, write); } @@ -1337,6 +1346,11 @@ uint32_t jit_var_log2_intrinsic(uint32_t a0) { return jitc_var_log2_intrinsic(a0); } +uint32_t jit_var_tanh_intrinsic(uint32_t a0) { + lock_guard guard(state.lock); + return jitc_var_tanh_intrinsic(a0); +} + uint32_t jit_var_cast(uint32_t index, VarType target_type, int reinterpret) { lock_guard guard(state.lock); @@ -1368,6 +1382,7 @@ VarInfo jit_set_backend(uint32_t index) noexcept { info.state = jitc_var_state(index); info.size = var->size; info.is_array = var->is_array(); + info.is_coop_vec = var->coop_vec; info.unaligned = var->unaligned; if(info.state == VarState::Literal) info.literal = var->literal; @@ -1584,3 +1599,88 @@ void jit_profile_stop() { if (cuProfilerStart) cuProfilerStop(); } + +uint32_t jit_coop_vec_pack(uint32_t n, const uint32_t *in) { + lock_guard guard(state.lock); + return jitc_coop_vec_pack(n, in); +} + +void jit_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out) { + lock_guard guard(state.lock); + jitc_coop_vec_unpack(index, n, out); +} + +uint32_t jit_coop_vec_literal(JitBackend backend, VarType type, + const void *value, size_t size, uint32_t length) { + lock_guard guard(state.lock); + return jitc_coop_vec_literal(backend, type, value, size, length); +} + +uint32_t jit_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length) { + lock_guard guard(state.lock); + return jitc_coop_vec_load(buffer, offset, length); +} + +uint32_t jit_coop_vec_unary_op(JitOp op, uint32_t a0) { + lock_guard guard(state.lock); + return jitc_coop_vec_unary_op(op, a0); +} + +uint32_t jit_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1) { + lock_guard guard(state.lock); + return jitc_coop_vec_binary_op(op, a0, a1); +} + +uint32_t jit_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2) { + lock_guard guard(state.lock); + return jitc_coop_vec_ternary_op(op, a0, a1, a2); +} + +void jit_coop_vec_pack_matrices(uint32_t count, + uint32_t in, const MatrixDescr *in_descr, + uint32_t out, const MatrixDescr *out_descr) { + lock_guard guard(state.lock); + jitc_coop_vec_pack_matrices(count, in, in_descr, out, out_descr); +} + +MatrixDescr jit_coop_vec_compute_layout(uint32_t index, + const MatrixDescr *in, + MatrixLayout layout, + uint32_t offset) { + lock_guard guard(state.lock); + return jitc_coop_vec_compute_layout(index, in, layout, offset); +} + +uint32_t jit_coop_vec_matvec(uint32_t A_index, const MatrixDescr *A_descr, + uint32_t x_index, uint32_t b_index, + const MatrixDescr *b_descr, int transpose) { + lock_guard guard(state.lock); + return jitc_coop_vec_matvec(A_index, A_descr, x_index, b_index, b_descr, + transpose); +} + +uint32_t jit_coop_vec_length(uint32_t index) { + lock_guard guard(state.lock); + const Variable *v = jitc_var(index); + if (!v->coop_vec) + jitc_raise("jit_coop_vec_length(): r%u is not a cooperative vector!", index); + return v->array_length; +} + +uint32_t jit_coop_vec_accum(uint32_t target, uint32_t target_size, + uint32_t offset, uint32_t index) { + lock_guard guard(state.lock); + return jitc_coop_vec_accum(target, target_size, offset, index); +} + +uint32_t jit_coop_vec_outer_product_accum(uint32_t target, uint32_t target_size, + const MatrixDescr *descr, uint32_t a, + uint32_t b) { + lock_guard guard(state.lock); + return jitc_coop_vec_outer_product_accum(target, target_size, descr, a, b); +} + +uint32_t jit_coop_vec_cast(uint32_t index, VarType vt) { + lock_guard guard(state.lock); + return jitc_coop_vec_cast(index, vt); +} diff --git a/src/call.cpp b/src/call.cpp index 2ee1c84e..5b5e6d5e 100644 --- a/src/call.cpp +++ b/src/call.cpp @@ -21,6 +21,7 @@ #include "trace.h" #include "util.h" #include "var.h" +#include "coop_vec.h" std::vector calls_assembled; @@ -90,7 +91,12 @@ void jitc_var_call(const char *name, bool symbolic, uint32_t self, } else if (!v->is_literal()) { jitc_raise("jit_var_call(): input variable r%u must either be a " "literal or symbolic wrapper around another variable!", in[i]); + } else { + // Literal field, read temporarily stashed size (see + // jitc_var_call_input in var.cpp) + size = std::max(size, v->unused); } + if (v->size != 1) jitc_raise( "jit_var_call(): size of input variable r%u is %u (must be 1)!", @@ -620,6 +626,10 @@ void jitc_var_call_analyze(CallData *call, uint32_t inst_id, uint32_t index, PacketScatterData *psd = (PacketScatterData *) v->data; for (uint32_t i : psd->values) jitc_var_call_analyze(call, inst_id, i, data_offset); + } else if (kind == VarKind::CoopVecPack) { + CoopVecPackData *cvid = (CoopVecPackData *) v->data; + for (uint32_t i : cvid->indices) + jitc_var_call_analyze(call, inst_id, i, data_offset); } else if (kind == VarKind::TraceRay) { TraceData *td = (TraceData *) v->data; for (uint32_t index_2: td->indices) diff --git a/src/coop_vec.cpp b/src/coop_vec.cpp new file mode 100644 index 00000000..405fa4f2 --- /dev/null +++ b/src/coop_vec.cpp @@ -0,0 +1,724 @@ +/* + src/coop_vec.cpp -- Backend-independent parts of the Cooperative Vector API + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "var.h" +#include "coop_vec.h" +#include "internal.h" +#include "log.h" +#include "op.h" +#include "optix_api.h" +#include "optix.h" +#include "cuda.h" +#include + +static uint32_t unwrap(uint32_t index) { + while (true) { + const Variable *v = jitc_var(index); + if (v->kind != (uint32_t) VarKind::LoopPhi) + return index; + index = borrow(v->dep[3]); + } +} + +uint32_t jitc_coop_vec_pack(uint32_t n, const uint32_t *in) { + if (n == 0) + jitc_raise("jit_coop_vec_pack(): vector cannot be empty!"); + if (n > 0xFFFF) + jitc_raise("jit_coop_vec_pack(): cooperative vector is too large!"); + + for (uint32_t i = 0; i < n; ++i) { + if (in[i] == 0) + jitc_raise("jit_coop_vec_pack(): argument %u is uninitialized!", i); + } + + const Variable *arg_v = jitc_var(in[0]); + if (arg_v->backend == (uint32_t) JitBackend::CUDA) { + bool coop_vec_supported = + (jitc_cuda_version_major == 12 && jitc_cuda_version_minor >= 8) || + jitc_cuda_version_major > 12; + if (!coop_vec_supported) + jitc_raise("jit_coop_vec_pack(): The use of cooperative vectors on " + "the CUDA/OptiX backend requires CUDA 12.8 or newer " + "(driver R570+)."); + } + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecPack; + v.type = arg_v->type; + v.size = arg_v->size; + v.backend = arg_v->backend; + v.array_length = n; + v.coop_vec = true; + v.optix = v.backend == (uint32_t) JitBackend::CUDA; + + drjit::unique_ptr cvid = new CoopVecPackData(); + bool is_literal = true; + uint64_t literal = arg_v->literal; + cvid->indices.reserve(n); + + jitc_log(Debug, "jit_coop_vec_pack(): building a cooperative vector with %u elements", n); + for (uint32_t i = 0; i < n; ++i) { + uint32_t index = in[i]; + const Variable *v2 = jitc_var(index); + v.size = std::max(v.size, v2->size); + if (v2->backend != v.backend || v2->type != v.type) + jitc_raise("jit_coop_vec_pack(): inputs must have compatible types and backends!"); + + if (!v2->is_literal() || v2->literal != literal) + is_literal = false; + + jitc_var_inc_ref(index); + cvid->indices.push_back(index); + jitc_log(Debug, " - entry %u: r%u", i, index); + } + + if (is_literal) + return jitc_coop_vec_literal((JitBackend) v.backend, (VarType) v.type, + &literal, v.size, n); + + uint32_t result = jitc_var_new(v, true); + jitc_var(result)->data = cvid.get(); + + jitc_var_set_callback( + result, + [](uint32_t, int free, void *p) { + if (free) + delete (CoopVecPackData *) p; + }, + cvid.release(), true); + + return result; +} + +void jitc_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out) { + Variable *vec_v = jitc_var(index); + Variable v; + + if (!vec_v->coop_vec) + jitc_raise("jit_coop_vec_unpack(): source must be a cooperative vector!"); + if (vec_v->array_length != n) + jitc_raise("jit_coop_vec_unpack(): internal error, array length did not match!"); + + uint32_t length = vec_v->array_length; + if (vec_v->is_coop_vec_literal()) { + uint64_t literal = vec_v->literal; + Ref r = steal(jitc_var_literal((JitBackend) vec_v->backend, + (VarType) vec_v->type, &literal, + vec_v->size, 0)); + for (uint32_t i = 0; i < length; ++i) { + jitc_var_inc_ref(r); + out[i] = r; + } + return; + } + + v.kind = (uint32_t) VarKind::CoopVecUnpack; + v.type = vec_v->type; + v.size = vec_v->size; + v.backend = vec_v->backend; + v.dep[0] = index; + + for (uint32_t i = 0; i < length; ++i) { + jitc_var_inc_ref(index); + v.literal = i; + out[i] = jitc_var_new(v); + } +} + +uint32_t jitc_coop_vec_literal(JitBackend backend, + VarType type, + const void *value, + size_t size, + uint32_t length) { + if (unlikely(size == 0)) + return 0; + + Variable v; + memcpy(&v.literal, value, type_size[(uint32_t) type]); + v.kind = (uint32_t) VarKind::CoopVecLiteral; + v.type = (uint32_t) type; + v.size = (uint32_t) size; + v.backend = (uint32_t) backend; + v.array_length = length; + v.coop_vec = true; + v.optix = v.backend == (uint32_t) JitBackend::CUDA; + + return jitc_var_new(v); +} + +uint32_t jitc_coop_vec_load(uint32_t buffer, uint32_t offset, uint32_t length) { + VarType vt; + JitBackend backend; + { + Variable *buffer_v = jitc_var(buffer); + vt = (VarType) buffer_v->type; + backend = (JitBackend) buffer_v->backend; + } + + void *p = nullptr; + Ref tmp = steal(jitc_var_data(buffer, false, &p)); + Ref buf_ptr = steal(jitc_var_pointer(backend, p, tmp, 0)); + + Ref mask = steal(jitc_var_bool(backend, true)); + mask = steal(jitc_var_mask_apply(mask, 1)); + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecLoad; + v.type = (uint32_t) vt; + v.size = 1; + v.backend = (uint32_t) backend; + v.array_length = length; + v.literal = offset; + v.coop_vec = true; + v.optix = backend == JitBackend::CUDA; + v.dep[0] = buf_ptr; + v.dep[1] = mask; + jitc_var_inc_ref(buf_ptr); + jitc_var_inc_ref(mask); + + return jitc_var_new(v); +} + +uint32_t jitc_coop_vec_unary_op(JitOp op, uint32_t a0) { + if (!a0) + return 0; + + Variable *a0_v = jitc_var(a0); + Variable v; + v.kind = (uint32_t) VarKind::CoopVecUnaryOp; + v.literal = (uint32_t) op; + v.type = a0_v->type; + v.size = a0_v->size; + v.backend = a0_v->backend; + v.array_length = a0_v->array_length; + v.coop_vec = true; + v.dep[0] = a0; + jitc_var_inc_ref(a0, a0_v); + return jitc_var_new(v); +} + +uint32_t jitc_coop_vec_cast(uint32_t index, VarType vt) { + if (!index) + return 0; + + Variable *prev_v = jitc_var(index); + if ((VarType) prev_v->type == vt) { + jitc_var_inc_ref(index, prev_v); + return index; + } + + /// The OptiX conversion intrinsic is currently too limited + if ((JitBackend) prev_v->backend == JitBackend::CUDA) { + uint32_t n = prev_v->array_length; + if (n == 0) // just here to silence a GCC warning.. + return 0; + uint32_t *tmp1 = (uint32_t *) alloca(sizeof(uint32_t) * n), + *tmp2 = (uint32_t *) alloca(sizeof(uint32_t) * n); + jitc_coop_vec_unpack(index, n, tmp1); + for (uint32_t i = 0; i < n; ++i) + tmp2[i] = jitc_var_cast(tmp1[i], vt, false); + uint32_t result = jitc_coop_vec_pack(n, tmp2); + for (uint32_t i = 0; i < n; ++i) { + jitc_var_dec_ref(tmp1[i]); + jitc_var_dec_ref(tmp2[i]); + } + return result; + } + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecCast; + v.type = (uint32_t) vt; + v.size = prev_v->size; + v.backend = prev_v->backend; + v.array_length = prev_v->array_length; + v.coop_vec = true; + v.dep[0] = index; + jitc_var_inc_ref(index, prev_v); + + return jitc_var_new(v); +} + +uint32_t jitc_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1) { + if (!a0 || !a1) + jitc_raise("jit_coop_vec_binary_op(): detected uninitialized inputs!"); + + Variable *a0_v = jitc_var(a0), + *a1_v = jitc_var(a1); + + if (a0_v->array_length != a1_v->array_length) + jitc_raise("jit_coop_vec_binary_op(): the cooperative vectors have " + "incompatible lengths (%u and %u)!", + a0_v->array_length, a1_v->array_length); + + if (a0_v->type != a1_v->type) + jitc_raise("jit_coop_vec_binary_op(): the cooperative vectors have " + "incompatible types (%s and %s)!", + type_name[a0_v->type], type_name[a1_v->type]); + + if (!(a0_v->size == a1_v->size || a1_v->size == 1 || a0_v->size == 1)) + jitc_raise( + "jit_coop_vec_binary_op(): incompatible thread count (%u and %u)!", + a0_v->size, a1_v->size); + + uint32_t max_size = std::max(a0_v->size, a1_v->size); + + // Exploit some basic optimization opportunities (useful for AD) + switch (op) { + case JitOp::Add: + if (jitc_is_any_zero(a0_v)) { return jitc_var_resize(a1, max_size); } + if (jitc_is_any_zero(a1_v)) { return jitc_var_resize(a0, max_size); } + break; + + case JitOp::Mul: + if (jitc_is_one(a0_v)) { return jitc_var_resize(a1, max_size); } + if (jitc_is_one(a1_v)) { return jitc_var_resize(a0, max_size); } + break; + + default: + break; + } + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecBinaryOp; + v.literal = (uint32_t) op; + v.type = a0_v->type; + v.size = max_size; + v.backend = a0_v->backend; + v.array_length = a0_v->array_length; + v.coop_vec = true; + v.dep[0] = a0; + v.dep[1] = a1; + jitc_var_inc_ref(a0, a0_v); + jitc_var_inc_ref(a1, a1_v); + return jitc_var_new(v); +} + +uint32_t jitc_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2) { + if (!a0 || !a1 || !a2) + jitc_raise("jit_coop_vec_ternary_op(): detected uninitialized inputs!"); + + Variable *a0_v = jitc_var(a0), + *a1_v = jitc_var(a1), + *a2_v = jitc_var(a2); + + uint32_t max_size = std::max(std::max(a0_v->size, a1_v->size), a2_v->size); + + if (a0_v->array_length != a1_v->array_length || a0_v->array_length != a2_v->array_length) + jitc_raise("jit_coop_vec_ternary_op(): the cooperative vectors have an " + "incompatible size (%u, %u, and %u)!", + a0_v->array_length, a1_v->array_length, a2_v->array_length); + + if (a0_v->type != a1_v->type || a0_v->type != a1_v->type) + jitc_raise("jit_coop_vec_ternary_op(): the cooperative vectors have " + "incompatible types (%s, %s, and %s)!", + type_name[a0_v->type], type_name[a1_v->type], type_name[a2_v->type]); + + if (!(a0_v->size == max_size || a0_v->size == 1) || + !(a1_v->size == max_size || a1_v->size == 1) || + !(a2_v->size == max_size || a2_v->size == 1)) + jitc_raise( + "jit_coop_vec_ternary_op(): incompatible thread count (%u, %u, and %u)!", + a0_v->size, a1_v->size, a2_v->size); + + // Exploit some basic optimization opportunities (useful for AD) + if (op == JitOp::Fma) { + if (jitc_is_one(a0_v)) { + Ref result = steal(jitc_coop_vec_binary_op(JitOp::Add, a1, a2)); + return jitc_var_resize(result, max_size); + } + if (jitc_is_one(a1_v)) { + Ref result = steal(jitc_coop_vec_binary_op(JitOp::Add, a0, a2)); + return jitc_var_resize(result, max_size); + } + if (jitc_is_any_zero(a1_v)) { + Ref result = steal(jitc_coop_vec_binary_op(JitOp::Mul, a0, a1)); + return jitc_var_resize(result, max_size); + } + } + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecTernaryOp; + v.literal = (uint32_t) op; + v.type = a0_v->type; + v.size = max_size; + v.backend = a0_v->backend; + v.array_length = a0_v->array_length; + v.coop_vec = true; + v.dep[0] = a0; + v.dep[1] = a1; + v.dep[2] = a2; + jitc_var_inc_ref(a0, a0_v); + jitc_var_inc_ref(a1, a1_v); + jitc_var_inc_ref(a2, a2_v); + return jitc_var_new(v); +} + +MatrixDescr jitc_coop_vec_compute_layout(uint32_t index, + const MatrixDescr *in_, + MatrixLayout layout, + uint32_t offset) { + const MatrixDescr in = *in_; + const bool is_vector = in.cols == 1; + +#if defined(DRJIT_ENABLE_OPTIX) + JitBackend backend; + VarType vt; + + { + const Variable *v = jitc_var(index); + vt = (VarType) v->type; + backend = (JitBackend) v->backend; + } + + uint32_t tsize = type_size[(uint32_t) vt]; + + if (backend == JitBackend::CUDA) { + uint32_t offset_in_bytes = in.offset * tsize; + if (offset_in_bytes % 64 != 0) + jitc_raise( + "jit_coop_vec_compute_layout(): OptiX requires input matrices " + "to be 64-byte aligned. Encountered an input with " + "offset %u, which is not divisible by 64.", offset_in_bytes); + + uint32_t out_align = is_vector ? 16 : 64; + offset = (ceil_div(offset * tsize, out_align) * out_align) / tsize; + } +#else + (void) index; +#endif + + MatrixDescr r; + r.dtype = in.dtype; + r.layout = is_vector ? MatrixLayout::RowMajor : layout; + r.rows = in.rows; + r.cols = in.cols; + r.offset = offset; + r.stride = r.cols; + r.size = (r.rows - 1) * r.stride + r.cols; + +#if defined(DRJIT_ENABLE_OPTIX) + if (backend == JitBackend::CUDA && r.layout != MatrixLayout::RowMajor) { + OptixDeviceContext ctx = jitc_optix_context(); + uint32_t type_id = jitc_optix_coop_vec_type_id(vt), + layout_id = jitc_optix_coop_vec_layout_id(layout); + size_t size = 0; + + if (vt != VarType::Float16) + jitc_raise( + "jit_coop_vec_compute_layout(): CUDA/OptiX conversion to " + "optimal layout is currently limited to half precision data."); + + if (!optixCoopVecMatrixComputeSize) + jitc_raise("jit_coop_vec_compute_layout(): Cooperative vectors are not " + "supported by your NVIDIA GPU driver. Please install " + "driver version 570 or newer."); + + jitc_optix_check(optixCoopVecMatrixComputeSize( + ctx, in.rows, in.cols, type_id, layout_id, 0, &size)); + r.stride = 0; + r.size = (uint32_t) size / tsize; + } +#endif + + return r; +} + +void jitc_coop_vec_pack_matrices(uint32_t count, + uint32_t in, + const MatrixDescr *in_descr, + uint32_t out, + const MatrixDescr *out_descr) { + void *in_p = nullptr, *out_p = nullptr; + Ref in_data = steal(jitc_var_data(in, true, &in_p)); + Ref out_data = steal(jitc_var_data(out, true, &out_p)); + + JitBackend backend; + { + const Variable *out_v = jitc_var(out); + jitc_log(Debug, "jit_coop_vec_pack(): packing %u %s, %u bytes, r%u -> r%u", + count, count == 1 ? "matrix" : "matrices", out_v->size * type_size[out_v->type], in, out); + backend = (JitBackend) out_v->backend; + } + + thread_state(backend)->coop_vec_pack(count, in_p, in_descr, out_p, out_descr); +} + +uint32_t jitc_coop_vec_matvec(uint32_t A_index, + const MatrixDescr *A_descr, + uint32_t x_index, + uint32_t b_index, + const MatrixDescr *b_descr, + int transpose) { + + if (!A_index || !x_index) + jitc_raise("jit_coop_vec_matvec(): detected uninitialized inputs!"); + + VarType a_vt = VarType::Void, + b_vt = VarType::Void, + x_vt = VarType::Void; + uint32_t size; + JitBackend backend; + + uint32_t input_length = transpose ? A_descr->rows : A_descr->cols, + output_length = transpose ? A_descr->cols : A_descr->rows; + + drjit::unique_ptr cvmvd = new CoopVecMatVecData(); + { + Variable *x_v = jitc_var(x_index); + x_vt = (VarType) x_v->type; + backend = (JitBackend) x_v->backend; + size = x_v->size; + + if (x_v->array_length != input_length) + jitc_raise( + "jit_coop_vec_matvec(): incompatible shapes. Attempted to " + "multiply a %ux%u matrix by a vector with %u elements.", + output_length, input_length, x_v->array_length); + } + + A_index = unwrap(A_index); + if (b_index) + b_index = unwrap(b_index); + + Ref a_ptr, b_ptr; + { + void *p = nullptr; + Ref tmp = steal(jitc_var_data(A_index, false, &p)); + + a_vt = (VarType) jitc_var(tmp)->type; + a_ptr = steal(jitc_var_pointer(backend, p, tmp, 0)); + cvmvd->A_descr = *A_descr; + + if (backend == JitBackend::CUDA) { + uint32_t tsize = type_size[(int) a_vt], + offset_in_bytes = A_descr->offset * tsize, + stride_in_bytes = A_descr->stride * tsize; + + if (offset_in_bytes % 64) + jitc_raise("jit_coop_vec_matvec(): matrix offset (%u bytes) " + "must be 64-byte aligned.\n", offset_in_bytes); + + if (stride_in_bytes % 16) + jitc_raise("jit_coop_vec_matvec(): matrix stride (%u bytes) " + "must be 16-byte aligned.\n", stride_in_bytes); + } + } + + if (b_index && b_descr) { + void *p = nullptr; + Ref tmp = steal(jitc_var_data(b_index, false, &p)); + + b_vt = (VarType) jitc_var(tmp)->type; + b_ptr = steal(jitc_var_pointer(backend, p, tmp, 0)); + cvmvd->b_descr = *b_descr; + + if (b_descr->rows != output_length || b_descr->cols != 1) + jitc_raise( + "jit_coop_vec_matvec(): 'b' vector has an incompatible shape " + "(expected (%u x 1), got (%u x %u)).", + output_length, b_descr->rows, b_descr->cols); + + if (b_descr->stride != 1) + jitc_raise( + "jit_coop_vec_matvec(): 'b' vector must be tightly packed."); + } + + cvmvd->transpose = transpose; + + bool supported = false, is_llvm = backend == JitBackend::LLVM; + supported |= a_vt == VarType::Float16 && x_vt == VarType::Float16 && + (b_vt == VarType::Void || b_vt == VarType::Float16); + supported |= is_llvm && (a_vt == VarType::Float32 && x_vt == VarType::Float32 && + (b_vt == VarType::Void || b_vt == VarType::Float32)); + + if (!supported) + jitc_raise("jit_coop_vec_matvec(): incompatible input types " + "(currently, only float16 is supported on the CUDA/OptiX)!"); + + Ref mask = steal(jitc_var_bool(backend, true)); + mask = steal(jitc_var_mask_apply(mask, size)); + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecMatVec; + v.type = (uint32_t) x_vt; + v.size = std::max(size, jitc_var(mask)->size); + v.backend = (uint32_t) backend; + v.array_length = output_length; + v.coop_vec = true; + v.dep[0] = a_ptr; + v.dep[1] = x_index; + v.dep[2] = mask; + v.dep[3] = b_ptr; + jitc_var_inc_ref(a_ptr); + jitc_var_inc_ref(x_index); + jitc_var_inc_ref(b_ptr); + jitc_var_inc_ref(mask); + + uint32_t result = jitc_var_new(v, true); + jitc_var(result)->data = cvmvd.get(); + jitc_var_set_callback( + result, + [](uint32_t, int free, void *p) { + if (free) + delete (CoopVecMatVecData *) p; + }, + cvmvd.release(), true); + + return result; +} + +uint32_t jitc_coop_vec_accum(uint32_t target_, uint32_t target_size, + uint32_t offset, uint32_t index) { + JitBackend backend; + VarType vt; + uint32_t size; + { + const Variable *v = jitc_var(index); + backend = (JitBackend) v->backend; + vt = (VarType) v->type; + size = v->size; + + if (backend == JitBackend::CUDA && + !(vt == VarType::Float16 || vt == VarType::Float32)) + jitc_raise( + "jit_coop_vec_accum(): this operation is restricted to " + "float16 precision on the CUDA/OptiX backend."); + } + + if (target_) + target_ = unwrap(target_); + + Ref target = borrow(target_); + if (!target) { + uint64_t z = 0; + target = steal(jitc_var_literal(backend, vt, &z, target_size, true)); + } else { + const Variable *target_v = jitc_var(target); + // Copy-on-Write logic. See the same line in jitc_var_scatter() for details + if (target_v->ref_count != 2 && target_v->ref_count_stashed != 1) + target = steal(jitc_var_copy(target)); + + if ((VarType) target_v->type != vt) + jitc_raise("jit_coop_vec_outer_product_accum(): source/target " + "buffers have an incompatible type!"); + } + + void *p = nullptr; + Ref tmp = steal(jitc_var_data(target, false, &p)); + Ref target_ptr = steal(jitc_var_pointer(backend, p, tmp, 1)); + + Ref mask = steal(jitc_var_bool(backend, true)); + mask = steal(jitc_var_mask_apply(mask, size)); + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecAccum; + v.type = (uint32_t) VarType::Void; + v.size = std::max(size, jitc_var(mask)->size); + v.backend = (uint32_t) backend; + v.literal = offset; + v.symbolic = jitc_flag(JitFlag::SymbolicScope); + v.dep[0] = target_ptr; + v.dep[1] = index; + v.dep[2] = mask; + jitc_var_inc_ref(target_ptr); + jitc_var_inc_ref(index); + jitc_var_inc_ref(mask); + + uint32_t result = jitc_var_new(v, true); + jitc_var_mark_side_effect(result); + return target.release(); +} + +uint32_t jitc_coop_vec_outer_product_accum(uint32_t target_, + uint32_t target_size, + const MatrixDescr *descr, + uint32_t a, uint32_t b) { + JitBackend backend; + VarType vt; + uint32_t size; + + if (target_) + target_ = unwrap(target_); + + { + const Variable *v_a = jitc_var(a), + *v_b = jitc_var(b); + + if (!v_a->coop_vec || !v_b->coop_vec || v_a->type != v_b->type) + jitc_raise("jit_coop_vec_outer_product_accum(): 'a' and 'b' must " + "be cooperative vectors of a compatible type!"); + + backend = (JitBackend) v_a->backend; + vt = (VarType) v_a->type; + size = std::max(v_a->size, v_b->size); + + if (backend == JitBackend::CUDA) { + if (vt != VarType::Float16) + jitc_raise( + "jit_coop_vec_outer_product_accum(): this operation is " + "restricted to float16 precision on the CUDA/OptiX backend."); + if (descr->layout != MatrixLayout::TrainingOptimal) + jitc_raise("jit_coop_vec_outer_product_accum(): the matrix " + "must be in training-optimal layout!"); + } + } + + Ref target = borrow(target_); + if (!target) { + uint64_t z = 0; + target = steal(jitc_var_literal(backend, vt, &z, target_size, true)); + } else { + const Variable *target_v = jitc_var(target); + // Copy-on-Write logic. See the same line in jitc_var_scatter() for details + if (target_v->ref_count != 2 && target_v->ref_count_stashed != 1) + target = steal(jitc_var_copy(target)); + + if ((VarType) target_v->type != vt) + jitc_raise("jit_coop_vec_outer_product_accum(): source/target " + "buffers have an incompatible type!"); + } + + void *p = nullptr; + Ref tmp = steal(jitc_var_data(target, false, &p)); + Ref target_ptr = steal(jitc_var_pointer(backend, p, tmp, 1)); + + Ref mask = steal(jitc_var_bool(backend, true)); + mask = steal(jitc_var_mask_apply(mask, size)); + + Variable v; + v.kind = (uint32_t) VarKind::CoopVecOuterProductAccum; + v.type = (uint32_t) VarType::Void; + v.size = std::max(size, jitc_var(mask)->size); + v.backend = (uint32_t) backend; + v.symbolic = jitc_flag(JitFlag::SymbolicScope); + v.dep[0] = target_ptr; + v.dep[1] = a; + v.dep[2] = b; + v.dep[3] = mask; + jitc_var_inc_ref(target_ptr); + jitc_var_inc_ref(a); + jitc_var_inc_ref(b); + jitc_var_inc_ref(mask); + + uint32_t result = jitc_var_new(v, true); + jitc_var_mark_side_effect(result); + + drjit::unique_ptr md = new MatrixDescr(*descr); + + jitc_var(result)->data = md.get(); + jitc_var_set_callback( + result, + [](uint32_t, int free, void *p) { + if (free) + delete (MatrixDescr *) p; + }, + md.release(), true); + + return target.release(); +} diff --git a/src/coop_vec.h b/src/coop_vec.h new file mode 100644 index 00000000..3e9d959b --- /dev/null +++ b/src/coop_vec.h @@ -0,0 +1,70 @@ +/* + src/coop_vec.h -- Backend-independent parts of the Cooperative Vector API + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +#include + +struct CoopVecPackData { + std::vector indices; + + ~CoopVecPackData() { + for (uint32_t index: indices) + jitc_var_dec_ref(index); + } +}; + +struct CoopVecMatVecData { + MatrixDescr A_descr; + MatrixDescr b_descr; + bool transpose; +}; + +extern uint32_t jitc_coop_vec_pack(uint32_t n, const uint32_t *in); + +extern void jitc_coop_vec_unpack(uint32_t index, uint32_t n, uint32_t *out); + +extern uint32_t jitc_coop_vec_literal(JitBackend backend, VarType type, + const void *value, size_t size, + uint32_t length); + +extern uint32_t jitc_coop_vec_load(uint32_t buffer, uint32_t offset, + uint32_t length); + +extern uint32_t jitc_coop_vec_unpack(uint32_t vec, uint32_t index); + +extern uint32_t jitc_coop_vec_unary_op(JitOp op, uint32_t a0); + +extern uint32_t jitc_coop_vec_binary_op(JitOp op, uint32_t a0, uint32_t a1); + +extern uint32_t jitc_coop_vec_ternary_op(JitOp op, uint32_t a0, uint32_t a1, uint32_t a2); + +extern void jitc_coop_vec_pack_matrices(uint32_t count, uint32_t in, + const MatrixDescr *in_descr, + uint32_t out, + const MatrixDescr *out_descr); + +extern MatrixDescr jitc_coop_vec_compute_layout(uint32_t index, + const MatrixDescr *in, + MatrixLayout layout, + uint32_t offset); + +extern uint32_t jitc_coop_vec_matvec(uint32_t A_index, + const MatrixDescr *A_descr, + uint32_t x_index, uint32_t b_index, + const MatrixDescr *b_descr, int transpose); + +extern uint32_t jitc_coop_vec_accum(uint32_t target, uint32_t size, + uint32_t offset, uint32_t index); + +extern uint32_t jitc_coop_vec_outer_product_accum(uint32_t target, uint32_t size, + const MatrixDescr *descr, + uint32_t a, uint32_t b); + +extern uint32_t jitc_coop_vec_cast(uint32_t index, VarType vt); diff --git a/src/cuda_core.cpp b/src/cuda_core.cpp index cbcbdda9..60da1215 100644 --- a/src/cuda_core.cpp +++ b/src/cuda_core.cpp @@ -169,6 +169,7 @@ bool jitc_cuda_init() { bool cuda_12_1_or_newer = (jitc_cuda_version_major > 12 || (jitc_cuda_version_major == 12 && jitc_cuda_version_minor >= 1)); + jitc_cuda_arg_limit = cuda_12_1_or_newer ? 4096 : 512; jitc_log(Info, "jit_cuda_init(): enabling CUDA backend (version %i.%i)", @@ -355,24 +356,38 @@ bool jitc_cuda_init() { device.sm_count = (uint32_t) sm_count; device.memory_pool = memory_pool != 0; device.preemptable = preemptable; - device.compute_capability = 50; - device.ptx_version = 65; device.context = context; cuda_check(cuStreamCreate(&device.stream, CU_STREAM_DEFAULT)); cuda_check(cuEventCreate(&device.event, CU_EVENT_DISABLE_TIMING)); cuda_check(cuEventCreate(&device.sync_stream_event, CU_EVENT_DISABLE_TIMING)); - const uint32_t sm_table[][2] = { { 70, 65 }, { 71, 65 }, { 75, 65 }, { 80, 70 }, - { 86, 71 }, { 89, 78 }, { 90, 78 } }; - uint32_t cc_combined = cc_major * 10 + cc_minor; - for (int j = 0; j < 7; ++j) { - if (cc_combined >= sm_table[j][0]) { - device.compute_capability = sm_table[j][0]; - device.ptx_version = sm_table[j][1]; - } + uint32_t driver_to_ptx_isa_mappling[43][2] = { + { 10, 10 }, { 11, 11 }, { 20, 12 }, { 21, 13 }, { 22, 14 }, + { 23, 14 }, { 30, 20 }, { 31, 21 }, { 32, 22 }, { 40, 23 }, + { 41, 23 }, { 42, 30 }, { 50, 31 }, { 55, 32 }, { 60, 40 }, + { 65, 42 }, { 70, 43 }, { 75, 50 }, { 80, 51 }, { 90, 60 }, + { 91, 61 }, { 92, 62 }, { 100, 63 }, { 101, 64 }, { 102, 65 }, + { 110, 70 }, { 111, 71 }, { 112, 72 }, { 113, 73 }, { 114, 74 }, + { 115, 75 }, { 116, 76 }, { 117, 77 }, { 118, 78 }, { 120, 80 }, + { 121, 81 }, { 122, 82 }, { 123, 83 }, { 124, 84 }, { 125, 85 }, + { 126, 85 }, { 127, 86 }, { 128, 87 } + }; + + uint32_t driver_version = jitc_cuda_version_major*10+jitc_cuda_version_minor; + uint32_t ptx_version = 0; + + for (uint32_t i = 0; i < 43; ++i) { + uint32_t driver_version_i = driver_to_ptx_isa_mappling[i][0], + ptx_version_i = driver_to_ptx_isa_mappling[i][1]; + + if (driver_version >= driver_version_i) + ptx_version = ptx_version_i; + else + break; } + device.ptx_version = ptx_version; state.devices.push_back(device); } diff --git a/src/cuda_eval.cpp b/src/cuda_eval.cpp index 34b91a77..811af74f 100644 --- a/src/cuda_eval.cpp +++ b/src/cuda_eval.cpp @@ -54,6 +54,7 @@ #include "cuda_packet.h" #if defined(DRJIT_ENABLE_OPTIX) # include +# include "optix_coop_vec.h" #endif // Forward declarations @@ -96,16 +97,11 @@ void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group, %b3, %w3, %r3, %rd3, %f3, %d3, %p3: reserved for use in compound statements that must write a temporary result to a register. */ - uint32_t ptx_version = ts->ptx_version; - - // Using extended kernel parameter passing requires PTX ISA v8.1 - if (n_params > 512) - ptx_version = std::max(ptx_version, 81u); fmt(".version $u.$u\n" ".target sm_$u\n" ".address_size 64\n\n", - ptx_version / 10, ptx_version % 10, + ts->ptx_version / 10, ts->ptx_version % 10, ts->compute_capability); if (!uses_optix) { @@ -385,6 +381,11 @@ static void jitc_cuda_render(Variable *v) { *a2 = v->dep[2] ? jitc_var(v->dep[2]) : nullptr, *a3 = v->dep[3] ? jitc_var(v->dep[3]) : nullptr; +#if defined(DRJIT_ENABLE_OPTIX) + if (v->coop_vec) + return jitc_optix_render_coop_vec(v, a0, a1, a2, a3); +#endif + const ThreadState *ts = thread_state_cuda; bool f32_upcast = jitc_is_half(v) && @@ -581,6 +582,20 @@ static void jitc_cuda_render(Variable *v) { jitc_cuda_render_array_select(v, a0, a1, a2); break; +#if defined(DRJIT_ENABLE_OPTIX) + case VarKind::CoopVecUnpack: + jitc_optix_render_coop_vec_unpack(v, a0); + break; + + case VarKind::CoopVecAccum: + jitc_optix_render_coop_vec_accum(v, a0, a1, a2); + break; + + case VarKind::CoopVecOuterProductAccum: + jitc_optix_render_coop_vec_outer_product_accum(v, a0, a1, a2, a3); + break; +#endif + case VarKind::Select: if (!jitc_is_bool(a1)) { fmt(" selp.$b $v, $v, $v, $v;\n", v, v, a1, a2, a0); @@ -684,6 +699,10 @@ static void jitc_cuda_render(Variable *v) { fmt(" lg2.approx.ftz.$t $v, $v;\n", v, v, a0); break; + case VarKind::Tanh: + fmt(" tanh.approx.$t $v, $v;\n", v, v, a0); + break; + case VarKind::Cast: if (jitc_is_bool(v)) { diff --git a/src/cuda_ts.cpp b/src/cuda_ts.cpp index 32852c8f..e7467c3a 100644 --- a/src/cuda_ts.cpp +++ b/src/cuda_ts.cpp @@ -5,6 +5,7 @@ #include "optix.h" #include "eval.h" #include "util.h" +#include "optix_api.h" static uint8_t *kernel_params_global = nullptr; @@ -859,3 +860,72 @@ void CUDAThreadState::enqueue_host_func(void (*callback)(void *), scoped_set_context guard(context); cuda_check(cuLaunchHostFunc(stream, callback, payload)); } + +void CUDAThreadState::coop_vec_pack(uint32_t count, const void *in_, + const MatrixDescr *in_d, void *out_, + const MatrixDescr *out_d) { +#if defined(DRJIT_ENABLE_OPTIX) + scoped_set_context guard(context); + OptixDeviceContext ctx = jitc_optix_context(); + const uint8_t *in = (const uint8_t *) in_; + uint8_t *out = (uint8_t *) out_; + + std::vector in_o, out_o; + in_o.reserve(count); + out_o.reserve(count); + + for (uint32_t i = 0; i < count; ++i) { + const MatrixDescr &id = in_d[i], + &od = out_d[i]; + + uint32_t tsize = type_size[(int) id.dtype]; + + if (id.cols == 1) { + cuda_check(cuMemcpyAsync(out + od.offset * tsize, + in + id.offset * tsize, + id.size * tsize, + stream)); + } else { + uint32_t type_id = jitc_optix_coop_vec_type_id(id.dtype); + + OptixCoopVecMatrixDescription io; + io.N = id.rows; + io.K = id.cols; + io.offsetInBytes = id.offset * tsize; + io.elementType = type_id; + io.layout = jitc_optix_coop_vec_layout_id(id.layout); + io.rowColumnStrideInBytes = id.stride * tsize; + io.sizeInBytes = id.size * tsize; + in_o.push_back(io); + + OptixCoopVecMatrixDescription oo; + oo.N = od.rows; + oo.K = od.cols; + oo.offsetInBytes = od.offset * tsize; + oo.elementType = type_id; + oo.layout = jitc_optix_coop_vec_layout_id(od.layout); + oo.rowColumnStrideInBytes = od.stride * tsize; + oo.sizeInBytes = od.size * tsize; + out_o.push_back(oo); + } + } + + OptixNetworkDescription in_net, out_net; + in_net.layers = in_o.data(); + in_net.numLayers = in_o.size(); + out_net.layers = out_o.data(); + out_net.numLayers = out_o.size(); + + if (!optixCoopVecMatrixConvert) + jitc_raise("jit_coop_vec_pack(): Cooperative vectors are not " + "supported by your NVIDIA GPU driver. Please install " + "driver version 570 or newer."); + + if (in_net.numLayers) + jitc_optix_check(optixCoopVecMatrixConvert( + ctx, stream, 1, &in_net, (CUdeviceptr) in_, 0, &out_net, out_, 0)); +#else + (void) count; (void) in_; (void) in_d; (void) out_; (void) out_d; + jitc_raise("CUDAThreadState::coop_vec_pack(): requires OptiX support!"); +#endif +} diff --git a/src/cuda_ts.h b/src/cuda_ts.h index 117cdf0d..17ed4c9a 100644 --- a/src/cuda_ts.h +++ b/src/cuda_ts.h @@ -51,4 +51,8 @@ struct CUDAThreadState : ThreadState { uint32_t) override { jitc_raise("jitc_reduce_expanded(): unsupported by CUDAThreadState!"); } + + /// Pack a set of matrices/vectors for use with the cooperative vector API + void coop_vec_pack(uint32_t count, const void *in, const MatrixDescr *in_d, + void *out, const MatrixDescr *out_d) override; }; diff --git a/src/eval.cpp b/src/eval.cpp index 65414226..e32b497f 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -17,6 +17,7 @@ #include "optix.h" #include "loop.h" #include "call.h" +#include "coop_vec.h" #include "trace.h" #include "op.h" #include "array.h" @@ -103,19 +104,15 @@ static std::vector visit_later; // ==================================================================== // Don't perform scatters, whose output buffer is found to be unreferenced -bool jitc_var_maybe_suppress_scatter(uint32_t index, Variable *v, uint32_t depth) { +bool jitc_elide_scatter(uint32_t index, const Variable *v) { + if ((VarKind) v->kind != VarKind::Scatter) + return false; Variable *target = jitc_var(v->dep[0]); Variable *target_ptr = jitc_var(target->dep[3]); - if (target_ptr->ref_count != 0 || depth != 0) - return false; - jitc_log(Debug, "jit_eval(): eliding scatter r%u, whose output is unreferenced.", index); - if (callable_depth == 0) - jitc_var_dec_ref(index, v); - return true; + return target_ptr->ref_count == 0; } - /// Recursively traverse the computation graph to find variables needed by a computation static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0) { if (!visited.emplace(size, index, depth).second) @@ -124,7 +121,7 @@ static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0) Variable *v = jitc_var(index); switch ((VarKind) v->kind) { case VarKind::Scatter: - if (jitc_var_maybe_suppress_scatter(index, v, depth)) + if (jitc_elide_scatter(index, v)) return; break; @@ -189,6 +186,16 @@ static void jitc_var_traverse(uint32_t size, uint32_t index, uint32_t depth = 0) } break; + case VarKind::CoopVecPack: { + CoopVecPackData *cvid = (CoopVecPackData *) v->data; + for (uint32_t index2 : cvid->indices) { + if (index2 == 0) + continue; + jitc_var_traverse(size, index2, depth); + } + } + break; + case VarKind::TraceRay: { TraceData *call = (TraceData *) v->data; for (uint32_t i: call->indices) @@ -272,8 +279,8 @@ void jitc_assemble(ThreadState *ts, ScheduledGroup group) { if (unlikely(v->ref_count == 0)) jitc_fail("jit_assemble(): schedule contains unreferenced variable r%u!", index); if (unlikely(v->size != 1 && v->size != group.size)) - jitc_fail("jit_assemble(): schedule contains variable r%u with incompatible size " - "(%u and %u)!", index, v->size, group.size); + jitc_fail("jit_assemble(): schedule contains variable r%u of kind \"%s\" with incompatible size " + "(var=%u and kernel=%u)!", index, var_kind_name[v->kind], v->size, group.size); if (unlikely(v->is_dirty())) jitc_fail("jit_assemble(): dirty variable r%u encountered!", index); @@ -510,29 +517,29 @@ Task *jitc_run(ThreadState *ts, ScheduledGroup group) { bool cache_hit = false; if (ts->backend == JitBackend::CUDA) { - ProfilerPhase profiler(profiler_region_backend_compile); - if (!uses_optix) { - kernel.size = 1; // dummy size value to distinguish between OptiX and CUDA kernels - kernel.data = nullptr; - std::tie(kernel.cuda.mod, cache_hit) = jitc_cuda_compile(buffer.get()); - } else { - #if defined(DRJIT_ENABLE_OPTIX) - cache_hit = jitc_optix_compile( - ts, buffer.get(), buffer.size(), kernel_name, kernel); - #endif - } - } else { + ProfilerPhase profiler(profiler_region_backend_compile); + if (!uses_optix) { + kernel.size = 1; // dummy size value to distinguish between OptiX and CUDA kernels + kernel.data = nullptr; + std::tie(kernel.cuda.mod, cache_hit) = jitc_cuda_compile(buffer.get()); + } else { + #if defined(DRJIT_ENABLE_OPTIX) + cache_hit = jitc_optix_compile( + ts, buffer.get(), buffer.size(), kernel_name, kernel); + #endif + } + } else { cache_hit = jitc_kernel_load(buffer.get(), (uint32_t) buffer.size(), ts->backend, kernel_hash, kernel); - if (!cache_hit) { - ProfilerPhase profiler(profiler_region_backend_compile); - jitc_llvm_compile(kernel); + if (!cache_hit) { + ProfilerPhase profiler(profiler_region_backend_compile); + jitc_llvm_compile(kernel); jitc_kernel_write(buffer.get(), (uint32_t) buffer.size(), ts->backend, kernel_hash, kernel); - jitc_llvm_disasm(kernel); - } - } + jitc_llvm_disasm(kernel); + } + } if (ts->backend == JitBackend::CUDA && !uses_optix) { // Locate the kernel entry point @@ -679,8 +686,14 @@ void jitc_eval_impl(ThreadState *ts) { ts->scheduled.clear(); - for (uint32_t index: ts->side_effects) - jitc_var_traverse(jitc_var(index)->size, index); + for (uint32_t index: ts->side_effects) { + Variable *v = jitc_var(index); + + if (jitc_elide_scatter(index, v)) + jitc_var_dec_ref(index); + else + jitc_var_traverse(v->size, index); + } ts->side_effects.clear(); @@ -736,7 +749,6 @@ void jitc_eval_impl(ThreadState *ts) { for (ScheduledGroup &group : schedule_groups) { jitc_assemble(ts, group); - jitc_run(ts, group); } diff --git a/src/internal.h b/src/internal.h index 4467231d..33cc9460 100644 --- a/src/internal.h +++ b/src/internal.h @@ -74,7 +74,7 @@ enum class VarKind : uint32_t { Rcp, RcpApprox, RSqrtApprox, // Multi-function generator (CUDA) - Sin, Cos, Exp2, Log2, + Sin, Cos, Exp2, Log2, Tanh, // Casts Cast, Bitcast, @@ -169,6 +169,19 @@ enum class VarKind : uint32_t { // Write an element to a variable array ArrayWrite, + // Cooperative Vector API + CoopVecLiteral, + CoopVecPack, + CoopVecUnpack, + CoopVecLoad, + CoopVecCast, + CoopVecUnaryOp, + CoopVecBinaryOp, + CoopVecTernaryOp, + CoopVecMatVec, + CoopVecAccum, + CoopVecOuterProductAccum, + // Denotes the number of different node types Count }; @@ -260,8 +273,8 @@ struct alignas(64) Variable { /// If set, evaluation will have side effects on other variables uint32_t side_effect : 1; - /// Unused flag - uint32_t unused_2: 1; + /// Is this a cooperative vector? + uint32_t coop_vec : 1; // =========== Entries that are temporarily used in jitc_eval() ============ // (+11 bits -> 32 bits with all the preceding individiual bits = 4 bytes) @@ -304,7 +317,8 @@ struct alignas(64) Variable { /// Reference count stash, see \ref jit_var_stash_ref() uint16_t ref_count_stashed; - /// Variable arrays (is_array() == 1) store their array length here + /// Variable arrays (is_array() == 1) and cooperative vectors + /// (coop_vec == 1) store their array length here uint16_t array_length; }; @@ -319,6 +333,7 @@ struct alignas(64) Variable { bool is_node() const { return (uint32_t) kind > (uint32_t) VarKind::Literal; } bool is_dirty() const { return ref_count_se > 0; } bool is_array() const { return array_state != (uint32_t) ArrayState::Invalid; } + bool is_coop_vec_literal() const { return kind == (uint32_t) VarKind::CoopVecLiteral; } }; static_assert(sizeof(Variable) == 64); @@ -347,38 +362,34 @@ struct VariableExtra { struct VariableKey { uint32_t size; uint32_t dep[4]; - uint32_t kind : 8; - uint32_t backend : 2; - uint32_t type : 4; - uint32_t write_ptr : 1; - uint32_t unused : 1; - uint32_t scope_lo : 16; + uint32_t kind : 8; + uint32_t backend : 2; + uint32_t type : 4; + uint32_t write_ptr : 1; + uint32_t unused : 1; + uint32_t array_length : 16; + uint32_t scope; uint64_t literal; // The LVN data structure is significantly more efficient when // a single key fits into exactly 32 bytes. Hence the elaborate // bit packing below. VariableKey(const Variable &v) { - uint32_t scope_hi = v.scope; size = v.size; - for (int i = 0; i < 4; ++i) { - uint32_t d = v.dep[i]; - d ^= scope_hi & 0xF0000000; - dep[i] = d; - scope_hi <<= 4; - } + memcpy(dep, v.dep, sizeof(uint32_t)*4); kind = v.kind; backend = v.backend; type = v.type; write_ptr = v.write_ptr; unused = 0; - scope_lo = v.scope; + scope = v.scope; + array_length = (v.is_array() || v.coop_vec) ? v.array_length : 0; literal = v.literal; } bool operator==(const VariableKey &v) const { return memcmp((const void *) this, (const void *) &v, - 8 * sizeof(uint32_t)) == 0; + 9 * sizeof(uint32_t)) == 0; } }; @@ -389,7 +400,7 @@ struct VariableKeyHasher { size_t operator()(const VariableKey &k) const { // 'scope_hi' field not included in hash key, the hash function // is faster when processing exactly 32 bytes. - return hash((const void *) &k, 8 * sizeof(uint32_t), 0); + return hash((const void *) &k, 9 * sizeof(uint32_t), 0); } }; @@ -401,7 +412,7 @@ using LVNMap = /* StoreHash = */ true>; static_assert( - sizeof(VariableKey) == 8 * sizeof(uint32_t), + sizeof(VariableKey) == 9 * sizeof(uint32_t), "VariableKey: incorrect size, likely an issue with padding/packing!"); /// Caches basic information about a CUDA device @@ -548,6 +559,8 @@ struct OptixPipelineCompileOptions { unsigned int exceptionFlags; const char* pipelineLaunchParamsVariableName; unsigned int usesPrimitiveTypeFlags; + int allowOpacityMicromaps; + int allowClusteredGeometry; // OptiX 9.0 ABI }; struct OptixPipelineData { @@ -723,6 +736,11 @@ struct ThreadState : public ThreadStateBase { virtual void reduce_expanded(VarType vt, ReduceOp op, void *data, uint32_t exp, uint32_t size) = 0; + /// Pack a set of matrices/vectors for use with the cooperative vector API + virtual void coop_vec_pack(uint32_t count, const void *in, + const MatrixDescr *in_d, void *out, + const MatrixDescr *out_d) = 0; + /// Notify the \c ThreadState that \c jitc_free has been called on a pointer. /// This is required for kernel freezing. virtual void notify_free(const void *ptr); @@ -1064,7 +1082,7 @@ inline bool jitc_is_bool(const Variable *v) { return jitc_is_bool((VarType) v->t inline bool jitc_is_zero(Variable *v) { return v->is_literal() && v->literal == 0; } inline bool jitc_is_any_zero(Variable *v) { - if (!v->is_literal()) + if (!v->is_literal() && !v->is_coop_vec_literal()) return false; switch ((VarType) v->type) { @@ -1076,7 +1094,7 @@ inline bool jitc_is_any_zero(Variable *v) { } inline bool jitc_is_one(Variable *v) { - if (!v->is_literal()) + if (!v->is_literal() && !v->is_coop_vec_literal()) return false; uint64_t one; diff --git a/src/llvm.h b/src/llvm.h index 0eed45a8..c923c168 100644 --- a/src/llvm.h +++ b/src/llvm.h @@ -62,9 +62,12 @@ extern int jitc_llvm_version_patch; /// Pre-generated strings for use by the template engine -/// String of all ones, for different variable types +/// Literal one, for different variable types extern char **jitc_llvm_ones_str; +/// All bits set, for different variable types +extern char **jitc_llvm_ones_bit_str; + /// (up to the current vector width) extern char *jitc_llvm_u32_arange_str; diff --git a/src/llvm_api.cpp b/src/llvm_api.cpp index 4f3e9529..7e4a7439 100644 --- a/src/llvm_api.cpp +++ b/src/llvm_api.cpp @@ -99,7 +99,6 @@ bool jitc_llvm_api_init() { jitc_llvm_version_patch = -1; void *handle = jitc_llvm_handle; - LOAD(core, LLVMLinkInMCJIT); LOAD(core, LLVMInitializeDrJitAsmPrinter); LOAD(core, LLVMInitializeDrJitDisassembler); LOAD(core, LLVMInitializeDrJitTarget); @@ -139,6 +138,7 @@ bool jitc_llvm_api_init() { LOAD(pb_new, LLVMDisposePassBuilderOptions); LOAD(pb_new, LLVMRunPasses); + LOAD(mcjit, LLVMLinkInMCJIT); LOAD(mcjit, LLVMModuleCreateWithName); LOAD(mcjit, LLVMGetExecutionEngineTargetMachine); LOAD(mcjit, LLVMCreateMCJITCompilerForModule); @@ -240,7 +240,6 @@ void jitc_llvm_api_shutdown() { if (!jitc_llvm_handle) return; - CLEAR(LLVMLinkInMCJIT); CLEAR(LLVMInitializeDrJitAsmPrinter); CLEAR(LLVMInitializeDrJitDisassembler); CLEAR(LLVMInitializeDrJitTarget); @@ -284,6 +283,7 @@ void jitc_llvm_api_shutdown() { CLEAR(LLVMRunPasses); // MCJIT + CLEAR(LLVMLinkInMCJIT); CLEAR(LLVMModuleCreateWithName); CLEAR(LLVMGetExecutionEngineTargetMachine); CLEAR(LLVMCreateMCJITCompilerForModule); diff --git a/src/llvm_coop_vec.cpp b/src/llvm_coop_vec.cpp new file mode 100644 index 00000000..cd2d2606 --- /dev/null +++ b/src/llvm_coop_vec.cpp @@ -0,0 +1,456 @@ +/* + src/llvm_coop_vec.cpp -- LLVM fallback code generation for Cooperative Vectors + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "src/llvm.h" +#include "var.h" +#include "eval.h" +#include "coop_vec.h" +#include "llvm_eval.h" +#include "llvm_scatter.h" + +void jitc_llvm_render_coop_vec_unpack(const Variable *v, const Variable *a0) { + put(" ; coop_vec_unpack\n"); + fmt(" $v = bitcast $V_$u to $T\n", v, a0, v->literal, v); +} + + +void jitc_llvm_render_coop_vec_accum(const Variable *v, const Variable *target, + const Variable *value, const Variable *mask) { + put(" ; coop_vec_accum\n"); + + fmt(" $v_p = bitcast $<{i8*}$> $v to $<{$t*}$>\n", v, target, value); + + if (callable_depth == 0) { + fmt_intrinsic("declare $t @llvm$e.vector.reduce.fadd.v$w$h($t, $T)", + value, value, value, value); + } else { + jitc_llvm_append_reduce_op_local((VarType) value->type, ReduceOp::Add, value); + fmt(" $v_mask = bitcast $V to i$w\n", + v, mask); + } + + for (uint32_t i = 0; i < value->array_length; ++i) { + fmt(" $v_p_$u = getelementptr inbounds $t, $<{$t*}$> $v_p, i32 $u\n", + v, i, value, value, v, (uint32_t) v->literal + i); + + if (callable_depth == 0) { + fmt(" $v_valid_$u = select $V, $V_$u, $T zeroinitializer\n" + " $v_red_$u = call $t @llvm$e.vector.reduce.fadd.v$w$h($t zeroinitializer, $T $v_valid_$u)\n" + " atomicrmw fadd {$t*} $v_p_$u, $t $v_red_$u monotonic\n", + v, i, mask, value, i, value, + v, i, value, value, value, value, v, i, + value, v, i, value, v, i); + } else { + fmt(" call fastcc void @reduce_add_$h_atomic_local(<$w x $p> $v_p_$u, $V_$u, i$w $v_mask)\n", + value, value, v, i, value, i, v); + } + } +} + +void jitc_llvm_render_coop_vec_outer_product_accum(const Variable *v, const Variable *target, + const Variable *v0, const Variable *v1, + const Variable *mask) { + put(" ; coop_vec_outer_product_accum\n"); + + uint32_t m = v0->array_length, + n = v1->array_length, + tsize = type_size[v0->type], + vec_size = jitc_llvm_vector_width * tsize; + + const MatrixDescr *d = (const MatrixDescr *) v->data; + + alloca_size = std::max(alloca_size, (int32_t) (vec_size * (n + m))); + alloca_align = std::max(alloca_align, (int32_t) (vec_size)); + + fmt(" $v_po_0 = getelementptr inbounds i8, $<{i8*}$> $v, i32 $u\n" + " $v_po = bitcast $<{i8*}$> $v_po_0 to $<{[$u x $t]*}$>\n" + " $v_p0 = bitcast {i8*} %buffer to {$T*}\n" + " $v_p1 = getelementptr inbounds $T, {$T*} $v_p0, i32 $u\n" + " $v_mask = bitcast $V to i$w\n", + v, target, d->offset * tsize, + v, v, d->stride, v0, + v, v0, + v, v0, v0, v, m, + v, mask); + + put("\n ; Prepare inputs\n"); + for (uint32_t i = 0; i < m; ++i) { + fmt(" $v_p0_$u = getelementptr inbounds $T, {$T*} $v_p0, i32 $u\n" + " store $V_$u, {$T*} $v_p0_$u, align $A\n", + v, i, v0, v0, v, i, + v0, i, v0, v, i, v0); + } + for (uint32_t i = 0; i < n; ++i) { + fmt(" $v_p1_$u = getelementptr inbounds $T, {$T*} $v_p1, i32 $u\n" + " store $V_$u, {$T*} $v_p1_$u, align $A\n", + v, i, v1, v1, v, i, + v1, i, v1, v, i, v1); + } + fmt(" br label %l$u_before\n" + "\n" + " ; Outer product\n" + "l$u_before:\n" + " br label %l$u_outer\n" + "\n" + "l$u_outer:\n" + " $v_i = phi i32 [ 0, %l$u_before ], [ $v_i_next, %l$u_inner ]\n" + " $v_i_cont = icmp ult i32 $v_i, $u\n" + " $v_i_next = add nuw nsw i32 $v_i, 1\n" + " br i1 $v_i_cont, label %l$u_load, label %l$u_done\n" + "\n" + "l$u_load:\n" + " $v_p0i = getelementptr inbounds $T, {$T*} $v_p0, i32 $v_i\n" + " $v_v0i = load $T, {$T*} $v_p0i, align $A\n" + " br label %l$u_inner;\n" + "\n" + "l$u_inner:\n" + " $v_j = phi i32 [ 0, %l$u_load ], [ $v_j_next, %l$u_body ]\n" + " $v_j_next = add nuw nsw i32 $v_j, 1\n" + " $v_j_cont = icmp ult i32 $v_j, $u\n" + " br i1 $v_j_cont, label %l$u_body, label %l$u_outer\n\n" + "l$u_body:\n", + v->reg_index, + v->reg_index, + v->reg_index, + + // outer: + v->reg_index, + v, v->reg_index, v, v->reg_index, + v, v, m, + v, v, + v, v->reg_index,v->reg_index, + + // load: + v->reg_index, + v, v0, v0, v, v, + v, v0, v0, v, v0, + v->reg_index, + + // inner: + v->reg_index, + v, v->reg_index, v, v->reg_index, + v, v, + v, v, n, + v, v->reg_index, v->reg_index, + v->reg_index + ); + + fmt(" $v_p1j = getelementptr inbounds $T, {$T*} $v_p1, i32 $v_j\n" + " $v_v1j = load $T, {$T*} $v_p1j, align $A\n" + " $v_ij = fmul $T $v_v0i, $v_v1j\n", + v, v1, v1, v, v, + v, v1, v1, v, v1, + v, v1, v, v); + + fmt(" $v_po_ij = getelementptr inbounds [$u x $t], $<{[$u x $t]*}$> $v_po, i32 $v_i, i32 $v_j\n", + v, d->stride, v0, d->stride, v0, v, v, v); + + if (callable_depth == 0) { + fmt_intrinsic("declare $t @llvm$e.vector.reduce.fadd.v$w$h($t, $T)", + v1, v1, v1, v1); + fmt(" $v_valid = select $V, $T $v_ij, $T zeroinitializer\n" + " $v_red = call $t @llvm$e.vector.reduce.fadd.v$w$h($t zeroinitializer, $T $v_valid)\n" + " atomicrmw fadd {$t*} $v_po_ij, $t $v_red monotonic\n", + v, mask, v1, v, v1, + v, v1, v1, v1, v1, v, + v1, v, v1, v); + } else { + jitc_llvm_append_reduce_op_local((VarType) v1->type, ReduceOp::Add, v1); + fmt(" call fastcc void @reduce_add_$h_atomic_local(<$w x $p> $v_po_ij, $T $v_ij, i$w $v_mask)\n", + v1, v1, v, v1, v, v); + } + + fmt(" br label %l$u_inner\n" + "\n" + "l$u_done:\n", + v->reg_index, + v->reg_index); +} + +void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0, + const Variable *a1, const Variable *a2, + const Variable *a3) { + fmt(" ; $s\n", var_kind_name[v->kind]); + + switch ((VarKind) v->kind) { + case VarKind::CoopVecLiteral: + fmt(" $v_p = insertelement $T undef, $t $l, i32 0\n" + " $v_0 = shufflevector $T $v_p, $T undef, <$w x i32> $z\n", + v, v, v, v, + v, v, v, v); + for (uint32_t i = 1; i < v->array_length; ++i) + fmt(" $v_$u = bitcast $V_0 to $T\n", v, i, v, v); + break; + + case VarKind::CoopVecPack: { + const std::vector &indices = ((const CoopVecPackData *) v->data)->indices; + for (uint32_t i = 0; i < (uint32_t) indices.size(); ++i) + fmt(" $v_$u = bitcast $V to $T\n", v, i, jitc_var(indices[i]), v); + } + break; + + case VarKind::CoopVecBinaryOp: { + const char *op = nullptr; + + bool is_float = jitc_is_float(v), + is_sint = jitc_is_sint(v); + bool is_intrinsic = false; + + switch ((JitOp) v->literal) { + case JitOp::Add: op = is_float ? "fadd" : "add"; break; + case JitOp::Mul: op = is_float ? "fmul" : "mul"; break; + case JitOp::Sub: op = is_float ? "fsub" : "sub"; break; + case JitOp::Min: op = is_float ? "minnum" : (is_sint ? "smin" : "umin"); is_intrinsic = true; break; + case JitOp::Max: op = is_float ? "maxnum" : (is_sint ? "smax" : "umax"); is_intrinsic = true; break; + case JitOp::Step: op = is_float ? "fcmp olt" : "icmp lt"; break; + default: + jitc_fail("CoopVecBinaryOp: unsupported operation!"); + } + + if ((JitOp) v->literal == JitOp::Step) { + for (uint32_t i = 0; i < v->array_length; ++i) { + fmt(" $v_$u_m = $s $V_$u, $v_$u\n", v, i, op, a0, i, a1, i); + fmt(" $v_$u = select <$w x i1> $v_$u_m, $T zeroinitializer, $T $s\n", + v, i, v, i, v, v, jitc_llvm_ones_str[v->type]); + } + } else { + if (!is_intrinsic) { + for (uint32_t i = 0; i < v->array_length; ++i) + fmt(" $v_$u = $s $V_$u, $v_$u\n", v, i, op, a0, i, a1, i); + } else { + fmt_intrinsic("declare $T @llvm.$s.v$w$h($T, $T)", v, op, v, a0, a1); + for (uint32_t i = 0; i < v->array_length; ++i) + fmt(" $v_$u = call fast $T @llvm.$s.v$w$h($V_$u, $V_$u)\n", + v, i, v, op, v, a0, i, a1, i); + } + } + } + break; + + case VarKind::CoopVecTernaryOp: + if ((JitOp) v->literal != JitOp::Fma) + jitc_fail("CoopVecTernaryOp: unsupported operation!"); + + fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)", v, v, + a0, a1, a2); + for (uint32_t i = 0; i < v->array_length; ++i) + fmt(" $v_$u = call $T @llvm.fma.v$w$h($V_$u, $V_$u, $V_$u)\n", + v, i, v, v, a0, i, a1, i, a2, i); + break; + + case VarKind::Bitcast: + for (uint32_t i = 0; i < v->array_length; ++i) + fmt(" $v_$u = bitcast $V_u to $T\n", v, i, a0, i, v); + break; + + case VarKind::CoopVecCast: { + bool bigger = type_size[v->type] > type_size[a0->type], + dst_float = jitc_is_float(v), + src_signed = jitc_is_float(a0), + dst_signed = jitc_is_float(v), + src_float = jitc_is_float(a0); + + const char *op; + if (src_float && dst_float) + op = bigger ? "fpext" : "fptrunc"; + else if (!src_float && !dst_float) + op = bigger ? (src_signed ? "sext" : "zext") : "trunc"; + else if (src_float && !dst_float) + op = dst_signed ? "fptosi" : "fptoui"; + else + op = src_signed ? "sitofp" : "uitofp"; + for (uint32_t i = 0; i < v->array_length; ++i) + fmt(" $v_$u = $s $V_$u to $T\n", v, i, op, a0, i, v); + } + break; + + case VarKind::CoopVecLoad: + fmt(" $v_p = bitcast $<{i8*}$> $v to $<{$t*}$>\n", v, a0, v); + + for (uint32_t i = 0; i < v->array_length; ++i) { + fmt(" $v_p_$u = getelementptr inbounds $t, $<{$t*}$> $v_p, i32 $u\n", + v, i, v, v, v, (uint32_t) v->literal + i); + + if (callable_depth == 0) { + fmt(" $v_$u_0 = load $t, {$t *} $v_p_$u, align $a\n" + " $v_$u_1 = insertelement $T undef, $t $v_$u_0, i32 0\n" + " $v_$u = shufflevector $T $v_$u_1, $T undef, <$w x i32> $z\n", + v, i, v, v, v, i, v, + v, i, v, v, v, i, + v, i, v, v, i, v); + } else { + fmt_intrinsic("declare $T @llvm.masked.gather.v$w$h(<$w x {$t*}>, i32, <$w x i1>, $T)", + v, v, v, v); + fmt(" $v_$u = call $T @llvm.masked.gather.v$w$h(<$w x {$t*}> $v_p_$u, i32 $a, $V, $T $z)\n", + v, i, v, v, v, v, i, v, a1, v); + } + } + break; + + case VarKind::CoopVecMatVec: { + CoopVecMatVecData *d = (CoopVecMatVecData *) v->data; + bool transpose = d->transpose; + const Variable *mask = a2; + const Variable *bias = a3; + + uint32_t tsize = type_size[v->type], + vec_size = jitc_llvm_vector_width * tsize, + n = transpose ? d->A_descr.rows : d->A_descr.cols, + m = transpose ? d->A_descr.cols : d->A_descr.rows; + + alloca_size = std::max(alloca_size, (int32_t) (vec_size * (n + m))); + alloca_align = std::max(alloca_align, (int32_t) (vec_size)); + + fmt( " $v_pi = bitcast {i8*} %buffer to {$T*}\n" + " $v_po = getelementptr inbounds $T, {$T*} $v_pi, i32 $u\n" + " $v_pa_0 = bitcast $<{i8*}$> $v to $<{$t*}$>\n" + " $v_pa{_1|} = getelementptr inbounds $t, $<{$t*}$> $v_pa_0, i32 $u\n" + "{ $v_pa = bitcast $<$t*$> $v_pa_1 to $<[$u x $t]*$>\n|}", + v, v, + v, v, v, v, n, + v, a0, v, + v, v, v, v, d->A_descr.offset, + v, v, v, transpose ? m : n, v); + if (bias) { + fmt(" $v_pb_0 = bitcast $<{i8*}$> $v to $<{$t*}$>\n" + " $v_pb = getelementptr inbounds $t, $<{$t*}$> $v_pb_0, i32 $u\n", + v, bias, v, + v, v, v, v, d->b_descr.offset); + } + + put("\n ; Prepare input\n"); + for (uint32_t i = 0; i < n; ++i) { + fmt(" $v_pi_$u = getelementptr inbounds $T, {$T*} $v_pi, i32 $u\n" + " store $V_$u, {$T*} $v_pi_$u, align $A\n", + v, i, v, v, v, i, + a1, i, v, v, i, v); + } + + put("\n ; Prepare output\n"); + for (uint32_t i = 0; i < m; ++i) { + if (bias) { + fmt(" $v_b_$u_1 = getelementptr inbounds $t, $<{$t*}$> $v_pb, i32 $u\n", + v, i, v, v, v, i); + if (callable_depth == 0) { + fmt(" $v_b_$u_2 = load $t, {$t *} $v_b_$u_1, align $a\n" + " $v_b_$u_3 = insertelement $T undef, $t $v_b_$u_2, i32 0\n" + " $v_b_$u = shufflevector $T $v_b_$u_3, $T undef, <$w x i32> $z\n", + v, i, v, v, v, i, v, + v, i, v, v, v, i, + v, i, v, v, i, v); + } else { + fmt(" $v_b_$u = call $T @llvm.masked.gather.v$w$h(<$w x {$t*}> $v_b_$u_1, i32 $a, $V, $T $z)\n", + v, i, v, v, v, v, i, v, mask, v); + } + } + + fmt(" $v_po_$u = getelementptr inbounds $T, {$T*} $v_po, i32 $u\n", v, i, v, v, v, i); + if (bias) + fmt(" store $V_b_$u, {$T*} $v_po_$u, align $A\n", v, i, v, v, i, v); + else + fmt(" store $T zeroinitializer, {$T*} $v_po_$u, align $A\n", v, v, v, i, v); + } + + put("\n ; Matrix multiplication\n"); + fmt(" br label %l$u_before\n" + "\n" + "l$u_before:\n" + " br label %l$u_outer\n" + "\n" + "l$u_outer:\n" + " $v_j = phi i32 [ 0, %l$u_before ], [ $v_j_next, %l$u_inner ]\n" + " $v_j_cont = icmp ult i32 $v_j, $u\n" + " $v_j_next = add nuw nsw i32 $v_j, 1\n" + " br i1 $v_j_cont, label %l$u_load, label %l$u_done\n" + "\n" + "l$u_load:\n" + " $v_x1 = getelementptr inbounds $T, {$T*} $v_pi, i32 $v_j\n" + " $v_x = load $T, {$T*} $v_x1, align $A\n" + " br label %l$u_inner;\n" + "\n" + "l$u_inner:\n" + " $v_i = phi i32 [ 0, %l$u_load ], [ $v_i_next, %l$u_body ]\n" + " $v_i_next = add nuw nsw i32 $v_i, 1\n" + " $v_i_cont = icmp ult i32 $v_i, $u\n" + " br i1 $v_i_cont, label %l$u_body, label %l$u_outer\n\n" + "l$u_body:\n", + v->reg_index, + v->reg_index, + v->reg_index, + v->reg_index, + v, v->reg_index, v, v->reg_index, + v, v, n, + v, v, + + v, v->reg_index, v->reg_index, + + v->reg_index, + v, v, v, v, v, + v, v, v, v, v, + v->reg_index, + + v->reg_index, + v, v->reg_index, v, v->reg_index, + v, v, + v, v, m, + v, v->reg_index, v->reg_index, + v->reg_index + ); + + fmt(" $v_a1 = getelementptr inbounds [$u x $t], $<{[$u x $t]*}$> $v_pa, i32 $v_$s, i32 $v_$s\n", + v, transpose ? m : n, + v, transpose ? m : n, + v, v, + v, transpose ? "j" : "i", + v, transpose ? "i" : "j"); + + if (callable_depth == 0) { + fmt(" $v_a2 = load $t, {$t *} $v_a1, align $a\n" + " $v_a3 = insertelement $T undef, $t $v_a2, i32 0\n" + " $v_a = shufflevector $T $v_a3, $T undef, <$w x i32> $z\n", + v, v, v, v, v, + v, v, v, v, + v, v, v, v); + } else { + fmt_intrinsic("declare $T @llvm.masked.gather.v$w$h(<$w x {$t*}>, i32, <$w x i1>, $T)", + v, v, v, v); + fmt(" $v_a = call $T @llvm.masked.gather.v$w$h(<$w x {$t*}> $v_a1, i32 $a, $V, $T $z)\n", + v, v, v, v, v, v, mask, v); + } + + fmt_intrinsic("declare $T @llvm.fma.v$w$h($T, $T, $T)", v, v, + v, v, v); + + fmt(" $v_y1 = getelementptr inbounds $T, {$T*} $v_po, i32 $v_i\n" + " $v_y = load $T, {$T*} $v_y1, align $A\n" + " $v_r = call $T @llvm.fma.v$w$h($V_a, $V_x, $V_y)\n" + " store $V_r, {$T*} $v_y1, align $A\n", + v, v, v, v, v, + v, v, v, v, v, + v, v, v, v, v, v, + v, v, v, v); + + fmt(" br label %l$u_inner\n" + "\n" + "l$u_done:\n", + v->reg_index, + v->reg_index); + + put(" ; Read back results\n"); + for (uint32_t i = 0; i < m; ++i) + fmt(" $v_$u = load $T, {$T*} $v_po_$u, align $A\n", v, i, v, v, v, i, v); + } + break; + + default: + jitc_fail("jitc_llvm_render_coop_vec(): unhandled variable " + "kind \"%s\"!", + var_kind_name[(uint32_t) v->kind]); + } +} diff --git a/src/llvm_coop_vec.h b/src/llvm_coop_vec.h new file mode 100644 index 00000000..d3b64c65 --- /dev/null +++ b/src/llvm_coop_vec.h @@ -0,0 +1,25 @@ +/* + src/llvm_coop_vec.h -- LLVM fallback code generation for Cooperative Vectors + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +extern void jitc_llvm_render_coop_vec(const Variable *v, const Variable *a0, + const Variable *a1, const Variable *a2, + const Variable *a3); +extern void jitc_llvm_render_coop_vec_unpack(const Variable *v, + const Variable *a0); +extern void jitc_llvm_render_coop_vec_accum(const Variable *v, + const Variable *a0, + const Variable *a1, + const Variable *a2); +extern void jitc_llvm_render_coop_vec_outer_product_accum(const Variable *v, + const Variable *a0, + const Variable *a1, + const Variable *a2, + const Variable *a3); diff --git a/src/llvm_core.cpp b/src/llvm_core.cpp index 6d0ed7e1..efff3611 100644 --- a/src/llvm_core.cpp +++ b/src/llvm_core.cpp @@ -48,9 +48,12 @@ bool jitc_llvm_opaque_pointers = false; // Strings related to the vector width, used by template engine -/// String of all ones, for different variable types +/// The literal one, for different variable types char **jitc_llvm_ones_str = nullptr; +/// All bits set, for different variable types +char **jitc_llvm_ones_bit_str = nullptr; + /// (up to the current vector width) char *jitc_llvm_u32_arange_str = nullptr; @@ -96,7 +99,8 @@ bool jitc_llvm_init() { } - LLVMLinkInMCJIT(); + if (LLVMLinkInMCJIT) + LLVMLinkInMCJIT(); LLVMInitializeDrJitTargetInfo(); LLVMInitializeDrJitTarget(); LLVMInitializeDrJitTargetMC(); @@ -229,7 +233,12 @@ void jitc_llvm_shutdown() { free(jitc_llvm_ones_str[i]); free(jitc_llvm_ones_str); } - jitc_llvm_ones_str = nullptr; + if (jitc_llvm_ones_bit_str) { + for (uint32_t i = 0; i < (uint32_t) VarType::Count; ++i) + free(jitc_llvm_ones_bit_str[i]); + free(jitc_llvm_ones_bit_str); + } + jitc_llvm_ones_bit_str = nullptr; free(jitc_llvm_u32_arange_str); jitc_llvm_u32_arange_str = nullptr; free(jitc_llvm_u32_width_str); @@ -251,13 +260,39 @@ void jitc_llvm_update_strings() { free(jitc_llvm_ones_str[i]); free(jitc_llvm_ones_str); } + if (jitc_llvm_ones_bit_str) { + for (uint32_t i = 0; i < (uint32_t) VarType::Count; ++i) + free(jitc_llvm_ones_bit_str[i]); + free(jitc_llvm_ones_bit_str); + } jitc_llvm_ones_str = (char **) malloc(sizeof(char *) * (uint32_t) VarType::Count); + jitc_llvm_ones_bit_str = + (char **) malloc(sizeof(char *) * (uint32_t) VarType::Count); for (uint32_t i = 0; i < (uint32_t) VarType::Count; ++i) { VarType vt = (VarType) i; + buf.clear(); + buf.put('<'); + for (uint32_t j = 0; j < width; ++j) { + buf.put(type_name_llvm[i], strlen(type_name_llvm[i])); + buf.put(' '); + + if (vt == VarType::Float16 || vt == VarType::Float32 || + vt == VarType::Float64){ + buf.put("1.0"); + } else { + buf.put("1"); + } + + if (j + 1 < width) + buf.put(", "); + } + buf.put('>'); + jitc_llvm_ones_str[i] = strdup(buf.get()); + buf.clear(); buf.put('<'); for (uint32_t j = 0; j < width; ++j) { @@ -279,7 +314,7 @@ void jitc_llvm_update_strings() { buf.put(", "); } buf.put('>'); - jitc_llvm_ones_str[i] = strdup(buf.get()); + jitc_llvm_ones_bit_str[i] = strdup(buf.get()); } buf.clear(); diff --git a/src/llvm_eval.cpp b/src/llvm_eval.cpp index a65b0fe2..a8d8ce64 100644 --- a/src/llvm_eval.cpp +++ b/src/llvm_eval.cpp @@ -77,6 +77,7 @@ #include "llvm_array.h" #include "llvm_eval.h" #include "llvm_packet.h" +#include "llvm_coop_vec.h" // Forward declaration static void jitc_llvm_render(Variable *v); @@ -473,6 +474,9 @@ static void jitc_llvm_render(Variable *v) { *a2 = v->dep[2] ? jitc_var(v->dep[2]) : nullptr, *a3 = v->dep[3] ? jitc_var(v->dep[3]) : nullptr; + if (v->coop_vec) + return jitc_llvm_render_coop_vec(v, a0, a1, a2, a3); + bool f32_upcast = jitc_is_half(v) && !jitc_fp16_supported_llvm((VarKind)v->kind); if (f32_upcast) { @@ -523,10 +527,10 @@ static void jitc_llvm_render(Variable *v) { fmt(" $v_0 = bitcast $V to $B\n" " $v_1 = xor $B $v_0, $s\n" " $v = bitcast $B $v_1 to $T\n", - v, a0, v, v, v, v, jitc_llvm_ones_str[(int) itype], + v, a0, v, v, v, v, jitc_llvm_ones_bit_str[(int) itype], v, v, v, v); } else { - fmt(" $v = xor $V, $s\n", v, a0, jitc_llvm_ones_str[v->type]); + fmt(" $v = xor $V, $s\n", v, a0, jitc_llvm_ones_bit_str[v->type]); } break; @@ -953,6 +957,7 @@ static void jitc_llvm_render(Variable *v) { } } break; + case VarKind::Scatter: if (v->literal) jitc_llvm_render_scatter_reduce(v, a0, a1, a2, a3); @@ -1058,6 +1063,18 @@ static void jitc_llvm_render(Variable *v) { (uint32_t) v->literal, v); break; + case VarKind::CoopVecUnpack: + jitc_llvm_render_coop_vec_unpack(v, a0); + break; + + case VarKind::CoopVecAccum: + jitc_llvm_render_coop_vec_accum(v, a0, a1, a2); + break; + + case VarKind::CoopVecOuterProductAccum: + jitc_llvm_render_coop_vec_outer_product_accum(v, a0, a1, a2, a3); + break; + case VarKind::ThreadIndex: fmt(" $v_1 = insertelement <$w x i32> undef, i32 %thread_id, i32 0\n" " $v = shufflevector <$w x i32> $v_1, <$w x i32> undef, <$w x i32> $z\n", v, v, v); @@ -1369,7 +1386,7 @@ static void jitc_llvm_render_trace(const Variable *v, fmt("\n ; -------- Ray $s -------\n", shadow_ray ? "test" : "trace"); - // Copy input parameters to staging area + // Copy input parameters to staging area uint32_t offset = 0; for (uint32_t i = 0; i < 13; ++i) { if (jitc_llvm_vector_width == 1 && i == 0) @@ -1386,17 +1403,17 @@ static void jitc_llvm_render_trace(const Variable *v, offset += type_size[v2->type] * width; } - // Reset geomID field to ones as required + // Reset geomID field to ones as required if (!shadow_ray) { fmt( " $v_in_geomid_{0|1} = getelementptr inbounds i8, {i8*} %buffer, i32 $u\n" "{ $v_in_geomid_1 = bitcast i8* $v_in_geomid_0 to <$w x i32> *\n|}" " store <$w x i32> $s, {<$w x i32>*} $v_in_geomid_1, align $u\n", v, (14 * float_size + 5 * 4) * width, v, v, - jitc_llvm_ones_str[(int) VarType::Int32], v, float_size * width); + jitc_llvm_ones_bit_str[(int) VarType::Int32], v, float_size * width); } - // Determine whether to mark the rays as coherent or incoherent + // Determine whether to mark the rays as coherent or incoherent const Variable *coherent = jitc_var(indices[0]); fmt( " $v_in_ctx_{0|1} = getelementptr inbounds i8, {i8*} %buffer, i32 $u\n" @@ -1410,7 +1427,7 @@ static void jitc_llvm_render_trace(const Variable *v, fmt_intrinsic("declare i1 @llvm$e.vector.reduce.and.v$wi1(<$w x i1>)"); fmt(" $v_coherent_0 = call i1 @llvm$e.vector.reduce.and.v$wi1($V)\n" - " $v_coherent_1 = zext i1 $v_coherent_0 to i32\n" + " $v_coherent_1 = zext i1 $v_coherent_0 to i32\n" " $v_ctx = insertelement <6 x i32> , i32 $v_coherent_1, i32 0\n" " store <6 x i32> $v_ctx, {<6 x i32>*} $v_in_ctx_1, align 4\n", v, coherent, v, v, v, v, v, v); diff --git a/src/llvm_scatter.cpp b/src/llvm_scatter.cpp index 0100fa00..e298b2cd 100644 --- a/src/llvm_scatter.cpp +++ b/src/llvm_scatter.cpp @@ -148,9 +148,9 @@ static const char *reduce_op_name[(int) ReduceOp::Count] = { "", "add", "mul", "min", "max", "and", "or" }; -static const char *append_reduce_op_direct(VarType vt, ReduceOp op, const Variable *v) { +static const char *jitc_llvm_append_reduce_op_direct(VarType vt, ReduceOp op, const Variable *v) { if (jitc_llvm_vector_width > 32) - jitc_fail("append_reduce_op_direct(): internal error -- code generation " + jitc_fail("jitc_llvm_append_reduce_op_direct(): internal error -- code generation " "assumes a vector length of <= 32 entries"); uint32_t ptr_align = (uint32_t) sizeof(void *), @@ -219,7 +219,7 @@ static const char *append_reduce_op_direct(VarType vt, ReduceOp op, const Variab return "atomic"; // variant name } -static const char *append_reduce_op_local(VarType vt, ReduceOp op, const Variable *v) { +const char *jitc_llvm_append_reduce_op_local(VarType vt, ReduceOp op, const Variable *v) { uint32_t ptr_align = (uint32_t) sizeof(void *), ptr_align_vec = std::min(ptr_align * jitc_llvm_vector_width, jitc_llvm_max_align), shiftamt = log2i_ceil(type_size[(int) vt]); @@ -229,7 +229,7 @@ static const char *append_reduce_op_local(VarType vt, ReduceOp op, const Variabl *cmp_op = jitc_is_float(v) ? "fcmp one" : "icmp ne"; auto [vector_reduce_name, vector_reduce_modifier, vector_reduce_identity, - vector_reduce_identity_type, vector_reduce_version] + vector_reduce_identity_type, vector_reduce_version] = jitc_llvm_vector_reduce_config(vt, op); fmt_intrinsic("declare $t @llvm$e.vector.reduce$s.$s.v$w$h($s$T)", @@ -328,7 +328,7 @@ static const char *append_reduce_op_local(VarType vt, ReduceOp op, const Variabl return "atomic_local"; // variant name } -static const char *append_reduce_op_noconflict(VarType vt, ReduceOp op, const Variable *v) { +static const char *jitc_llvm_append_reduce_op_noconflict(VarType vt, ReduceOp op, const Variable *v) { uint32_t ptr_align = (uint32_t) sizeof(void *), ptr_align_vec = std::min(ptr_align * jitc_llvm_vector_width, jitc_llvm_max_align), shiftamt = log2i_ceil(type_size[(int) vt]); @@ -338,7 +338,7 @@ static const char *append_reduce_op_noconflict(VarType vt, ReduceOp op, const Va auto [vector_reduce_name, vector_reduce_modifier, vector_reduce_identity, - vector_reduce_identity_type, vector_reduce_version] + vector_reduce_identity_type, vector_reduce_version] = jitc_llvm_vector_reduce_config(vt, op); fmt_intrinsic("declare $t @llvm$e.vector.reduce$s.$s.v$w$h($s$T)", @@ -373,7 +373,7 @@ static const char *append_reduce_op_noconflict(VarType vt, ReduceOp op, const Va break; default: - jitc_fail("append_reduce_op_noconflict(): unsupported operation!"); + jitc_fail("jitc_llvm_append_reduce_op_noconflict(): unsupported operation!"); } char scalar_op[128]; @@ -518,15 +518,15 @@ void jitc_llvm_render_scatter_reduce(const Variable *v, switch (mode) { case ReduceMode::Direct: - variant = append_reduce_op_direct(vt, op, value); + variant = jitc_llvm_append_reduce_op_direct(vt, op, value); break; case ReduceMode::Local: - variant = append_reduce_op_local(vt, op, value); + variant = jitc_llvm_append_reduce_op_local(vt, op, value); break; case ReduceMode::NoConflicts: - variant = append_reduce_op_noconflict(vt, op, value); + variant = jitc_llvm_append_reduce_op_noconflict(vt, op, value); break; default: diff --git a/src/llvm_scatter.h b/src/llvm_scatter.h index 444b7ed8..b9262332 100644 --- a/src/llvm_scatter.h +++ b/src/llvm_scatter.h @@ -12,22 +12,26 @@ #include "eval.h" -void jitc_llvm_render_scatter(const Variable *v, const Variable *ptr, - const Variable *value, const Variable *index, - const Variable *mask); - -void jitc_llvm_render_scatter_reduce(const Variable *v, - const Variable *ptr, +extern void jitc_llvm_render_scatter(const Variable *v, const Variable *ptr, const Variable *value, const Variable *index, const Variable *mask); -void jitc_llvm_render_scatter_add_kahan(const Variable *v, - const Variable *ptr_1, - const Variable *ptr_2, - const Variable *index, - const Variable *value); +extern void jitc_llvm_render_scatter_reduce(const Variable *v, + const Variable *ptr, + const Variable *value, + const Variable *index, + const Variable *mask); + +extern void jitc_llvm_render_scatter_add_kahan(const Variable *v, + const Variable *ptr_1, + const Variable *ptr_2, + const Variable *index, + const Variable *value); -void jitc_llvm_render_scatter_inc(Variable *v, const Variable *ptr, - const Variable *index, const Variable *mask); +extern void jitc_llvm_render_scatter_inc(Variable *v, const Variable *ptr, + const Variable *index, + const Variable *mask); +extern const char *jitc_llvm_append_reduce_op_local(VarType vt, ReduceOp op, + const Variable *v); diff --git a/src/llvm_ts.cpp b/src/llvm_ts.cpp index 210d22fe..ce04e6cb 100644 --- a/src/llvm_ts.cpp +++ b/src/llvm_ts.cpp @@ -836,3 +836,57 @@ void LLVMThreadState::reduce_expanded(VarType vt, ReduceOp op, void *ptr, size, std::max(1u, blocks)); } + +void LLVMThreadState::coop_vec_pack(uint32_t count, const void *in, + const MatrixDescr *in_d, void *out, + const MatrixDescr *out_d) { + struct PackTask { + drjit::unique_ptr in_d, out_d; + const uint8_t *in; + uint8_t *out; + + void run(uint32_t index) { + const MatrixDescr &id = in_d[index], &od = out_d[index]; + uint32_t tsize = type_size[(int) id.dtype]; + + if (id.stride == id.cols && od.stride == od.cols) { + std::memcpy( + out + od.offset*tsize, + in + id.offset*tsize, + id.size * tsize + ); + } else { + for (uint32_t j = 0; j < id.rows; ++j) { + std::memcpy( + out + (od.offset + j * od.stride)*tsize, + in + (id.offset + j * id.stride)*tsize, + id.cols * tsize); + } + } + } + }; + + drjit::unique_ptr task = new PackTask(); + task->in_d = new MatrixDescr[count]; + task->out_d = new MatrixDescr[count]; + task->in = (const uint8_t *) in; + task->out = (uint8_t *) out; + std::memcpy(task->in_d.get(), in_d, count * sizeof(MatrixDescr)); + std::memcpy(task->out_d.get(), out_d, count * sizeof(MatrixDescr)); + + Task *new_task = task_submit_dep( + nullptr, &jitc_task, 1, count, + [](uint32_t index, void *p) { + ((PackTask *) p)->run(index); + }, + task.release(), + sizeof(void *), + [](void *p) { + delete (PackTask *) p; + }, + 0 + ); + + task_release(jitc_task); + jitc_task = new_task; +} diff --git a/src/llvm_ts.h b/src/llvm_ts.h index e76a641b..caec2ff8 100644 --- a/src/llvm_ts.h +++ b/src/llvm_ts.h @@ -50,4 +50,8 @@ struct LLVMThreadState : ThreadState { /// dr.ReduceOp.Expand void reduce_expanded(VarType vt, ReduceOp op, void *data, uint32_t exp, uint32_t size) override; + + /// Pack a set of matrices/vectors for use with the cooperative vector API + void coop_vec_pack(uint32_t count, const void *in, const MatrixDescr *in_d, + void *out, const MatrixDescr *out_d) override; }; diff --git a/src/lock.h b/src/lock.h index ed482450..2d6909c8 100644 --- a/src/lock.h +++ b/src/lock.h @@ -51,9 +51,6 @@ class unlock_guard { unlock_guard(Lock &lock) : m_lock(lock) { lock_release(m_lock); } ~unlock_guard() { lock_acquire(m_lock); - #if defined(DRJIT_SANITIZE_INTENSE) - jitc_sanitation_checkpoint(); - #endif } unlock_guard(const unlock_guard &) = delete; unlock_guard &operator=(const unlock_guard &) = delete; diff --git a/src/log.cpp b/src/log.cpp index fd17e8f8..2ca4880e 100644 --- a/src/log.cpp +++ b/src/log.cpp @@ -141,8 +141,11 @@ static void print_float_with_unit(char *buf, size_t bufsize, double value, bool accurate, const char *unit) { int digits_after_comma = accurate ? 5 : 3; - digits_after_comma = - std::max(digits_after_comma - int(std::log10(value)), 0); + if (value == 0) + digits_after_comma = 0; + else + digits_after_comma = + std::max(digits_after_comma - int(std::log10(value)), 0); int pos = snprintf(buf, bufsize, "%.*f", digits_after_comma, value); diff --git a/src/op.cpp b/src/op.cpp index 02546bfc..e841f27a 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -1469,6 +1469,29 @@ uint32_t jitc_var_log2_intrinsic(uint32_t a0) { // -------------------------------------------------------------------------- +template > = 0> +T eval_tanh(T value) { return T(std::tanh(value)); } + +template > = 0> +T eval_tanh(T) { jitc_fail("eval_tanh(): unsupported operands!"); } + +uint32_t jitc_var_tanh_intrinsic(uint32_t a0) { + auto [info, v0] = jitc_var_check("jit_var_tanh_intrinsic", a0); + + uint32_t result = 0; + if (info.simplify && info.literal) + result = jitc_eval_literal(info, [](auto l0) { return eval_tanh(l0); }, v0); + + if (!result && info.size) + result = jitc_var_new_node_1(info.backend, VarKind::Tanh, info.type, + info.size, info.symbolic, a0, v0); + + jitc_trace("jit_var_tanh_intrinsic(r%u <- r%u)", result, a0); + return result; +} + +// -------------------------------------------------------------------------- + uint32_t jitc_var_cast(uint32_t a0, VarType target_type, int reinterpret) { if (a0 == 0) return 0; @@ -1777,6 +1800,16 @@ uint32_t jitc_var_gather(uint32_t src_, uint32_t index, uint32_t mask) { msg = " [elided, scalar source]"; } + if (!result) { + index_v = jitc_var(index); + src_v = jitc_var(src); + if ((VarKind) index_v->kind == VarKind::Counter && + index_v->size == src_v->size) { + result = jitc_var_and(src, mask); + msg = " [elided, identity gather]"; + } + } + // Don't perform the gather operation if the inputs are trivial / can be re-indexed if (!result) { Ref index_2 = steal(jitc_var_cast(index, VarType::UInt32, 0)); @@ -2743,6 +2776,7 @@ uint32_t jitc_var_op(JitOp op, const uint32_t *dep) { case JitOp::Cos: return jitc_var_cos_intrinsic(dep[0]); case JitOp::Exp2: return jitc_var_exp2_intrinsic(dep[0]); case JitOp::Log2: return jitc_var_log2_intrinsic(dep[0]); + case JitOp::Tanh: return jitc_var_tanh_intrinsic(dep[0]); case JitOp::Eq: return jitc_var_eq(dep[0], dep[1]); case JitOp::Neq: return jitc_var_neq(dep[0], dep[1]); case JitOp::Lt: return jitc_var_lt(dep[0], dep[1]); diff --git a/src/op.h b/src/op.h index d5c74c53..9d494730 100644 --- a/src/op.h +++ b/src/op.h @@ -112,6 +112,7 @@ extern uint32_t jitc_var_sin_intrinsic(uint32_t a0); extern uint32_t jitc_var_cos_intrinsic(uint32_t a0); extern uint32_t jitc_var_exp2_intrinsic(uint32_t a0); extern uint32_t jitc_var_log2_intrinsic(uint32_t a0); +extern uint32_t jitc_var_tanh_intrinsic(uint32_t a0); /// Extra data describing a packet scatter operatoin struct PacketScatterData { diff --git a/src/optix.h b/src/optix.h index 484f93dd..4ffe3a66 100644 --- a/src/optix.h +++ b/src/optix.h @@ -19,6 +19,9 @@ struct OptixShaderBindingTable; struct ThreadState; struct OptixPipelineData; +// Is the cooperative vector ABI available? +extern bool jitc_optix_has_abi_105; + /// Create an OptiX device context on the current ThreadState extern OptixDeviceContext jitc_optix_context(); @@ -72,3 +75,13 @@ extern void jitc_optix_launch(ThreadState *ts, const Kernel &kernel, /// Optional: set the desired launch size extern void jitc_optix_set_launch_size(uint32_t width, uint32_t height, uint32_t samples); + +/// Convert a Dr.Jit variable type into an OptiX CoopVec variable type +extern uint32_t jitc_optix_coop_vec_type_id(VarType vt); + +/// Convert a Dr.Jit matrix layout type into an OptiX CoopVec matrix layout type +extern uint32_t jitc_optix_coop_vec_layout_id(MatrixLayout ml); + +#define jitc_optix_check(err) jitc_optix_check_impl((err), __FILE__, __LINE__) +using OptixResult = int; +extern void jitc_optix_check_impl(OptixResult errval, const char *file, const int line); diff --git a/src/optix_api.cpp b/src/optix_api.cpp index fe100b73..d4d23d58 100644 --- a/src/optix_api.cpp +++ b/src/optix_api.cpp @@ -8,8 +8,12 @@ */ #define DR_OPTIX_SYM(...) __VA_ARGS__ = nullptr; -#define DR_OPTIX_ABI_VERSION 87 -#define DR_OPTIX_FUNCTION_TABLE_SIZE 48 + +// We target ABI 87 and upgrade to 105 if it is available +#define DR_OPTIX_ABI_VERSION_87 87 +#define DR_OPTIX_ABI_VERSION_105 105 +#define DR_OPTIX_FUNCTION_TABLE_SIZE_87 48 +#define DR_OPTIX_FUNCTION_TABLE_SIZE_105 52 #include "optix.h" #include "optix_api.h" @@ -29,8 +33,10 @@ static void *jitc_optix_handle = nullptr; void *jitc_optix_win32_load_alternative(); #endif -static void *jitc_optix_table[DR_OPTIX_FUNCTION_TABLE_SIZE] { }; -static const char *jitc_optix_table_names[DR_OPTIX_FUNCTION_TABLE_SIZE] = { +static void *jitc_optix_table_87[DR_OPTIX_FUNCTION_TABLE_SIZE_87] { }; +static void *jitc_optix_table_105[DR_OPTIX_FUNCTION_TABLE_SIZE_105] { }; + +static const char *jitc_optix_table_names_87[DR_OPTIX_FUNCTION_TABLE_SIZE_87] = { "optixGetErrorName", "optixGetErrorString", "optixDeviceContextCreate", @@ -81,6 +87,63 @@ static const char *jitc_optix_table_names[DR_OPTIX_FUNCTION_TABLE_SIZE] = { "optixDenoiserCreateWithUserModel" }; +static const char *jitc_optix_table_names_105[DR_OPTIX_FUNCTION_TABLE_SIZE_105] = { + "optixGetErrorName", + "optixGetErrorString", + "optixDeviceContextCreate", + "optixDeviceContextDestroy", + "optixDeviceContextGetProperty", + "optixDeviceContextSetLogCallback", + "optixDeviceContextSetCacheEnabled", + "optixDeviceContextSetCacheLocation", + "optixDeviceContextSetCacheDatabaseSizes", + "optixDeviceContextGetCacheEnabled", + "optixDeviceContextGetCacheLocation", + "optixDeviceContextGetCacheDatabaseSizes", + "optixModuleCreate", + "optixModuleCreateWithTasks", + "optixModuleGetCompilationState", + "optixModuleDestroy", + "optixBuiltinISModuleGet", + "optixTaskExecute", + "optixProgramGroupCreate", + "optixProgramGroupDestroy", + "optixProgramGroupGetStackSize", + "optixPipelineCreate", + "optixPipelineDestroy", + "optixPipelineSetStackSize", + "optixAccelComputeMemoryUsage", + "optixAccelBuild", + "optixAccelGetRelocationInfo", + "optixCheckRelocationCompatibility", + "optixAccelRelocate", + "optixAccelCompact", + "optixAccelEmitProperty", + "optixConvertPointerToTraversableHandle", + "optixOpacityMicromapArrayComputeMemoryUsage", + "optixOpacityMicromapArrayBuild", + "optixOpacityMicromapArrayGetRelocationInfo", + "optixOpacityMicromapArrayRelocate", + "optixDisplacementMicromapArrayComputeMemoryUsage", + "optixDisplacementMicromapArrayBuild", + "optixClusterAccelComputeMemoryUsage", + "optixClusterAccelBuild", + "optixSbtRecordPackHeader", + "optixLaunch", + "optixCoopVecMatrixConvert", + "optixCoopVecMatrixComputeSize", + "optixDenoiserCreate", + "optixDenoiserDestroy", + "optixDenoiserComputeMemoryResources", + "optixDenoiserSetup", + "optixDenoiserInvoke", + "optixDenoiserComputeIntensity", + "optixDenoiserComputeAverageColor", + "optixDenoiserCreateWithUserModel" +}; + +bool jitc_optix_has_abi_105 = false; + bool jitc_optix_api_init() { if (jitc_optix_handle) return true; @@ -140,16 +203,28 @@ bool jitc_optix_api_init() { return false; } - int rv = optixQueryFunctionTable(DR_OPTIX_ABI_VERSION, 0, 0, 0, - &jitc_optix_table, sizeof(jitc_optix_table)); + int rv = optixQueryFunctionTable(DR_OPTIX_ABI_VERSION_105, 0, 0, 0, + &jitc_optix_table_105, + sizeof(jitc_optix_table_105)); if (rv) { - jitc_log(Warn, - "jit_optix_api_init(): Failed to load OptiX library! Very likely, " - "your NVIDIA graphics driver is too old and not compatible " - "with the version of OptiX that is being used. In particular, " - "OptiX 8.0 requires driver revision R535 or newer."); - jitc_optix_api_shutdown(); - return false; + jitc_optix_has_abi_105 = false; + memset(jitc_optix_table_105, 0, sizeof(jitc_optix_table_105)); + + // Next, try ABI 87 + int rv = optixQueryFunctionTable(DR_OPTIX_ABI_VERSION_87, 0, 0, 0, + &jitc_optix_table_87, + sizeof(jitc_optix_table_87)); + if (rv) { + jitc_log(Warn, + "jit_optix_api_init(): Failed to load OptiX library! Very likely, " + "your NVIDIA graphics driver is too old and not compatible " + "with the version of OptiX that is being used. In particular, " + "OptiX 8.0 requires driver revision R535 or newer."); + jitc_optix_api_shutdown(); + return false; + } + } else { + jitc_optix_has_abi_105 = true; } #define LOAD(name) name = (decltype(name)) jitc_optix_lookup(#name) @@ -174,9 +249,15 @@ bool jitc_optix_api_init() { LOAD(optixPipelineSetStackSize); LOAD(optixProgramGroupGetStackSize); + if (jitc_optix_has_abi_105) { + LOAD(optixCoopVecMatrixConvert); + LOAD(optixCoopVecMatrixComputeSize); + } + #undef LOAD - jitc_log(Info, "jit_optix_api_init(): loaded OptiX (via 8.0 ABI)."); + jitc_log(Info, "jit_optix_api_init(): loaded OptiX (via %s ABI).", + jitc_optix_has_abi_105 ? "9.0" : "8.0"); return true; } @@ -195,7 +276,8 @@ void jitc_optix_api_shutdown() { #endif jitc_optix_handle = nullptr; - memset(jitc_optix_table, 0, sizeof(jitc_optix_table)); + memset(jitc_optix_table_87, 0, sizeof(jitc_optix_table_87)); + memset(jitc_optix_table_105, 0, sizeof(jitc_optix_table_105)); #define Z(x) x = nullptr Z(optixGetErrorName); Z(optixGetErrorString); Z(optixDeviceContextCreate); @@ -206,13 +288,21 @@ void jitc_optix_api_shutdown() { Z(optixProgramGroupDestroy); Z(optixPipelineCreate); Z(optixPipelineDestroy); Z(optixLaunch); Z(optixSbtRecordPackHeader); Z(optixPipelineSetStackSize); Z(optixProgramGroupGetStackSize); + Z(optixCoopVecMatrixConvert); Z(optixCoopVecMatrixComputeSize); #undef Z } void *jitc_optix_lookup(const char *name) { - for (size_t i = 0; i < DR_OPTIX_FUNCTION_TABLE_SIZE; ++i) { - if (strcmp(name, jitc_optix_table_names[i]) == 0) - return jitc_optix_table[i]; + if (jitc_optix_has_abi_105) { + for (size_t i = 0; i < DR_OPTIX_FUNCTION_TABLE_SIZE_105; ++i) { + if (strcmp(name, jitc_optix_table_names_105[i]) == 0) + return jitc_optix_table_105[i]; + } + } else { + for (size_t i = 0; i < DR_OPTIX_FUNCTION_TABLE_SIZE_87; ++i) { + if (strcmp(name, jitc_optix_table_names_87[i]) == 0) + return jitc_optix_table_87[i]; + } } jitc_raise("jit_optix_lookup(): function \"%s\" not found!", name); } diff --git a/src/optix_api.h b/src/optix_api.h index e8575081..35a06458 100644 --- a/src/optix_api.h +++ b/src/optix_api.h @@ -30,31 +30,38 @@ using OptixTask = void*; using OptixModule = void*; using OptixProgramGroup = void*; using OptixPipeline = void*; - -#define OPTIX_EXCEPTION_FLAG_NONE 0 -#define OPTIX_EXCEPTION_FLAG_STACK_OVERFLOW 1 -#define OPTIX_EXCEPTION_FLAG_TRACE_DEPTH 2 -#define OPTIX_EXCEPTION_FLAG_DEBUG 8 -#define OPTIX_ERROR_VALIDATION_FAILURE 7053 -#define OPTIX_COMPILE_DEBUG_LEVEL_NONE 0x2350 -#define OPTIX_COMPILE_DEBUG_LEVEL_MINIMAL 0x2351 -#define OPTIX_COMPILE_DEBUG_LEVEL_MODERATE 0x2353 -#define OPTIX_COMPILE_DEBUG_LEVEL_FULL 0x2352 -#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_0 0x2340 -#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_1 0x2341 -#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_2 0x2342 -#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_3 0x2343 -#define OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_OFF 0 -#define OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_ALL ((int) 0xFFFFFFFF) -#define OPTIX_MODULE_COMPILE_STATE_COMPLETED 0x2364 -#define OPTIX_PROGRAM_GROUP_KIND_RAYGEN 0x2421 -#define OPTIX_PROGRAM_GROUP_KIND_CALLABLES 0x2425 -#define OPTIX_PROGRAM_GROUP_KIND_MISS 0x2422 -#define OPTIX_SBT_RECORD_HEADER_SIZE 32 -#define OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY 0 -#define OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS 1 +using OptixCoopVecElemType = int; +using OptixCoopVecMatrixLayout = int; +struct OptixPipelineCompileOptions; +struct OptixShaderBindingTable; + +#define OPTIX_EXCEPTION_FLAG_NONE 0 +#define OPTIX_EXCEPTION_FLAG_STACK_OVERFLOW 1 +#define OPTIX_EXCEPTION_FLAG_TRACE_DEPTH 2 +#define OPTIX_EXCEPTION_FLAG_DEBUG 8 +#define OPTIX_ERROR_VALIDATION_FAILURE 7053 +#define OPTIX_COMPILE_DEBUG_LEVEL_NONE 0x2350 +#define OPTIX_COMPILE_DEBUG_LEVEL_MINIMAL 0x2351 +#define OPTIX_COMPILE_DEBUG_LEVEL_MODERATE 0x2353 +#define OPTIX_COMPILE_DEBUG_LEVEL_FULL 0x2352 +#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_0 0x2340 +#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_1 0x2341 +#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_2 0x2342 +#define OPTIX_COMPILE_OPTIMIZATION_LEVEL_3 0x2343 +#define OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_OFF 0 +#define OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_ALL ((int) 0xFFFFFFFF) +#define OPTIX_MODULE_COMPILE_STATE_COMPLETED 0x2364 +#define OPTIX_PROGRAM_GROUP_KIND_RAYGEN 0x2421 +#define OPTIX_PROGRAM_GROUP_KIND_CALLABLES 0x2425 +#define OPTIX_PROGRAM_GROUP_KIND_MISS 0x2422 +#define OPTIX_SBT_RECORD_HEADER_SIZE 32 +#define OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY 0 +#define OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS 1 +#define OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE (1 << 31) +#define OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR 0x2A40 +#define OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL 0x2A43 +#define OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL 0x2A42 #define OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_LEVEL_INSTANCING (1u << 1) -#define OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE (1 << 31) struct OptixDeviceContextOptions { OptixLogCallback logCallbackFunction; @@ -130,6 +137,21 @@ struct OptixProgramGroupOptions { OptixPayloadType *payloadType; }; +struct OptixCoopVecMatrixDescription { + unsigned int N; + unsigned int K; + unsigned int offsetInBytes; + OptixCoopVecElemType elementType; + OptixCoopVecMatrixLayout layout; + unsigned int rowColumnStrideInBytes; + unsigned int sizeInBytes; +}; + +struct OptixNetworkDescription { + OptixCoopVecMatrixDescription* layers; + unsigned int numLayers; +}; + DR_OPTIX_SYM(OptixResult (*optixQueryFunctionTable)(int, unsigned int, void *, const void **, void *, size_t)); @@ -174,3 +196,11 @@ DR_OPTIX_SYM(OptixResult (*optixPipelineSetStackSize)( DR_OPTIX_SYM(OptixResult (*optixProgramGroupGetStackSize)(OptixProgramGroup, OptixStackSizes *, OptixPipeline)); + +DR_OPTIX_SYM(OptixResult (*optixCoopVecMatrixConvert)( + OptixDeviceContext, CUstream, unsigned int, const OptixNetworkDescription *, + CUdeviceptr, size_t, const OptixNetworkDescription *, CUdeviceptr, size_t)); + +DR_OPTIX_SYM(OptixResult (*optixCoopVecMatrixComputeSize)( + OptixDeviceContext, unsigned int, unsigned int, OptixCoopVecElemType, + OptixCoopVecMatrixLayout, size_t, size_t *)); diff --git a/src/optix_coop_vec.cpp b/src/optix_coop_vec.cpp new file mode 100644 index 00000000..268cf5fb --- /dev/null +++ b/src/optix_coop_vec.cpp @@ -0,0 +1,488 @@ +/* + src/optix_coop_vec.cpp -- OptiX code generation for Cooperative Vectors + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "var.h" +#include "eval.h" +#include "coop_vec.h" +#include "cuda_eval.h" +#include "optix_api.h" + +#define OPTIX_COOP_VEC_OP_CVT 0x2A2A + +// -------------------------------------------------------------------------- +// Mappings to Optix ABI IDs + +static uint32_t jitc_optix_coop_vec_op_id(JitOp op) { + switch (op) { + case JitOp::Exp2: return 0x2A21; + case JitOp::Log2: return 0x2A22; + case JitOp::Tanh: return 0x2A23; + case JitOp::Max: return 0x2A24; + case JitOp::Min: return 0x2A25; + case JitOp::Fma: return 0x2A26; + case JitOp::Mul: return 0x2A27; + case JitOp::Add: return 0x2A28; + case JitOp::Sub: return 0x2A29; + case JitOp::Step: return 0x2A2B; + default: jitc_fail("get_coop_vec_type_id(): unsupported operation!"); + } +} + +uint32_t jitc_optix_coop_vec_type_id(VarType vt) { + switch (vt) { + case VarType::Float16: return 0x2A01; + case VarType::Float32: return 0x2A03; + case VarType::UInt32: return 0x2A08; + case VarType::Int32: return 0x2A09; + default: + jitc_fail("jitc_optix_coop_vec_type_id(): unsupported variable type!"); + } +} + +uint32_t jitc_optix_coop_vec_layout_id(MatrixLayout ml) { + switch (ml) { + case MatrixLayout::TrainingOptimal: return OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; + case MatrixLayout::InferencingOptimal: return OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; + case MatrixLayout::RowMajor: return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; + default: + jitc_fail("jitc_optix_coop_vec_layout_id(): unsupported layout type!"); + } +} + +// -------------------------------------------------------------------------- +// Helper routines to generate frequently used PTX fragments + +static uint32_t get_reg_count(uint32_t l) { + if (l <= 16) + return 16; + else if (l <= 64) + return 64; + else + return 0; +} + +static void declare_buffer(const char *name, const Variable *v) { + uint32_t tsize = type_size[v->type], + length = v->array_length; + + fmt(" .local .align $u .b8 $s[$u];\n" + " .reg.u64 %$s;\n" + " cvta.local.u64 %$s, $s;\n", + tsize, name, tsize * length, + name, + name, name); +} + +static void copy_to_buffer(const char *name, const Variable *v) { + uint32_t tsize = type_size[v->type], + length = v->array_length; + for (uint32_t i = 0; i < length; ++i) + fmt(" st.local.$b [$s+$u], %cv$u_$u;\n", v, name, tsize * i, v->reg_index, i); +} + +static void copy_from_buffer(const char *name, const Variable *v) { + uint32_t tsize = type_size[v->type], + length = v->array_length; + + for (uint32_t i = 0; i < length; ++i) + fmt(" ld.local.$b %cv$u_$u, [$s+$u];\n", v, v->reg_index, i, name, tsize * i); +} + +static void put_elem(const Variable *v, uint32_t reg_count, bool trailing_comma, bool is_return_value = false) { + for (uint32_t i = 0; i < reg_count; ++i) { + if (i < v->array_length || is_return_value) + fmt("%cv$u_$u, ", v->reg_index, i); + else + put("%u, "); + } + if (!trailing_comma) + buffer.rewind(2); +} + +void jitc_optix_render_coop_vec_unpack(const Variable *v, const Variable *a0) { + uint32_t tsize = type_size[v->type]; + + put(" // coop_vec_unpack\n"); + if (tsize == 4) { + fmt(" mov.b32 $v, %cv$u_$u;\n", v, a0->reg_index, (uint32_t) v->literal); + } else { + fmt(" {\n" + " .reg.$b %temp;\n" + " cvt.u$u.u32 %temp, %cv$u_$u;\n" + " mov.$b $v, %temp;\n" + " }\n", + v, + tsize*8, a0->reg_index, (uint32_t) v->literal, + v, v); + } +} + +void jitc_optix_render_coop_vec_accum(const Variable *v, const Variable *target, + const Variable *value, const Variable *mask) { + bool use_mask = !mask->is_literal() || mask->literal != 1; + uint32_t length = value->array_length, + reg_count = get_reg_count(length); + + if (use_mask) + fmt(" @!$v bra l_$u_done;\n", mask, v->reg_index); + + fmt(" { // coop_vec_accum\n" + " .reg.b32 %type, %size, %offset, %u;\n" + " .reg.b64 %dst;\n" + " mov.b32 %type, $u;\n" + " mov.b32 %size, $u;\n" + " mov.b32 %offset, $u;\n" + " mov.b64 %dst, $v;\n", + jitc_optix_coop_vec_type_id((VarType) value->type), + length, + (uint32_t) v->literal * type_size[value->type], + target + ); + + if (reg_count) { + fmt(" call (), _optix_reduce_sum_accumulate_$uxi32, (%type, %size, %dst, %offset, ", + reg_count); + put_elem(value, reg_count, false); + put(");\n"); + } else { + declare_buffer("src", value); + copy_to_buffer("src", value); + put(" call (), _optix_reduce_sum_accumulate_ptr, (%type, %size, %dst, %offset, %src_p);\n"); + } + put(" }\n"); + if (use_mask) + fmt("\nl_$u_done:\n", v->reg_index); +} + +void jitc_optix_render_coop_vec_outer_product_accum(const Variable *v, const Variable *target, + const Variable *v0, const Variable *v1, + const Variable *mask) { + uint32_t op_length = std::max(v0->array_length, v1->array_length), + reg_count = get_reg_count(op_length), + tsize = type_size[jitc_var(target->dep[3])->type]; + + const MatrixDescr *d = (const MatrixDescr *) v->data; + + bool use_mask = !mask->is_literal() || mask->literal != 1; + if (use_mask) + fmt(" @!$v bra l_$u_done;\n", mask, v->reg_index); + + fmt(" { // coop_vec_outer_product_accum\n" + " .reg.b32 %type_0, %type_1, %size_0, %size_1, %offset, %layout, %stride, %u;\n" + " .reg.b64 %dst;\n" + " mov.b32 %type_0, $u;\n" + " mov.b32 %type_1, $u;\n" + " mov.b32 %size_0, $u;\n" + " mov.b32 %size_1, $u;\n" + " mov.b32 %offset, $u;\n" + " mov.b32 %layout, $u;\n" + " mov.b32 %stride, $u;\n" + " mov.b64 %dst, $v;\n", + jitc_optix_coop_vec_type_id((VarType) v0->type), + jitc_optix_coop_vec_type_id((VarType) v1->type), + v0->array_length, + v1->array_length, + d->offset * tsize, + jitc_optix_coop_vec_layout_id(d->layout), + d->stride * tsize, + target); + + if (reg_count) { + fmt(" call (), _optix_outer_product_accumulate_$uxi32, (%type_0, " + "%size_0, %type_1, %size_1, %dst, %offset, %layout, %stride, ", + reg_count); + put_elem(v0, reg_count, true); + put_elem(v1, reg_count, false); + put(");\n"); + } else { + declare_buffer("src_0", v0); + declare_buffer("src_1", v1); + copy_to_buffer("src_0", v0); + copy_to_buffer("src_1", v1); + put(" call (), _optix_outer_product_accumulate_ptr, (%type_0, " + "%size_0, %type_1, %size_1, %dst, %offset, %layout, %stride, " + "%src_0, %src_1);\n"); + } + put(" }\n"); + if (use_mask) + fmt("\nl_$u_done:\n", v->reg_index); +} + +void jitc_optix_render_coop_vec(const Variable *v, const Variable *a0, + const Variable *a1, const Variable *a2, + const Variable *a3) { + uint32_t tsize = type_size[v->type], + length = v->array_length, + op_length = length; + + if ((VarKind) v->kind == VarKind::CoopVecMatVec) + op_length = std::max(op_length, (uint32_t) a1->array_length); + + uint32_t reg_count = get_reg_count(op_length); + + fmt(" .reg.b32 %cv$u_<$u>;\n", v->reg_index, std::max(reg_count, (uint32_t) v->array_length)); + + fmt(" { // $s\n", var_kind_name[v->kind]); + + switch ((VarKind) v->kind) { + case VarKind::CoopVecLiteral: + for (uint32_t i = 0; i < length; ++i) + fmt(" mov.b32 %cv$u_$u, $l;\n", v->reg_index, i, v); + break; + + case VarKind::CoopVecPack: { + if (tsize != 4) + fmt(" .reg.b$u %temp;\n", tsize*8); + const std::vector &indices = ((const CoopVecPackData *) v->data)->indices; + for (uint32_t i = 0; i < (uint32_t) indices.size(); ++i) { + if (tsize != 4) { + fmt(" mov.b$u %temp, $v;\n" + " cvt.u32.u$u %cv$u_$u, %temp;\n", + tsize*8, jitc_var(indices[i]), + tsize*8, v->reg_index, i); + } else { + fmt(" mov.b32 %cv$u_$u, $v;\n", + v->reg_index, i, jitc_var(indices[i])); + } + } + } + break; + + case VarKind::CoopVecLoad: + fmt(" .reg.b32 %type, %size, %u;\n" + " .reg.b64 %src;\n" + " mov.b32 %type, $u;\n" + " mov.b32 %size, $u;\n" + " add.u64 %src, $v, $u;\n", + jitc_optix_coop_vec_type_id((VarType) v->type), + length, + a0, (uint32_t) v->literal * type_size[v->type] + ); + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_vector_load_$uxi32, (%type, %size, %src);\n", + reg_count); + } else { + declare_buffer("dst", v); + fmt(" call (), _optix_vector_load_ptr, (%type, %size, %src, %dst);\n", + v->reg_index); + copy_from_buffer("dst", v); + } + break; + + case VarKind::CoopVecUnaryOp: + fmt(" .reg.b32 %op, %type, %size, %u;\n" + " mov.b32 %op, $u;\n" + " mov.b32 %type, $u;\n" + " mov.b32 %size, $u;\n", + jitc_optix_coop_vec_op_id((JitOp) v->literal), + jitc_optix_coop_vec_type_id((VarType) v->type), + length + ); + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_vector_op1_$uxi32, (%op, %type, %size, %type, %size, ", reg_count); + put_elem(a0, reg_count, false); + put(");\n"); + } else { + declare_buffer("src", a0); + declare_buffer("dst", v); + copy_to_buffer("src", a0); + put(" call (), _optix_vector_op1_ptr, (%op, %type, %size, %type, %size, %src, %dst);\n"); + copy_from_buffer("dst", v); + } + break; + + case VarKind::CoopVecBinaryOp: + fmt(" .reg.b32 %op, %type, %size, %u;\n" + " mov.b32 %op, $u;\n" + " mov.b32 %type, $u;\n" + " mov.b32 %size, $u;\n", + jitc_optix_coop_vec_op_id((JitOp) v->literal), + jitc_optix_coop_vec_type_id((VarType) v->type), + length + ); + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_vector_op2_$uxi32, (%op, %type, %size, %type, %size, ", reg_count); + put_elem(a0, reg_count, true); + put_elem(a1, reg_count, false); + put(");\n"); + } else { + declare_buffer("src_0", a0); + declare_buffer("src_1", a1); + declare_buffer("dst", v); + copy_to_buffer("src_0", a0); + copy_to_buffer("src_1", a1); + put(" call (), _optix_vector_op2_ptr, (%op, %type, %size, %type, %size, %src_0, %src_1, %dst);\n"); + copy_from_buffer("dst", v); + } + break; + + case VarKind::CoopVecTernaryOp: + fmt(" .reg.b32 %op, %type, %size, %u;\n" + " mov.b32 %op, $u;\n" + " mov.b32 %type, $u;\n" + " mov.b32 %size, $u;\n", + jitc_optix_coop_vec_op_id((JitOp) v->literal), + jitc_optix_coop_vec_type_id((VarType) v->type), + length + ); + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_vector_op3_$uxi32, (%op, %type, %size, %type, %size, ", reg_count); + put_elem(a0, reg_count, true); + put_elem(a1, reg_count, true); + put_elem(a2, reg_count, false); + put(");\n"); + } else { + declare_buffer("src_0", a0); + declare_buffer("src_1", a1); + declare_buffer("src_2", a2); + declare_buffer("dst", v); + copy_to_buffer("src_0", a0); + copy_to_buffer("src_1", a1); + copy_to_buffer("src_2", a2); + put(" call (), _optix_vector_op3_ptr, (%op, %type, %size, %type, %size, %src_0, %src_1, %src_2, %dst);\n"); + copy_from_buffer("dst", v); + } + break; + + case VarKind::Bitcast: + for (uint32_t i = 0; i < reg_count; ++i) + fmt(" mov.b32 %cv$u_$u, %cv$u_$u;\n", + v->reg_index, i, a0->reg_index, i); + break; + + case VarKind::CoopVecCast: + fmt(" .reg.b32 %op, %in_type, %out_type, %size, %u;\n" + " mov.b32 %op, $u;\n" + " mov.b32 %out_type, $u;\n" + " mov.b32 %in_type, $u;\n" + " mov.b32 %size, $u;\n", + OPTIX_COOP_VEC_OP_CVT, + jitc_optix_coop_vec_type_id((VarType) v->type), + jitc_optix_coop_vec_type_id((VarType) a0->type), + length + ); + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_vector_op1_$uxi32, (%op, %in_type, %size, %out_type, %size, ", reg_count); + put_elem(a0, reg_count, false); + put(");\n"); + } else { + declare_buffer("src", a0); + declare_buffer("dst", v); + copy_to_buffer("src", a0); + put(" call (), _optix_vector_op1_ptr, (%op, %in_type, %size, %out_type, %size, %src, %dst);\n"); + copy_from_buffer("dst", v); + } + break; + + case VarKind::CoopVecMatVec: { + CoopVecMatVecData *d = (CoopVecMatVecData *) v->data; + const Variable *matrix_v = jitc_var(a0->dep[3]); + const Variable *bias_v = a3 ? jitc_var(a3->dep[3]) : nullptr; + + uint32_t input_type_id = jitc_optix_coop_vec_type_id((VarType) a1->type); + uint32_t output_type_id = jitc_optix_coop_vec_type_id((VarType) v->type); + uint32_t matrix_type_id = jitc_optix_coop_vec_type_id((VarType) matrix_v->type); + uint32_t bias_type_id = bias_v ? jitc_optix_coop_vec_type_id((VarType) bias_v->type) : 0; + uint32_t mat_tsize = type_size[matrix_v->type]; + if (!bias_type_id) + bias_type_id = output_type_id; + + fmt(" .reg.b32 %out_type, %out_size, %in_type, %in_size, " + "%in_interp, %mat_type, %mat_offset, %mat_stride, " + "%mat_layout, %mat_n, %mat_k, %mat_transpose, " + "%bias_type, %bias_offset, %u;\n" + " .reg.b64 %mat_ptr, %bias_ptr;\n" + " mov.b32 %out_type, $u;\n" + " mov.b32 %out_size, $u;\n" + " mov.b32 %in_type, $u;\n" + " mov.b32 %in_size, $u;\n" + " mov.b32 %in_interp, $u;\n" + " mov.b32 %mat_type, $u;\n" + " mov.b64 %mat_ptr, $v;\n" + " mov.b32 %mat_offset, $u;\n" + " mov.b32 %mat_stride, $u;\n" + " mov.b32 %mat_layout, $u;\n" + " mov.b32 %mat_n, $u;\n" + " mov.b32 %mat_k, $u;\n" + " mov.b32 %mat_transpose, $u;\n" + " mov.b32 %bias_type, $u;\n", + output_type_id, + v->array_length, + input_type_id, + a1->array_length, + matrix_type_id, + matrix_type_id, + a0, + d->A_descr.offset * mat_tsize, + d->A_descr.stride * mat_tsize, + jitc_optix_coop_vec_layout_id(d->A_descr.layout), + d->transpose ? d->A_descr.cols : d->A_descr.rows, + d->transpose ? d->A_descr.rows : d->A_descr.cols, + (uint32_t) d->transpose, + bias_type_id + ); + if (bias_v) { + fmt(" mov.b64 %bias_ptr, $v;\n" + " mov.b32 %bias_offset, $u;\n", + a3, + d->b_descr.offset * type_size[bias_v->type]); + } else { + put(" mov.b64 %bias_ptr, 0;\n" + " mov.b32 %bias_offset, 0;\n"); + } + + if (reg_count) { + put(" call ("); + put_elem(v, reg_count, false, true); + fmt("), _optix_matvecmul_$uxi32, (%out_type, %out_size, " + "%in_type, %in_size, %in_interp, %mat_n, " + "%mat_k, %mat_ptr, %mat_offset, " + "%mat_stride, %mat_layout, %mat_transpose, " + "%mat_type, %bias_ptr, %bias_offset, " + "%bias_type, ", + reg_count); + put_elem(a1, reg_count, false); + put(");\n"); + } else { + declare_buffer("src", a1); + declare_buffer("dst", v); + copy_to_buffer("src", a1); + put(" call (), _optix_matvecmul_ptr, (%out_type, %out_size, " + "%in_type, %in_size, %in_interp, %mat_n, " + "%mat_k, %mat_ptr, %mat_offset, " + "%mat_stride, %mat_layout, %mat_transpose, " + "%mat_type, %bias_ptr, %bias_offset, " + "%bias_type, %src, %dst);\n"); + + copy_from_buffer("dst", v); + } + } + break; + + default: + jitc_fail("jitc_optix_render_coop_vec(): unhandled variable kind \"%s\"!", + var_kind_name[(uint32_t) v->kind]); + } + put(" }\n"); +} diff --git a/src/optix_coop_vec.h b/src/optix_coop_vec.h new file mode 100644 index 00000000..b7372f45 --- /dev/null +++ b/src/optix_coop_vec.h @@ -0,0 +1,25 @@ +/* + src/optix_coop_vec.h -- OptiX code generation for Cooperative Vectors + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#pragma once + +extern void jitc_optix_render_coop_vec(const Variable *v, const Variable *a0, + const Variable *a1, const Variable *a2, + const Variable *a3); +extern void jitc_optix_render_coop_vec_unpack(const Variable *v, + const Variable *a0); +extern void jitc_optix_render_coop_vec_accum(const Variable *v, + const Variable *a0, + const Variable *a1, + const Variable *a2); +extern void jitc_optix_render_coop_vec_outer_product_accum(const Variable *v, + const Variable *a0, + const Variable *a1, + const Variable *a2, + const Variable *a3); diff --git a/src/optix_core.cpp b/src/optix_core.cpp index d5809974..738e4fd0 100644 --- a/src/optix_core.cpp +++ b/src/optix_core.cpp @@ -14,17 +14,14 @@ #define DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION_ON #endif -#define jitc_optix_check(err) jitc_optix_check_impl((err), __FILE__, __LINE__) -extern void jitc_optix_check_impl(OptixResult errval, const char *file, const int line); - static bool jitc_optix_cache_hit = false; static bool jitc_optix_cache_global_disable = false; void jitc_optix_log(unsigned int level, const char *tag, const char *message, void *) { - size_t len = strlen(message); - if (level <= (uint32_t) state.log_level_stderr) - fprintf(stderr, "jit_optix_log(): [%s] %s%s", tag, message, - (len > 0 && message[len - 1] == '\n') ? "" : "\n"); + // Note: cannot use jitc_var_log here. Parallel OptiX compilation may enter this + // region from another thread, causing deadlocks (with the Dr.Jit-Core lock + Python GIL) + if (level <= (uint32_t) std::max(state.log_level_callback, state.log_level_stderr)) + fprintf(stderr, "jit_optix_log(): [%s] %s", tag, message); if (strcmp(tag, "DISKCACHE") == 0 && strncmp(message, "Cache miss for key", 18) == 0) @@ -49,8 +46,7 @@ static OptixPipelineCompileOptions jitc_optix_default_compile_options() { #ifndef DRJIT_ENABLE_OPTIX_DEBUG_VALIDATION_ON pco.exceptionFlags = OPTIX_EXCEPTION_FLAG_NONE; #else - pco.exceptionFlags = OPTIX_EXCEPTION_FLAG_DEBUG | - OPTIX_EXCEPTION_FLAG_TRACE_DEPTH | + pco.exceptionFlags = OPTIX_EXCEPTION_FLAG_TRACE_DEPTH | OPTIX_EXCEPTION_FLAG_STACK_OVERFLOW; #endif @@ -258,6 +254,7 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size, OptixDeviceContext &optix_context = state.devices[ts->device].optix_context; OptixPipelineData &pipeline = *ts->optix_pipeline; +#if 1 // Parallel compilation OptixTask task; error_log[0] = '\0'; int rv = optixModuleCreateWithTasks( @@ -291,6 +288,11 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size, ); }; execute_task(task); +#else // Fallback: serial compilation + int rv = optixModuleCreate( + optix_context, &mco, &pipeline.compile_options, buf, buf_size, + error_log, &log_size, &kernel.optix.mod); +#endif int compilation_state = 0; jitc_optix_check( diff --git a/src/record_ts.cpp b/src/record_ts.cpp index 1d664d88..a515e983 100644 --- a/src/record_ts.cpp +++ b/src/record_ts.cpp @@ -1911,6 +1911,13 @@ uint32_t RecordThreadState::capture_call_offset(const void *ptr, size_t dsize) { return slot; } +void RecordThreadState::coop_vec_pack(uint32_t count, const void *in, + const MatrixDescr *in_d, void *out, + const MatrixDescr *out_d) { + (void) count; (void) in; (void) in_d; (void) out; (void) out_d; + jitc_raise("drjit.[un]pack(): currently not supported within frozen functions."); +} + /** * This function tries to capture a variable that is not known to the * recording \c ThreadState. @@ -2103,7 +2110,7 @@ struct DisabledThreadState : ThreadState { */ void record_exception() { m_raised = true; - }; + } /** * Actually throws the exception, if any was thrown during recording. @@ -2123,69 +2130,72 @@ struct DisabledThreadState : ThreadState { } } - void barrier() override { record_exception(); }; + void barrier() override { record_exception(); } Task *launch(Kernel /*kernel*/, KernelKey * /*key*/, XXH128_hash_t /*hash*/, uint32_t /*size*/, std::vector * /*kernel_params*/, const std::vector * /*kernel_param_ids*/) override { record_exception(); return nullptr; - }; + } void memset_async(void * /*ptr*/, uint32_t /*size*/, uint32_t /*isize*/, const void * /*src*/) override { record_exception(); - }; + } uint32_t compress(const uint8_t * /*in*/, uint32_t /*size*/, uint32_t * /*out*/) override { record_exception(); return 0; - }; + } uint32_t mkperm(const uint32_t * /*values*/, uint32_t /*size*/, uint32_t /*bucket_count*/, uint32_t * /*perm*/, uint32_t * /*offsets*/) override { record_exception(); return 0; - }; + } void memcpy(void * /*dst*/, const void * /*src*/, size_t /*size*/) override { record_exception(); - }; + } void memcpy_async(void * /*dst*/, const void * /*src*/, size_t /*size*/) override { record_exception(); - }; + } void block_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/, uint32_t /*block_size*/, const void * /*in*/, void * /*out*/) override { record_exception(); - }; + } void block_prefix_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/, uint32_t /*block_size*/, bool /*exclusive*/, bool /*reverse*/, const void * /*in*/, void * /*out*/) override { record_exception(); - }; + } void reduce_dot(VarType /*type*/, const void * /*ptr_1*/, const void * /*ptr_2*/, uint32_t /*size*/, void * /*out*/) override { record_exception(); - }; + } void poke(void * /*dst*/, const void * /*src*/, uint32_t /*size*/) override { record_exception(); - }; + } void aggregate(void * /*dst*/, AggregationEntry * /*agg*/, uint32_t /*size*/) override { record_exception(); - }; + } void enqueue_host_func(void (* /*callback*/)(void *), void * /*payload*/) override { record_exception(); - }; + } void notify_expand(uint32_t /*index*/) override {}; void reduce_expanded(VarType /*vt*/, ReduceOp /*reduce_op*/, void * /*data*/, uint32_t /*exp*/, - uint32_t /*size*/) override {}; - void notify_free(const void * /*ptr*/) override {}; + uint32_t /*size*/) override {} + void notify_free(const void * /*ptr*/) override {} + void coop_vec_pack(uint32_t /* count */, const void * /* in */, + const MatrixDescr * /* in_d */, void * /* out */, + const MatrixDescr * /* out_d */) override { } }; void set_disabled_thread_state(ThreadState **ts, JitBackend recording_backend) { diff --git a/src/record_ts.h b/src/record_ts.h index 9216371d..fcaf3b1d 100644 --- a/src/record_ts.h +++ b/src/record_ts.h @@ -428,6 +428,11 @@ struct RecordThreadState : ThreadState { void reduce_expanded(VarType vt, ReduceOp reduce_op, void *data, uint32_t exp, uint32_t size) override; + /// Pack a set of matrices/vectors for use with the cooperative vector API + void coop_vec_pack(uint32_t count,const void *in, + const MatrixDescr *in_d, void *out, + const MatrixDescr *out_d) override; + /** * This function is called every time a pointer is freed using \ref * jitc_free. It records the operation and removes the mapping from that diff --git a/src/strbuf.cpp b/src/strbuf.cpp index 723f2163..a06fa1ee 100644 --- a/src/strbuf.cpp +++ b/src/strbuf.cpp @@ -53,6 +53,12 @@ void StringBuffer::rewind_to(size_t pos) { *m_cur = '\0'; } +void StringBuffer::rewind(size_t rel) { + m_cur -= rel; + if (m_start != m_end) + *m_cur = '\0'; +} + void StringBuffer::move_suffix(size_t suffix_start, size_t suffix_target) { size_t buffer_size = size(), suffix_size = buffer_size - suffix_start; diff --git a/src/strbuf.h b/src/strbuf.h index 9e54636a..b0f67906 100644 --- a/src/strbuf.h +++ b/src/strbuf.h @@ -77,6 +77,9 @@ struct StringBuffer { */ void rewind_to(size_t pos); + /// Like \ref rewind_to(), but relative to the end of the buffer + void rewind(size_t rel); + /// Delete trailing spaces and commas void delete_trailing_commas(); diff --git a/src/var.cpp b/src/var.cpp index 3a1029c4..cd1c6717 100644 --- a/src/var.cpp +++ b/src/var.cpp @@ -14,6 +14,7 @@ #include "eval.h" #include "util.h" #include "op.h" +#include "coop_vec.h" #include "registry.h" #include "llvm.h" @@ -178,7 +179,7 @@ const char *var_kind_name[(int) VarKind::Count] { "rcp", "rcp.approx", "rsqrt.approx", // Multi-function generator (CUDA) - "sin", "cos", "exp2", "log2", + "sin", "cos", "exp2", "log2", "tanh", // Casts "cast", "bitcast", @@ -271,7 +272,20 @@ const char *var_kind_name[(int) VarKind::Count] { "array_read", // Write an element to a variable array - "array_write" + "array_write", + + // Cooperative vector API + "coop_vec_literal", + "coop_vec_pack", + "coop_vec_unpack", + "coop_vec_load", + "coop_vec_cast", + "coop_vec_unary_op", + "coop_vec_binary_op", + "coop_vec_ternary_op", + "coop_vec_mat_vec", + "coop_vec_accum", + "coop_vec_outer_product_accum" }; /// Temporary string buffer for miscellaneous variable-related tasks @@ -840,6 +854,10 @@ uint32_t jitc_var_pointer(JitBackend backend, const void *value, std::swap(scope_backup, ts->scope); uint32_t result = jitc_var_new(v); std::swap(scope_backup, ts->scope); + + jitc_log(Debug, "jit_var_pointer(): pointer r%u = " DRJIT_PTR " (r%u, write=%i)", + result, (uintptr_t) value, dep, write); + return result; } @@ -886,10 +904,16 @@ uint32_t jitc_var_call_input(uint32_t index) { v2.size = 1; bool optimize = jitc_flags() & (uint32_t) JitFlag::OptimizeCalls; + bool disable_lvn = !optimize; if (v->is_literal() && optimize) { v2.kind = (uint32_t) VarKind::Literal; v2.literal = v->literal; + // Temporarily stash the size here (subsequently read in call.cpp) + // Will have to be redesigned if the 'unused' field is ever used + // for another purpose. + v2.unused = v->size; + disable_lvn = true; return jitc_var_new(v2); } else { v2.kind = (uint32_t) VarKind::CallInput; @@ -898,7 +922,7 @@ uint32_t jitc_var_call_input(uint32_t index) { jitc_var_inc_ref(index); } - return jitc_var_new(v2, !optimize); + return jitc_var_new(v2, disable_lvn); } uint32_t jitc_new_scope(JitBackend backend) { @@ -1700,6 +1724,8 @@ uint32_t jitc_var_resize(uint32_t index, size_t size) { v2.symbolic = v->symbolic; v2.size = (uint32_t) size; v2.dep[0] = index; + v2.coop_vec = v->coop_vec; + v2.array_length = v->array_length; jitc_var_inc_ref(index, v); result = jitc_var_new(v2, true); } @@ -2825,6 +2851,18 @@ const char *jitc_var_graphviz() { } var_buffer.put(";\n"); } + + switch ((VarKind) v.kind) { + case VarKind::CoopVecPack: { + CoopVecPackData *cvid = (CoopVecPackData *) v.data; + for (uint32_t index2 : cvid->indices) + var_buffer.fmt(" %u -> %zu", index2, index); + } + break; + + default: + break; + } } var_buffer.put( diff --git a/tests/optix_stubs.h b/tests/optix_stubs.h index b518bb35..d5b66e2d 100644 --- a/tests/optix_stubs.h +++ b/tests/optix_stubs.h @@ -173,6 +173,8 @@ struct OptixPipelineCompileOptions { unsigned int exceptionFlags; const char* pipelineLaunchParamsVariableName; unsigned int usesPrimitiveTypeFlags; + int allowOpacityMicromaps; + int allowClusteredGeometry; // OptiX 9.0 ABI }; struct OptixAccelEmitDesc {