Skip to content

Commit e19530d

Browse files
Wenzel Jakobwjakob
authored andcommitted
Cooperative Vector API
Cooperative vectors enable efficient compilation and evaluation of expressions involving matrix multiplication. They cater to a specific use case, where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways. On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.). For further details on this new API and now to use it, refer to the documentation in ``docs/coop_vec.rst``. Fix derivative of ``nn.matmul()`` in simple symbolic loops This commit fixes bugs and adds tests to ensure that matrix multiplication can be correctly differentiated in reverse-mode when it occurs inside a "simple" loop (i.e., a loop with max_iterations==-1).
1 parent 01ef10e commit e19530d

35 files changed

+3823
-155
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ if (DRJIT_ENABLE_JIT)
108108
set_target_properties(nanothread PROPERTIES ${DRJIT_OUTPUT_DIRECTORY})
109109
endif()
110110

111-
mark_as_advanced(NANOTHREAD_ENABLE_TESTS)
111+
mark_as_advanced(NANOTHREAD_ENABLE_TESTS NANOTHREAD_STATIC)
112112
mark_as_advanced(DRJIT_CORE_ENABLE_TESTS)
113113
mark_as_advanced(NB_TEST NB_TEST_SHARED_BUILD NB_TEST_STABLE_ABI NB_USE_SUBMODULE_DEPS NB_TEST_SANITZE NB_CREATE_INSTALL_RULES nanobind_DIR)
114114
mark_as_advanced(NB_TEST_CUDA NB_TEST_FREE_THREADED NB_TEST_SANITIZERS_ASAN NB_TEST_SANITIZERS_TSAN NB_TEST_SANITIZERS_UBSAN)

docs/autodiff.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ Dr.Jit how a particular operation should be differentiated. Reasons for this
427427
may include:
428428

429429
- The automatic differentiation backend cannot keep track of computation
430-
performed outside of Dr.Jit (e.g. using a highly optimized :ref:`CUDA kernel
431-
<custom-cuda>`). In this case, review the section on :ref:`interoperability
430+
performed outside of Dr.Jit (e.g. using custom CUDA kernels). In this case,
431+
review the section on :ref:`interoperability
432432
<interop>`, since it presents a potentially simpler solution.
433433

434434
- The derivative may admit a simplified analytic expression that is superior to

docs/changelog.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ Here is what's new:
348348

349349

350350
⚠️ Compatibility ⚠️
351-
------------------
351+
-------------------
352352

353353
- **Symbolic loop syntax**: the old "recorded loop" syntax is no longer
354354
supported. Existing code will need adjustments to use

docs/coop_vec.rst

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
.. py:currentmodule:: drjit
2+
3+
.. cpp:namespace:: drjit
4+
5+
.. _coop_vec:
6+
7+
Cooperative vectors
8+
===================
9+
10+
*Cooperative vectors* are a `new API
11+
<https://github.com/KhronosGroup/GLSL/blob/main/extensions/nv/GLSL_NV_cooperative_vector.txt>`__
12+
for evaluating matrix-vector products in certain types of GPU workloads. They
13+
are designed to handle cases, where each thread of a parallel program needs
14+
to multiply a vector by a reasonably small matrix (e.g., 64x64 or fewer
15+
entries). By working together, the threads can perform these multiplications
16+
more efficiently, which is why the approach is called *cooperative*.
17+
18+
Cooperative vectors are especially useful for evaluating small `multilayer
19+
perceptrons <https://en.wikipedia.org/wiki/Multilayer_perceptron>`__ (MLPs)
20+
within larger programs while fully *fusing* all steps of the process into a
21+
single kernel. Other workloads that heavily rely on matrix-vector products may
22+
benefit as well.
23+
24+
Dr.Jit supports cooperative vectors on both of its backends:
25+
26+
- On **NVIDIA GPUs (Turing or newer)**, cooperative vectors map to the OptiX
27+
`cooperative vector API
28+
<https://raytracing-docs.nvidia.com/optix9/guide/index.html#cooperative_vectors#neural-rendering-with-cooperative-vectors>`__,
29+
leveraging built-in `tensor cores
30+
<https://www.nvidia.com/en-us/data-center/tensor-cores/>`__ for acceleration.
31+
Driver version R570 or newer is required to use this feature.
32+
33+
- On the **CPU (LLVM) backend**, compilation of cooperative vector operations
34+
targets the available instruction set extensions (AVX512, NEON, etc.).
35+
36+
Code snippets in the remainder of this section assume the following include
37+
directives:
38+
39+
.. code-block:: python
40+
41+
import drjit as dr
42+
import drjit.nn as nn
43+
from drjit.auto.ad import Float16, TensorXf16
44+
45+
Motivation
46+
----------
47+
48+
The cooperative vector API is available via the :py:mod:`drjit.nn` submodule.
49+
Below is an example demonstrating how to use it to perform a matrix
50+
multiplication.
51+
52+
.. code-block:: python
53+
54+
# Matrix shape
55+
m, n = 3, 16
56+
57+
# Create a random matrix + offset
58+
A = dr.normal(TensorXf, (m, n))
59+
b = dr.rand(TensorXf, m)
60+
61+
# Pack 'A' and 'b' into a buffer with an optimal layout
62+
buffer, A_view, b_view = nn.pack(A, b)
63+
64+
# Create a cooperative vector
65+
x = nn.CoopVec(... 16 values ...)
66+
67+
# Evaluate A @ x + b
68+
v_out = nn.matvec(A_view, v_in, b_view)
69+
70+
# Unpack the resulting cooperative vector
71+
x, y, z = v_out
72+
73+
This involves the following steps:
74+
75+
- Initializing matrix data and packing it into an optimized memory layout using
76+
:py:func:`nn.pack() <drjit.nn.pack>`.
77+
78+
- Constructing a :py:class:`nn.CoopVec` containing the inputs to the matrix
79+
multiplication.inputs.
80+
81+
- Performing one or more matrix-vector multiplications and other arithmetic,
82+
while keeping the state in cooperative vector form.
83+
84+
- Unpacking the final cooperative vector into regular Dr.Jit arrays.
85+
86+
Cooperative vectors
87+
-------------------
88+
89+
The central type of this API is the *cooperative vector* class
90+
:py:class:`nn.CoopVec`. This is a dynamically sized vector with uniformly
91+
typed elements.
92+
93+
Unlike regular Dr.Jit arrays (e.g. :py:class:`drjit.cuda.ArrayXf`), cooperative
94+
vectors *do not allow indexed element access*. For example, the following
95+
operation raises an exception:
96+
97+
.. code-block:: pycon
98+
99+
>>> vec = nn.CoopVec(Float16(1), Float16(2))
100+
>>> vec[1]
101+
Traceback (most recent call last):
102+
File "<stdin>", line 1, in <module>
103+
TypeError: 'drjit.nn.CoopVec' object is not subscriptable
104+
105+
This restriction exists because the compiler may arbitrarily distribute
106+
cooperative vector components across threads for efficiency. Allowing direct
107+
indexing would interfere with this optimization.
108+
109+
The :py:class:`drjit.nn.CoopVec` constructor accepts an arbitrary sequence
110+
of :ref:`PyTrees <pytrees>` containing Dr.Jit array and Python scalars and
111+
flattens them into a cooperative vector:
112+
113+
.. code-block:: python
114+
115+
vec = nn.CoopVec( # Construct a 4D vector
116+
Float16(1),
117+
3.0,
118+
Array2f(4, 5)
119+
)
120+
121+
Use the standard Python unpacking syntax to turn cooperative vectors back into
122+
their components:
123+
124+
.. code-block:: python
125+
126+
x, y, z = vec # Unpack a cooperative 3D vector
127+
x, y, *extra = vec # Unpack first 2 components, put rest into 'extra'
128+
129+
The same syntax can also be used to concatenate vectors:
130+
131+
.. code-block:: python
132+
133+
vec_3 = nn.CoopVec(*vec_1, *vec_2)
134+
135+
Cooperative vectors can also be converted into nested arrays, tensors, or
136+
Python lists:
137+
138+
.. code-block:: python
139+
140+
vec_arr = Array3f(vec)
141+
vec_ten = TensorXf(vec)
142+
vec_lst = list(vec)
143+
144+
Cooperative vectors are compatible with Dr.Jit's symbolic tracing
145+
infrastructure and may be used as state variables in
146+
:py:func:`drjit.while_loop` and :py:func:`drjit.if_stmt`.
147+
148+
Arithmetic
149+
^^^^^^^^^^
150+
151+
Cooperative vectors support a restricted set of arithmetic operations:
152+
153+
- Elementary arithmetic operations: ``+``, ``-``, ``*`` (but no division)
154+
- :py:func:`dr.fma() <fma>`,
155+
- :py:func:`dr.minimum() <minimum>`, :py:func:`dr.maximum() <maximum>`,
156+
- :py:func:`dr.log2() <log2>`, :py:func:`dr.exp2() <exp2>`,
157+
- :py:func:`dr.tanh() <tanh>`,
158+
- :py:func:`dr.step() <step>`.
159+
- :py:func:`nn.matvec() <drjit.nn.matvec>`
160+
161+
These operations directly map to hardware-optimized operations on CUDA/OptiX.
162+
Operations outside of this set can be realized via unpacking/repacking, e.g.:
163+
164+
.. code-block::
165+
166+
x : nn.CoopVec = ...
167+
y = nn.CoopVec(dr.sin(v) for v in x)
168+
169+
However, this may degrade performance. It is best to keep cooperative vectors
170+
in their opaque layout whenever possible.
171+
172+
Arithmetic operations may mix cooperative vectors and regular Dr.Jit arrays or
173+
Python scalars, which will undergo implicit broadcasting.
174+
175+
.. code-block::
176+
177+
x: nn.CoopVec[dr.cuda.Float16] = ...
178+
y: dr.cuda.Float16 = ...
179+
z = dr.maximum(x, 0) + y
180+
181+
.. _matrix_views:
182+
183+
Matrix views
184+
------------
185+
186+
Input matrices and bias vectors should generally be converted into a
187+
hardware-dependent layout to improve performance compared to the default
188+
row-major representation (also, many operations raise exceptions on the
189+
OptiX/CUDA backend when matrices are not in such an optimal layout).
190+
191+
The function :py:func:`nn.pack() <drjit.nn.pack>` performs this conversion and
192+
furthermore packs data into a shared buffer for optimal efficiency. The
193+
function takes an arbitrary sequence of :ref:`PyTrees <pytrees>` as input and
194+
returns a result with the same structure.
195+
196+
.. code-block:: python
197+
198+
A: TensorXf = ...
199+
b: Float = ...
200+
A_view, b_view = nn.pack(A, b, layout='inference')
201+
202+
Every Dr.Jit array or tensor will be replaced by a
203+
:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer
204+
annotated with layout and type metadata. The function can generate optimal
205+
memory layouts for either *inference* (the default) or *training*. You must
206+
specify ``layout='training'`` if you wish to differentiate matrix
207+
multiplication in reverse mode.
208+
209+
Following this step, ``A`` and ``b`` have been merged into ``buffer``, and
210+
``A_view`` and ``b_view`` encode the offset and layout within this larger
211+
buffer. Matrix views *cannot* be used in arithmetic expressions and are best
212+
thought of as opaque handles. They only exist to describe the input of the
213+
matrix-vector multiplication operation explained next.
214+
215+
Two other view-related operations be useful in certain situations, please
216+
see the linked documentation for details.
217+
218+
- :py:func:`drjit.nn.unpack` converts optimal-layout data back into a row-major layout.
219+
- :py:func:`drjit.nn.view` creates row-major views.
220+
221+
Matrix-vector products
222+
----------------------
223+
224+
The main purpose of cooperative vectors is the matrix-vector multiplication
225+
operation :py:func:`nn.matvec() <drjit.nn.matvec>`:
226+
227+
.. code-block:: python
228+
229+
y = nn.matvec(A, x, b) # Compute y = A @ x + b
230+
231+
Here,
232+
233+
- ``A`` and ``b`` are *views* (:py:class:`nn.MatrixView`) created by
234+
:py:func:`nn.pack() <drjit.nn.pack>` or :py:func:`nn.view()
235+
<drjit.nn.view>`.
236+
- ``x`` and ``y`` are cooperative vectors. They are interpreted as *column
237+
vectors*, i.e., ``y = A[:, 0] * x[0] + A[:, 1] * x[1] + ... + b``.
238+
- the ``b`` term is optional.
239+
240+
The function also accepts an optional ``transpose=True`` parameter to compute
241+
:math:`A^Tx + b`.
242+
243+
The standard Python ``A @ x`` and ``A.T @ x`` matrix multiplication syntax
244+
works as well. However, if your computation requires the addition of a ``b``
245+
vector, prefer :py:func:`nn.matvec() <drjit.nn.matvec>` over this syntax, since
246+
it merges both steps into a single operation.
247+
248+
Differentiation
249+
---------------
250+
251+
Cooperative vectors support automatic differentiation. Simply pack variables
252+
with tracked gradients into cooperative vectors---the system will then
253+
propagate derivatives through subsequent operations. Here is an example:
254+
255+
.. code-block:: python
256+
257+
# Differentiable input
258+
a = Array2f16(..)
259+
dr.enable_grad(a)
260+
261+
# Differentiable matrix + bias vector
262+
buffer, A_view, b_view = nn.pack(A, b)
263+
dr.enable_grad(buffer)
264+
265+
# Pack grad-enabled variables into a cooperative vector
266+
x = nn.CoopVec(a)
267+
268+
# Differentiable matrix-vector multiplication
269+
y = dr.matvec(A_view, x, b_view)
270+
271+
r0, r1 = y # Unpack
272+
loss = r0**2 + r1**2 # Continue calculation and ..
273+
dr.backward_from(loss) # .. eventually backpropagate
274+
275+
Specific views or cooperative vectors can also be detached via
276+
:py:func:`drjit.detach()` to inhibit gradient propagation, e.g.:
277+
278+
.. code-block:: python
279+
280+
y = nn.matvec(A_view, dr.detach(x), dr.detach(b_view))
281+
282+
Note that the conversion functions :py:func:`nn.pack() <drjit.nn.pack()>` and
283+
:py:func:`nn.unpack() <drjit.nn.unpack()>` are *not differentiable*. This is
284+
intentional: to train a neural network, convert the initial coefficient values
285+
into training-optimal layout and optimize this representation directly. Doing
286+
so is more efficient than changing layouts twice in every optimization step
287+
(once for the weights and once for their derivatives).
288+
289+
The following AD operations recognize :py:func:`nn.CoopVec
290+
<drjit.nn.CoopVec>` and :py:func:`nn.MatrixView <drjit.nn.MatrixView>` objects:
291+
292+
- :py:func:`grad_enabled`, :py:func:`enable_grad`, :py:func:`disable_grad`.
293+
- :py:func:`detach`.
294+
295+
Performance considerations
296+
--------------------------
297+
298+
- **CUDA/OptiX** backend:
299+
300+
- :py:func:`nn.matvec() <drjit.nn.matvec>` currently requires 16-bit
301+
floating point arguments. FP8 formats may be added in the future.
302+
303+
- Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column
304+
counts are not a multiples of 8 or 16 will be zero-padded internally. There
305+
is no performance benefit in working with such intermediate sizes.
306+
307+
Unpacking cooperative vectors may degrade performance. It is best to keep
308+
them in their opaque layout whenever possible.
309+
310+
- **LLVM** backend:
311+
312+
- There is no difference between row-major and training/inference-optimal
313+
layouts on the CPU. However, using :py:func:`nn.pack()
314+
<drjit.nn.pack>` is still recommended, since packing multiple arrays
315+
into a shared buffer has a small performance benefit.
316+
317+
- On Intel-compatible processors, using half precision cooperative vectors is
318+
not recommended. FP16 matrix multiplication requires ``AVX512FP16``, an
319+
extension not yet available on consumer CPUs as of 2025. Without this
320+
extension, FP16 computation involves many costly FP16 ↔ FP32 roundtrips.

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public API.
4646
bench
4747
cpp
4848
textures
49+
coop_vec
50+
nn
4951
faq
5052

5153
.. toctree::

docs/misc.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ resolve at a later point. So here, we have
529529
- ``SelfCp``: a forward reference to ``drjit.llvm.ad._Array2fCp`` (more on this shortly),
530530
- ``ValT``: :py:class:`drjit.llvm.ad.Float`,
531531
- ``ValCpT``: a forward reference to ``drjit.llvm.ad._FloatCp`` (more on this shortly),
532-
- ``RedT``: :py:class`drjit.llvm.ad.Float`,
532+
- ``RedT``: :py:class:`drjit.llvm.ad.Float`,
533533
- ``PlainT``: :py:class:`drjit.llvm.ad.Array2f`, and
534534
- ``MaskT``: :py:class:`drjit.llvm.ad.Array2b`.
535535

0 commit comments

Comments
 (0)