Skip to content

Cooperative Vector API #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

Cooperative Vector API #384

wants to merge 9 commits into from

Conversation

wjakob
Copy link
Member

@wjakob wjakob commented Apr 15, 2025

This feature adds cooperative vector support to Dr.Jit. They enable efficient compilation and evaluation of expressions involving matrix multiplication and cater to situations where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways.

On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.).

For further details on this new API and now to use it, refer to the documentation:

@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 7c65d4b to cd67909 Compare April 15, 2025 10:18
@wjakob wjakob requested a review from merlinND April 15, 2025 12:41
@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 89499a0 to 0247195 Compare April 16, 2025 05:07
Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 1 (three big files left to review)

a, b = t(1), t(2)
dr.enable_grad(a, b)
z = nn.CoopVec(a, b) # pack
assert dr.grad_enabled(z)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also test that dr.enable_grad(z) raises as expected

dr.schedule(x.grad, y.grad)
assert x.grad == 4
assert y.grad == 5
assert dr.grad_enabled(z)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also test dr.detach(z)

z + 3
)
b = nn.cast(a, dr.float32_array_t(t))
c = nn.cast(b, dr.float16_array_t(t))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test grad enabled / disabled and grad propagation through casts?
Ideally, gradients would just be converted to the new precision. It's an important use-case to have most of a differentiable pipeline in fp32 and locally convert to fp16 for the MLP.

@merlinND
Copy link
Member

merlinND commented Apr 17, 2025

One thing that will come up when we add the hash grid encoding, but good to keep in mind in general: atomic addition of f16 values is much slower than f16x2. I think it should be fairly easy to add a special case for f16 in the scatter_packet implementation?
(Not really for this PR, just to keep in mind for later)

@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 2d46aae to f80af5a Compare April 21, 2025 14:44
Wenzel Jakob and others added 7 commits April 22, 2025 07:55
Cooperative vectors enable efficient compilation and evaluation of
expressions involving matrix multiplication. They cater to a specific
use case, where each execution thread performs a sequence of independent
multiplications by reasonably small matrices (e.g., 64x64). This enables
the fully fused evaluation of small multilayer perceptrons within a
larger program. That said, the feature isn't specific to MLPs and could
also be used in other ways.

On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX
cooperative vector API leveraging the builtin tensor core for
acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative
vector operations using available instruction set extensions (AVX512,
NEON, etc.).

For further details on this new API and now to use it, refer to the
documentation in ``docs/coop_vec.rst``.
This commit improves handling of evaluated loops with grad-enabled state
variables. Previously, the AD variable ID of each differentiable state
variable changed in every iteration, even if the loop did not touch that
variable. This is an implementation detail of the loop evaluation code,
that should, however, not leak into user code. This commit fixes this
behavior.
This commit fixes bugs in the compilation of reverse-mode derivatives of
simple loops (i.e, loops with max_iterations==-1) and updates the test
suite to cover problematic cases.
This commit fixes bugs and adds tests to ensure that matrix
multiplication can be correctly differentiated in reverse-mode when it
occurs inside a "simple" loop (i.e., a loop with max_iterations==-1).
@wjakob
Copy link
Member Author

wjakob commented Apr 22, 2025

One thing that will come up when we add the hash grid encoding, but good to keep in mind in general: atomic addition of f16 values is much slower than f16x2. I think it should be fairly easy to add a special case for f16 in the scatter_packet implementation?
(Not really for this PR, just to keep in mind for later)

Dr.Jit-Core always generates the f16x2 assembly operation, even when only scattering a single f16 value. In the case of your hash grid, would it be possible to make use of the f16x2 format to scatter two values at once?

Right now, packet atomics are ignored by the CUDA backend. I think that Blackwell is the first consumer architecture that really supports these besides the f16x2 special case. In any case, such changes are out of scope for this already very big PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants