Skip to content

Cooperative vector API #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

Cooperative vector API #141

wants to merge 9 commits into from

Conversation

wjakob
Copy link
Member

@wjakob wjakob commented Apr 18, 2025

This PR adds cooperative vector code generation support to Dr.Jit-Core. The main documentation and test suite of this feature is in the Dr.Jit parent project.

@wjakob
Copy link
Member Author

wjakob commented Apr 18, 2025

@merlinND I forgot to create a separate PR for this part, done now.

Wenzel Jakob and others added 7 commits April 22, 2025 07:55
This commit adds cooperative vector code generation support to
Dr.Jit-Core. The main documentation and test suite of this feature is in
the Dr.Jit parent project.
Dr.Jit previously chose the lowest possible PTX version for each compute
capability, but this ended up being too restrictive. It now ships a
table containing a full driver version -> PTX version mapping and then
searches it for the highest possible PTX version.
Dr.Jit can elide scatter operations when their result can no longer be
referenced by any other operations. The logic to do so, and when
reference count decreases are needed, was dispersed throughout
``eval.cpp``. This commit simplifies the underlying code.
Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 1 (missing llvm/cuda_coop_vec.cpp).

uint32_t rows, cols; //< Shape of the matrix
uint32_t offset; //< Offset from the beginning of the buffer (in elements)
uint32_t stride; //< Row stride (in elements)
uint32_t size; //< Total size (in elements)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make a note that in some layouts, the size in number of elements may be larger than the actual number of elements of the matrix? (e.g. when there's mandatory padding, etc).

uint32_t out,
const MatrixDescr *out_descr);

/// Query the backend to compute the size of an array/vector in a given layout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify which fields will be ignored in in, and which fields will be filled-in in the output. Are the fields that are not computed by this function copied over from in into the return value?

MatrixLayout layout,
uint32_t offset);

/// Perform a matrix-vector multiplication + bias addition
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify which arguments are optional (e.g. b_index = 0 for no bias, if allowed?)

int transpose);

/// Accumulate the coop. vector 'index' into the buffer 'target' with offset.
/// Potentially create a new buffer of size 'size' if target == 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Potentially create a new buffer of size 'size' if target == 0.
/// Potentially create a new buffer of size 'target_size' if target == 0.


~CoopVecPackData() {
for (uint32_t index: indices)
jitc_var_dec_ref(index);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the variables don't have their ref count increased when building a CoopVecPackData, but gets decreased upon destruction.

If it's not too heavy, would it be worth putting CoopVecPackData construction behind a static method CoopVecPackData::steal(indices) so that the ownership is clear?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative, I think there was a small specialized DrJit vector (list) class that was automatically increasing the ref count?


if (!(a0_v->size == a1_v->size || a1_v->size == 1 || a0_v->size == 1))
jitc_raise(
"jit_coop_vec_binary_op(): incompatible thread count (%u and %u)!",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"jit_coop_vec_binary_op(): incompatible thread count (%u and %u)!",
"jit_coop_vec_binary_op(): incompatible width (%u and %u)!",

To stay consistent with dr::width()?

!(a1_v->size == max_size || a1_v->size == 1) ||
!(a2_v->size == max_size || a2_v->size == 1))
jitc_raise(
"jit_coop_vec_ternary_op(): incompatible thread count (%u, %u, and %u)!",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"jit_coop_vec_ternary_op(): incompatible thread count (%u, %u, and %u)!",
"jit_coop_vec_ternary_op(): incompatible width (%u, %u, and %u)!",


if (!supported)
jitc_raise("jit_coop_vec_matvec(): incompatible input types "
"(currently, only float16 is supported on the CUDA/OptiX)!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"(currently, only float16 is supported on the CUDA/OptiX)!");
"(currently, only float16 is supported on the CUDA/OptiX backend,"
" and only float16 and float32 on the LLVM backend.");

Because people might try float64.

Comment on lines +564 to +571
jitc_var(result)->data = cvmvd.get();
jitc_var_set_callback(
result,
[](uint32_t, int free, void *p) {
if (free)
delete (CoopVecMatVecData *) p;
},
cvmvd.release(), true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would make sense to refactor this common pattern to a helper like jitc_var_new_with_payload<T>(...)?


void *p = nullptr;
Ref tmp = steal(jitc_var_data(target, false, &p));
Ref target_ptr = steal(jitc_var_pointer(backend, p, tmp, 1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the current implementation support multiple outer_product_accumulate to the same matrix?
It's quite common to evaluate the same MLP multiple times in the same forward pass (e.g. taking finite difference gradients of a neural field), so in the backward pass there would be multiple outer products targeting the same matrix that we want to keep in the same kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants