From 136497b0eeaaf857e451cff30bd16dd7ecb1e997 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 18 Mar 2025 22:48:13 +0900 Subject: [PATCH 1/9] linear/sRGB conversion functions --- docs/reference.rst | 2 ++ drjit/__init__.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/docs/reference.rst b/docs/reference.rst index 6d369a13..8bab5507 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -146,6 +146,8 @@ Miscellaneous operations .. autofunction:: binary_search .. autofunction:: make_opaque .. autofunction:: copy +.. autofunction:: linear_to_srgb +.. autofunction:: srgb_to_linear Just-in-time compilation ------------------------ diff --git a/drjit/__init__.py b/drjit/__init__.py index f0bab723..76e9cb0c 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -2647,6 +2647,44 @@ def assert_equal( **kwargs, ) +def srgb_to_linear(x: ArrayT, clip_range: bool = True) -> ArrayT: + """ + Convert a sRGB gamma-corrected intensity value on the interval [0, 1] into + a linear intensity value on the interval [0, 1]. + + Values outside of the range [0, 1] are clipped by default. You may specify + `clip_range=False` to avoid this step if your data is already guranteed to be in + this range. + """ + + if clip_range: + x = clip(x, 0, 1) + + return select( + x < 0.04045, + x / 12.92, + fma(x, 1 / 1.055, 0.055 / 1.055) ** 2.4 + ) + +def linear_to_srgb(x: ArrayT, clip_range: bool = True) -> ArrayT: + """ + Convert a linear intensity value on the interval [0, 1] to into a sRGB + value by applying the underlying gamma correction curve. + + Values outside of the range [0, 1] are clipped by default. You may specify + `clip_range=False` to avoid this step if your data is already guranteed to be in + this range. + """ + + if clip_range: + x = clip(x, 0, 1) + + return select( + x < 0.0031308, + x * 12.92, + fma(1.055, x ** (1.0 / 2.4), -0.055) + ) + newaxis = None From 92576c4207f020271a69047354c2f602fc6003fd Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Thu, 6 Mar 2025 01:15:34 +0900 Subject: [PATCH 2/9] Cooperative Vector API Cooperative vectors enable efficient compilation and evaluation of expressions involving matrix multiplication. They cater to a specific use case, where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways. On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.). For further details on this new API and now to use it, refer to the documentation in ``docs/coop_vec.rst``. --- CMakeLists.txt | 2 +- docs/autodiff.rst | 4 +- docs/changelog.rst | 2 +- docs/coop_vec.rst | 316 ++++++++++++++++ docs/index.rst | 2 + docs/misc.rst | 2 +- docs/nn.rst | 134 +++++++ docs/reference.rst | 122 ++++++- docs/what.rst | 8 +- drjit/__init__.py | 2 +- drjit/nn.py | 482 +++++++++++++++++++++++++ drjit/stubs.pat | 53 ++- ext/drjit-core | 2 +- include/drjit/extra.h | 28 ++ src/extra/autodiff.cpp | 627 ++++++++++++++++++++++++++++++-- src/extra/math.cpp | 21 +- src/python/CMakeLists.txt | 10 +- src/python/apply.cpp | 61 ++-- src/python/autodiff.cpp | 33 ++ src/python/base.cpp | 2 +- src/python/coop_vec.cpp | 740 ++++++++++++++++++++++++++++++++++++++ src/python/coop_vec.h | 83 +++++ src/python/detail.cpp | 9 +- src/python/dlpack.cpp | 4 + src/python/docstr.rst | 195 +++++++++- src/python/eval.cpp | 3 + src/python/init.cpp | 86 ++++- src/python/main.cpp | 2 + src/python/memop.cpp | 39 +- src/python/meta.cpp | 1 - src/python/random.h | 2 +- src/python/reduce.cpp | 5 + src/python/tracker.cpp | 134 ++++--- tests/test_coop_vec.py | 527 +++++++++++++++++++++++++++ 34 files changed, 3596 insertions(+), 147 deletions(-) create mode 100644 docs/coop_vec.rst create mode 100644 docs/nn.rst create mode 100644 drjit/nn.py create mode 100644 src/python/coop_vec.cpp create mode 100644 src/python/coop_vec.h create mode 100644 tests/test_coop_vec.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e9070f1a..3cfc2dae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ if (DRJIT_ENABLE_JIT) set_target_properties(nanothread PROPERTIES ${DRJIT_OUTPUT_DIRECTORY}) endif() -mark_as_advanced(NANOTHREAD_ENABLE_TESTS) +mark_as_advanced(NANOTHREAD_ENABLE_TESTS NANOTHREAD_STATIC) mark_as_advanced(DRJIT_CORE_ENABLE_TESTS) mark_as_advanced(NB_TEST NB_TEST_SHARED_BUILD NB_TEST_STABLE_ABI NB_USE_SUBMODULE_DEPS NB_TEST_SANITZE NB_CREATE_INSTALL_RULES nanobind_DIR) mark_as_advanced(NB_TEST_CUDA NB_TEST_FREE_THREADED NB_TEST_SANITIZERS_ASAN NB_TEST_SANITIZERS_TSAN NB_TEST_SANITIZERS_UBSAN) diff --git a/docs/autodiff.rst b/docs/autodiff.rst index 805ebfea..e3eb7b3a 100644 --- a/docs/autodiff.rst +++ b/docs/autodiff.rst @@ -427,8 +427,8 @@ Dr.Jit how a particular operation should be differentiated. Reasons for this may include: - The automatic differentiation backend cannot keep track of computation - performed outside of Dr.Jit (e.g. using a highly optimized :ref:`CUDA kernel - `). In this case, review the section on :ref:`interoperability + performed outside of Dr.Jit (e.g. using custom CUDA kernels). In this case, + review the section on :ref:`interoperability `, since it presents a potentially simpler solution. - The derivative may admit a simplified analytic expression that is superior to diff --git a/docs/changelog.rst b/docs/changelog.rst index 6f24f3d3..63818329 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -348,7 +348,7 @@ Here is what's new: ⚠️ Compatibility ⚠️ ------------------- +------------------- - **Symbolic loop syntax**: the old "recorded loop" syntax is no longer supported. Existing code will need adjustments to use diff --git a/docs/coop_vec.rst b/docs/coop_vec.rst new file mode 100644 index 00000000..2a5b4489 --- /dev/null +++ b/docs/coop_vec.rst @@ -0,0 +1,316 @@ +.. py:currentmodule:: drjit + +.. cpp:namespace:: drjit + +.. _coop_vec: + +Cooperative vectors +=================== + +*Cooperative vectors* are a `new API +`__ +for evaluating matrix-vector products in certain types of GPU workloads. They +are designed to handle cases, where each thread of a parallel program needs +to multiply a vector by a reasonably small matrix (e.g., 64x64 or fewer +entries). By working together, the threads can perform these multiplications +more efficiently, which is why the approach is called *cooperative*. + +Cooperative vectors are especially useful for evaluating small `multilayer +perceptrons `__ (MLPs) +within larger programs while fully *fusing* all steps of the process into a +single kernel. Other workloads that heavily rely on matrix-vector products may +benefit as well. + +Dr.Jit supports cooperative vectors on both of its backends: + +- On **NVIDIA GPUs (Turing or newer)**, cooperative vectors map to the OptiX + `cooperative vector API + `__, + leveraging built-in `tensor cores + `__ for acceleration. + +- On the **CPU (LLVM) backend**, compilation of cooperative vector operations + targets the available instruction set extensions (AVX512, NEON, etc.). + +Code snippets in the remainder of this section assume the following include +directives: + +.. code-block:: python + + import drjit as dr + import drjit.nn as nn + from drjit.auto.ad import Float16, TensorXf16 + +Motivation +---------- + +The cooperative vector API is available via the :py:mod:`drjit.nn` submodule. +Below is an example demonstrating how to use it to perform a matrix +multiplication. + +.. code-block:: python + + # Matrix shape + m, n = 3, 16 + + # Create a random matrix + offset + A = dr.normal(TensorXf, (m, n)) + b = dr.rand(TensorXf, m) + + # Pack 'A' and 'b' into a buffer with an optimal layout + buffer, A_view, b_view = nn.pack(A, b) + + # Create a cooperative vector + x = nn.CoopVec(... 16 values ...) + + # Evaluate A @ x + b + v_out = nn.matvec(A_view, v_in, b_view) + + # Unpack the resulting cooperative vector + x, y, z = v_out + +This involves the following steps: + +- Initializing matrix data and packing it into an optimized memory layout using + :py:func:`nn.pack() `. + +- Constructing a :py:class:`nn.CoopVec` containing the inputs to the matrix + multiplication.inputs. + +- Performing one or more matrix-vector multiplications and other arithmetic, + while keeping the state in cooperative vector form. + +- Unpacking the final cooperative vector into regular Dr.Jit arrays. + +Cooperative vectors +------------------- + +The central type of this API is the *cooperative vector* class +:py:class:`nn.CoopVec`. This is a dynamically sized vector with uniformly +typed elements. + +Unlike regular Dr.Jit arrays (e.g. :py:class:`drjit.cuda.ArrayXf`), cooperative +vectors *do not allow indexed element access*. For example, the following +operation raises an exception: + +.. code-block:: pycon + + >>> vec = nn.CoopVec(Float16(1), Float16(2)) + >>> vec[1] + Traceback (most recent call last): + File "", line 1, in + TypeError: 'drjit.nn.CoopVec' object is not subscriptable + +This restriction exists because the compiler may arbitrarily distribute +cooperative vector components across threads for efficiency. Allowing direct +indexing would interfere with this optimization. + +The :py:class:`drjit.nn.CoopVec` constructor accepts an arbitrary sequence +of :ref:`PyTrees ` containing Dr.Jit array and Python scalars and +flattens them into a cooperative vector: + +.. code-block:: python + + vec = nn.CoopVec( # Construct a 4D vector + Float16(1), + 3.0, + Array2f(4, 5) + ) + +Use the standard Python unpacking syntax to turn cooperative vectors back into +their components: + +.. code-block:: python + + x, y, z = vec # Unpack a cooperative 3D vector + x, y, *extra = vec # Unpack first 2 components, put rest into 'extra' + +The same syntax can also be used to concatenate vectors: + +.. code-block:: python + + vec_3 = nn.CoopVec(*vec_1, *vec_2) + +Cooperative vectors can also be converted into nested arrays, tensors, or +Python lists: + +.. code-block:: python + + vec_arr = Array3f(vec) + vec_ten = TensorXf(vec) + vec_lst = list(vec) + +Cooperative vectors are compatible with Dr.Jit's symbolic tracing +infrastructure and may be used as state variables in +:py:func:`drjit.while_loop` and :py:func:`drjit.if_stmt`. + +Arithmetic +^^^^^^^^^^ + +Cooperative vectors support a restricted set of arithmetic operations: + +- Elementary arithmetic operations: ``+``, ``-``, ``*`` (but no division) +- :py:func:`dr.fma() `, +- :py:func:`dr.minimum() `, :py:func:`dr.maximum() `, +- :py:func:`dr.log2() `, :py:func:`dr.exp2() `, +- :py:func:`dr.tanh() `, +- :py:func:`dr.step() `. +- :py:func:`nn.matvec() ` + +These operations directly map to hardware-optimized operations on CUDA/OptiX. +Operations outside of this set can be realized via unpacking/repacking, e.g.: + +.. code-block:: + + x : nn.CoopVec = ... + y = nn.CoopVec(dr.sin(v) for v in x) + +However, this may degrade performance. It is best to keep cooperative vectors +in their opaque layout whenever possible. + +Arithmetic operations may mix cooperative vectors and regular Dr.Jit arrays or +Python scalars, which will undergo implicit broadcasting. + +.. code-block:: + + x: nn.CoopVec[dr.cuda.Float16] = ... + y: dr.cuda.Float16 = ... + z = dr.maximum(x, 0) + y + +.. _matrix_views: + +Matrix views +------------ + +Input matrices and bias vectors should generally be converted into a +hardware-dependent layout to improve performance compared to the default +row-major representation (also, many operations raise exceptions on the +OptiX/CUDA backend when matrices are not in such an optimal layout). + +The function :py:func:`nn.pack() ` performs this conversion and +furthermore packs data into a shared buffer for optimal efficiency. The +function takes an arbitrary sequence of :ref:`PyTrees ` as input and +returns a result with the same structure. + +.. code-block:: python + + A: TensorXf = ... + b: Float = ... + A_view, b_view = nn.pack(A, b, layout='inference') + +Every Dr.Jit array or tensor will be replaced by a +:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer +annotated with layout and type metadata. The function can generate optimal +memory layouts for either *inference* (the default) and *training*. You must +specify ``layout='training'`` if you wish to differentiate matrix +multiplication in reverse mode. + +Following this step, ``A`` and ``b`` have been merged into ``buffer``, and +``A_view`` and ``b_view`` encode the offset and layout within this larger +buffer. Matrix views *cannot* be used in arithmetic expressions and are best +thought of as opaque handles. They only exist to describe the input of the +matrix-vector multiplication operation explained next. + +Two other view-related operations be useful in certain situations, please +see the linked documentation for details. + +- :py:func:`drjit.nn.unpack` converts optimal-layout data back into a row-major layout. +- :py:func:`drjit.nn.view` creates row-major views. + +Matrix-vector products +---------------------- + +The main purpose of cooperative vectors is the matrix-vector multiplication +operation :py:func:`nn.matvec() `: + +.. code-block:: python + + y = nn.matvec(A, x, b) # Compute y = A @ x + b + +Here, + +- ``A`` and ``b`` are *views* (:py:class:`nn.MatrixView`) created by + :py:func:`nn.pack() ` or :py:func:`nn.view() + `. +- ``x`` and ``y`` are cooperative vectors. They are interpreted as *column + vectors*, i.e., ``y = A[:, 0] * x[0] + A[:, 1] * x[1] + ... + b``. +- the ``b`` term is optional. + +The function also accepts an optional ``transpose=True`` parameter to compute +:math:`A^Tx + b`. + +The standard Python ``A @ x`` and ``A.T @ x`` matrix multiplication syntax +works as well. However, if your computation requires the addition of a ``b`` +vector, prefer :py:func:`nn.matvec() ` over this syntax, since +it merges both steps into a single operation. + +Differentiation +--------------- + +Cooperative vectors support automatic differentiation. Simply pack variables +with tracked gradients into cooperative vectors---the system will then +propagate derivatives through subsequent operations. Here is an example: + +.. code-block:: python + + # Differentiable input + a = Array2f16(..) + dr.enable_grad(a) + + # Differentiable matrix + bias vector + buffer, A_view, b_view = nn.pack(A, b) + dr.enable_grad(buffer) + + # Pack grad-enabled variables into a cooperative vector + x = nn.CoopVec(a) + + # Differentiable matrix-vector multiplication + y = dr.matvec(A_view, x, b_view) + + r0, r1 = y # Unpack + loss = r0**2 + r1**2 # Continue calculation and .. + dr.backward_from(loss) # .. eventually backpropagate + +Specific views or cooperative vectors can also be detached via +:py:func:`drjit.detach()` to inhibit gradient propagation, e.g.: + +.. code-block:: python + + y = nn.matvec(A_view, dr.detach(x), dr.detach(b_view)) + +Note that the conversion functions :py:func:`nn.pack() ` and +:py:func:`nn.unpack() ` are *not differentiable*. This is +intentional: to train a neural network, convert the initial coefficient values +into training-optimal layout and optimize this representation directly. Doing +so is more efficient than changing layouts twice in every optimization step +(once for the weights and once for their derivatives). + +The following AD operations recognize :py:func:`nn.CoopVec +` and :py:func:`nn.MatrixView ` objects: + +- :py:func:`grad_enabled`, :py:func:`enable_grad`, :py:func:`disable_grad`. +- :py:func:`detach`. + +Performance considerations +-------------------------- + +- **CUDA/OptiX** backend: + + - :py:func:`nn.matvec() ` currently requires 16-bit + floating point arguments. FP8 formats may be added in the future. + + - Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column + counts are not a multiples of 8 or 16 will be zero-padded internally. There + is no performance benefit in working with such intermediate sizes. + +- **LLVM** backend: + + - There is no difference between row-major and training/inference-optimal + layouts on the CPU. However, using :py:func:`nn.pack() + ` is still recommended, since packing multiple arrays + into a shared buffer has a small performance benefit. + + - On Intel-compatible processors, using half precision cooperative vectors is + not recommended. FP16 matrix multiplication requires ``AVX512FP16``, an + extension not yet available on consumer CPUs as of 2025. Without this + extension, FP16 computation involves many costly FP16 ↔ FP32 roundtrips. diff --git a/docs/index.rst b/docs/index.rst index 0b5ac4fe..e10e6611 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,8 @@ public API. bench cpp textures + coop_vec + nn faq .. toctree:: diff --git a/docs/misc.rst b/docs/misc.rst index 9321e6c9..07ccad5d 100644 --- a/docs/misc.rst +++ b/docs/misc.rst @@ -529,7 +529,7 @@ resolve at a later point. So here, we have - ``SelfCp``: a forward reference to ``drjit.llvm.ad._Array2fCp`` (more on this shortly), - ``ValT``: :py:class:`drjit.llvm.ad.Float`, - ``ValCpT``: a forward reference to ``drjit.llvm.ad._FloatCp`` (more on this shortly), -- ``RedT``: :py:class`drjit.llvm.ad.Float`, +- ``RedT``: :py:class:`drjit.llvm.ad.Float`, - ``PlainT``: :py:class:`drjit.llvm.ad.Array2f`, and - ``MaskT``: :py:class:`drjit.llvm.ad.Array2b`. diff --git a/docs/nn.rst b/docs/nn.rst new file mode 100644 index 00000000..aa658594 --- /dev/null +++ b/docs/nn.rst @@ -0,0 +1,134 @@ +.. py:currentmodule:: drjit.nn + +.. _neural_nets: + +Neural Networks +=============== + +Dr.Jit's neural network infrastructure builds on :ref:`cooperative vectors +`. Please review their documentation before reading this section. + +The module :py:mod:`drjit.nn` provides convenient modular abstractions to +construct, evaluate, and optimize neural networks. Their design resembles the +PyTorch `nn.Module +`__ classes. +The Dr.Jit :py:class:`nn.Module ` class takes a cooperative vector as input +and produces another cooperative vector. Modules can be chained to form longer +sequential pipelines. + +.. warning:: + + The neural network classes are experimental and subject to change in future + versions of Dr.Jit. + +List +---- + +The set of neural network module currently includes: + +- Sequential evaluation of a list of models: :py:class:`nn.Sequential `. + +- Linear and affine layers: :py:class:`nn.Linear `. + +- Encoding layers: :py:class:`nn.SinEncode `, :py:class:`nn.TriEncode `. + +- Activation functions and other nonlinear transformations: :py:class:`nn.ReLU `, :py:class:`nn.LeakyReLU `, + :py:class:`nn.Exp `, :py:class:`nn.Exp2 `, :py:class:`nn.Tanh `. + +- Miscellaneous: :py:class:`nn.Cast `, :py:class:`nn.ScaleAdd `. + +Example +------- + +Below is a fully worked out example demonstrating how to use it to declare and +optimize a small `multilayer perceptron +`__ (MLP). This network +implements a 2D neural field (right) that we then fit to a low-resolution image of `The +Great Wave off Kanagawa +`__ (left). + +.. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/coopvec-screenshot.png + :width: 600 + :align: center + +The optimization uses the *Adam* optimizer (:py:class:`dr.opt.Adam +`) optimizer and a *gradient scaler* +(:py:class:`dr.opt.GradScaler `) for adaptive +mixed-precision training. + +.. code-block:: python + + from tqdm.auto import tqdm + import imageio.v3 as iio + import drjit as dr + import drjit.nn as nn + from drjit.opt import Adam, GradScaler + from drjit.auto.ad import Texture2f, TensorXf, TensorXf16, Float16, Float32, Array2f, Array3f + + # Load a test image and construct a texture object + ref = TensorXf(iio.imread("https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/wave-128.png") / 256) + tex = Texture2f(ref) + + # Ensure consistent results when re-running the following + dr.seed(0) + + # Establish the network structure + net = nn.Sequential( + nn.TriEncode(16, 0.2), + nn.Cast(Float16), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, -1, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, 3, bias=False), + nn.Exp() + ) + + # Instantiate the network for a specific backend + input size + net = net.alloc(TensorXf16, 2) + + # Convert to training-optimal layout + coeffs, net = nn.pack(net, layout='training') + print(net) + + # Optimize a single precision copy of the parameters + opt = Adam(lr=1e-3, params={'coeffs': Float32(coeffs)}) + + # This is an adaptive mixed-precision (AMP) optimization, where a half + # precision computation runs within a larger single precision program. + # Gradient scaling is required to make this numerically well-behaved. + scaler = GradScaler() + + res = 256 + + for i in tqdm(range(40000)): + # Update network state from optimizer + coeffs[:] = Float16(opt['coeffs']) + + # Generate jittered positions on [0, 1]^2 + t = dr.arange(Float32, res) + p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res*res))) / res + + # Evaluate neural net + L2 loss + img = Array3f(net(nn.CoopVec(p))) + loss = dr.squared_norm(tex.eval(p)-img) + + # Mixed-precision training: take suitably scaled steps + dr.backward(scaler.scale(loss)) + scaler.step(opt) + + # Done optimizing, now let's plot the result + t = dr.linspace(Float32, 0, 1, res) + p= Array2f(dr.meshgrid(t, t)) + img = Array3f(net(nn.CoopVec(p))) + img = dr.reshape(TensorXf(img, flip_axes=True), (res, res, 3)) + + import matplotlib.pyplot as plt + fig, ax = plt.subplots(1, 2, figsize=(10,5)) + ax[0].imshow(ref) + ax[1].imshow(dr.clip(img, 0, 1)) + fig.tight_layout() + plt.show() + diff --git a/docs/reference.rst b/docs/reference.rst index 8bab5507..a71737bd 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -278,6 +278,7 @@ Standard mathematical functions .. autofunction:: sign .. autofunction:: copysign .. autofunction:: mulsign +.. autofunction:: step Operations for vectors and matrices ----------------------------------- @@ -633,6 +634,8 @@ Low-level bits .. py:currentmodule:: drjit.detail .. autofunction:: set_leak_warnings .. autofunction:: leak_warnings +.. autofunction:: llvm_version +.. autofunction:: cuda_version .. py:currentmodule:: drjit Typing @@ -692,7 +695,6 @@ gradient-based optimization and adaptive mixed-precision training. .. automethod:: __delitem__ .. automethod:: __contains__ .. automethod:: __len__ - .. automethod:: update .. automethod:: keys .. automethod:: values .. automethod:: items @@ -715,3 +717,121 @@ gradient-based optimization and adaptive mixed-precision training. .. automethod:: step .. automethod:: scale .. automethod:: unscale + +.. _coop_vec_ref: + +Cooperative Vectors +------------------- + +.. py:module:: drjit.nn + +The :py:mod:`drjit.nn` module provides infrastructure to implement small +neural networks and revolves around the notion of *cooperative vectors* that +facilitate code generation of matrix-vector products. Please see the separate +:ref:`documentation section ` for an introduction. + +.. autoclass:: CoopVec + + .. automethod:: __init__ + .. automethod:: __add__ + .. automethod:: __sub__ + .. automethod:: __mul__ + .. automethod:: __len__ + .. automethod:: __repr__ + + .. property:: index + :type: int + + Stores the Dr.Jit variable index of the cooperative vector. + + .. property:: type + :type: type[drjit.ArrayBase] + + Stores the element type + +.. autoclass:: MatrixView + + .. automethod:: __getitem__ + + .. property:: dtype + :type: drjit.VarType + + Scalar type underlying the view. + + .. property:: shape + :type: tuple[int, int] + + Number of rows/columns. Vectors are stored as matrices with one column. + + .. property:: layout + :type: MatrixLayout + + One of several possible matrix layouts (training/inference-optimal and + row-major). + + .. property:: stride + :type: int + + Row stride (in # of elements) + + .. property:: size + :type: int + + Total number of elements + + .. property:: transpose + :type: bool + + The ``MatrixView.T`` property flips this flag (all other + values stay unchanged). + + .. property:: buffer + :type: drjit.ArrayBase + + The underlying buffer, which may contain additional matrices/vectors + besides the data referenced by the :py:class:`MatrixView`. + + .. property:: T + :type: MatrixView + + Return a transposed view. + + .. property:: grad + :type: MatrixView + + Return an analogous view of the gradient. + +.. autofunction:: view +.. autofunction:: pack +.. autofunction:: unpack +.. autofunction:: matvec +.. autofunction:: cast + +Neural Networks +--------------- + +Besides :ref:`cooperative vector classes `, the +:py:mod:`drjit.nn` module also provides convenient abstractions to declare, +evaluate, and train networks. Please see the separate :ref:`documentation +section ` for an introduction. + +.. autoclass:: Model + + .. automethod:: __call__ + .. automethod:: alloc + +.. autoclass:: Sequential + + .. automethod:: __len__ + .. automethod:: __getitem__ + +.. autoclass:: Linear +.. autoclass:: ReLU +.. autoclass:: LeakyReLU +.. autoclass:: SinEncode +.. autoclass:: TriEncode +.. autoclass:: Exp +.. autoclass:: Exp2 +.. autoclass:: Tanh +.. autoclass:: Cast +.. autoclass:: ScaleAdd diff --git a/docs/what.rst b/docs/what.rst index e76af0f1..59250170 100644 --- a/docs/what.rst +++ b/docs/what.rst @@ -22,9 +22,11 @@ Using Dr.Jit involves two steps: **That's it**. It doesn't do much, but it does this *very efficiently*. Perhaps the most significant difference to the majority of existing tools is -that Dr.Jit is *not* a machine learning library. Its sweet spot are non-neural -programs characterized by *embarrassing parallelism*---that is to say, programs -with large data-parallel regions. A good example of this are `Monte Carlo +that Dr.Jit is *not primarily* a machine learning library. While it does +provide support for neural network :ref:`evaluation and training `, +it its sweet spot are non-neural programs characterized by *embarrassing +parallelism*---that is to say, programs with large data-parallel regions. A +good example of this are `Monte Carlo `__ methods with their parallel sample evaluation (indeed, the reason why this project was originally created was to provide the foundation of `Mitsuba 3 diff --git a/drjit/__init__.py b/drjit/__init__.py index 76e9cb0c..0b96ef7d 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -2269,7 +2269,7 @@ def upsample(t, shape=None, scale_factor=None): _rand_seed : int = 0 -def seed(value: int): +def seed(value: int) -> None: """ Reset the seed value that is used for pseudorandom number generation. diff --git a/drjit/nn.py b/drjit/nn.py new file mode 100644 index 00000000..e0ae987f --- /dev/null +++ b/drjit/nn.py @@ -0,0 +1,482 @@ +from __future__ import annotations +import drjit +import sys + +if sys.version_info < (3, 11): + from typing_extensions import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any +else: + from typing import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any + +# Import classes/functions from C++ extension +MatrixView = drjit.detail.nn.MatrixView +CoopVec = drjit.detail.nn.CoopVec +pack = drjit.detail.nn.pack +unpack = drjit.detail.nn.unpack +matvec = drjit.detail.nn.matvec +view = drjit.detail.nn.view +cast = drjit.detail.nn.cast +T = drjit.detail.nn.T + +TensorOrViewOrNone: TypeAlias = Union[ + drjit.ArrayBase, + MatrixView, + None +] + +class Module: + """ + This is the base class of a modular set of operations that make + the specification of neural network architectures more convenient. + + Module subclasses are :ref:`PyTrees `, which means that various + Dr.Jit operations can automatically traverse them. + + Constructing a neural network generally involves the following pattern: + + .. code-block:: + + # 1. Establish the network structure + net = nn.Sequential( + nn.Linear(-1, 32, bias=False), + nn.ReLU(), + nn.Linear(-1, 3) + ) + + # 2. Instantiate the network for a specific backend + input size + net = net.alloc(TensorXf16, 2) + + # 3. Pack coefficients into a training-optimal layout + coeffs, net = nn.pack(net, layout='training') + + Network evaluation expects a :ref:`cooperative vector ` as input + (i.e., ``net(nn.CoopVec(...))``) and returns another cooperative vector. + The ``coeffs`` buffer contains all weight/bias data in training-optimal + format and can be optimized, which will directly impact the packed network + ``net`` that references this buffer. + """ + def __call__(self, arg: CoopVec, /) -> CoopVec: + """ + Evaluate the model with an input cooperative vector and return the result. + """ + raise NotImplementedError(f"{type(self).__name__}.__call__() implementation is missing.") + + def _alloc(self, dtype: Type[drjit.ArrayBase], size: int, /) -> Tuple[Module, int]: + return self, size + + def alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1) -> Module: + """ + Returns a new instance of the model with allocated weights. + + This function expects a suitable tensor ``dtype`` (e.g. + :py:class:`drjit.cuda.ad.TensorXf16` or + :py:class:`drjit.llvm.ad.TensorXf`) that will be used to store the + weights on the device. + + If the model or one of its sub-models is automatically sized (e.g., + ``input_features=-1`` in :py:class:`drjit.nn.Linear`), the final + network configuration may ambiguous and an exception will be raised. + Specify the optional ``size`` parameter in such cases to inform the + allocation about the size of the input cooperative vector. + """ + return self._alloc(dtype, size)[0] + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + +class Sequential(Module, Sequence[Module]): + """ + This model evaluates provided arguments ``arg[0]``, ``arg[1]``, ..., in sequence. + """ + DRJIT_STRUCT = { 'layers' : tuple } + + layers: tuple[Module, ...] + + def __init__(self, *args: Module): + self.layers = args + + def __call__(self, arg: CoopVec, /) -> CoopVec: + for l in self.layers: + arg = l(arg) + return arg + + def _alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1, /) -> Tuple[Module, int]: + result = [] + for l in self.layers: + l_new, size = l._alloc(dtype, size) + result.append(l_new) + return Sequential(*result), size + + def __len__(self): + """Return the number of contained models""" + return len(self.layers) + + def __getitem__(self, index: Union[int], /) -> Module: # type: ignore + """Return the model at position ``index``""" + return self.layers[index] + + def __repr__(self) -> str: + s = 'Sequential(\n' + n = len(self.layers) + for i in range(n): + s += ' ' + repr(self.layers[i]).replace('\n', '\n ') + if i + 1 < n: + s += ',' + s += '\n' + s += ')' + return s + +class ReLU(Module): + r""" + ReLU (rectified linear unit) activation function. + + This model evaluates the following expression: + + .. math:: + + \mathrm{ReLU}(x) = \mathrm{max}\{x, 0\}. + + """ + + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.maximum(arg, 0) + +class LeakyReLU(Module): + r""" + "Leaky" ReLU (rectified linear unit) activation function. + + This model evaluates the following expression: + + .. math:: + + \mathrm{LeakyReLU}(x) = \begin{cases} + x,&\mathrm{if}\ x\ge 0,\\ + \texttt{negative\_slope}\cdot x,&\mathrm{otherwise}. + \end{cases} + """ + + DRJIT_STRUCT = { 'negative_slope': float } + def __init__(self, negative_slope: float = 1e-2): + self.negative_slope = negative_slope + + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.maximum(arg, 0) + drjit.minimum(arg, 0.0) * self.negative_slope + + +class Exp2(Module): + r""" + Applies the base-2 exponential function to each component. + + .. math:: + + \mathrm{Exp2}(x) = 2^x + + On the CUDA backend, this function directly maps to an efficient native GPU instruction. + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.exp2(arg) + +class Exp(Module): + r""" + Applies the exponential function to each component. + + .. math:: + + \mathrm{Exp}(x) = e^x + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.exp2(arg * (1 / drjit.log(2))) + +class Tanh(Module): + r""" + Applies the hyperbolic tangent function to each component. + + .. math:: + + \mathrm{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x)+\exp(-x)} + + On the CUDA backend, this function directly maps to an efficient native GPU instruction. + """ + DRJIT_STRUCT = { } + def __call__(self, arg: CoopVec, /) -> CoopVec: + return drjit.tanh(arg) + +class ScaleAdd(Module): + r""" + Scale the input by a fixed scale and apply an offset. + + Note that ``scale`` and ``offset`` are assumed to be constant (i.e., not trainable). + + .. math:: + + \mathrm{ScaleAdd}(x) = x\cdot\texttt{scale} + \texttt{offset} + """ + DRJIT_STRUCT = {'scale': Union[None, float, int, drjit.ArrayBase], + 'offset': Union[None, float, int, drjit.ArrayBase]} + def __init__(self, scale: Union[float, int, drjit.ArrayBase, None] = None, + offset: Union[float, int, drjit.ArrayBase, None] = None): + self.scale = scale + self.offset = offset + def __call__(self, arg: CoopVec, /) -> CoopVec: + if not self.scale or not self.offset: + raise Exception("drjit.nn.ScaleAdd(): you must set a scale and offset!") + return drjit.fma(arg, self.scale, self.offset) + +class Cast(Module): + """ + Cast the input cooperative vector to a different precision. Should be + instantiated with the desired element type, e.g. ``Cast(drjit.cuda.ad.Float32)`` + """ + DRJIT_STRUCT = { 'dtype': Optional[Type[drjit.ArrayBase]] } + def __init__(self, dtype: Optional[Type[drjit.ArrayBase]] = None): + self.dtype = dtype + def __call__(self, arg: CoopVec, /) -> CoopVec: + return cast(arg, self.dtype) + +class Linear(Module): + r""" + This layer represents a learnable affine linear transformation of the input + data following the expression :math:`\mathbf{y} = \mathbf{A}\mathbf{x} + + \mathbf{b}`. + + It takes ``in_features`` inputs and returns a cooperative vector with + ``out_features`` dimensions. The following parameter values have a special + a meaning: + + - ``in_features=-1``: set the input size to match the previous model's + output (or the input of the network, if there is no previous model). + + - ``out_features=-1``: set the output size to match the input size. + + The bias (:math:`\textbf{b}`) term is optional and can be disabled by + specifying ``bias=False``. + + The method :py:func:`Module.alloc` initializes the underlying coefficient + storage with random weights following a uniform Xavier initialization, + i.e., uniform variates on the interval :math:`[-k,k]` where + :math:`k=1/\sqrt{\texttt{out\_features}}`. Call :py:func:`drjit.seed()` prior + to this step to ensure that weights are always initialized with the same + values, which can be helpful for hyperpararameter tuning and + reproducibility. + """ + config: Tuple[int, int, bool] + weights: TensorOrViewOrNone + bias: TensorOrViewOrNone + + DRJIT_STRUCT = { + 'config': Tuple[int, int, bool], + 'weights': TensorOrViewOrNone, + 'bias': TensorOrViewOrNone + } + + def __init__(self, in_features: int = -1, out_features: int = -1, bias = True) -> None: + self.config = (in_features, out_features, bias) + self.weights = self.bias = None + + def __repr__(self) -> str: + s = f'Linear({self.config[0]}, {self.config[1]}' + if not self.config[2]: + s += ', bias=False' + s += ')' + return s + + def __call__(self, arg: CoopVec, /) -> CoopVec: + if self.weights is None: + raise RuntimeError( + "Uninitialized network. Call 'net = net.alloc(""" + ")' to initialize the weight storage first. Following this, " + "use 'drjit.nn.pack()' to transform the network into an " + "optimal layout for evaluation." + ) + elif not isinstance(self.weights, MatrixView) or \ + (self.bias is not None and not isinstance(self.bias, MatrixView)): + raise RuntimeError( + "Uninitialized network. Use 'drjit.nn.pack()' to transform" + "the network into an optimal layout for evaluation." + ) + return matvec(self.weights, arg, self.bias) + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + in_features, out_features, bias = self.config + if in_features < 0: + in_features = size + if out_features < 0: + out_features = in_features + if in_features == -1 or out_features == -1: + raise RuntimeError("The network contains layers with an unspecified " + "size. You must specify the input size to drjit.nn.Module.alloc().") + + result = Linear(in_features, out_features, bias) + # Xavier (uniform) initialization, matches PyTorch + scale = drjit.sqrt(1 / out_features) + Float32 = drjit.float32_array_t(dtype) + samples = drjit.rand(Float32, (out_features, in_features)) + result.weights = dtype(drjit.fma(samples, 2, -1) * scale) + if bias: + result.bias = drjit.zeros(dtype, out_features) + return result, out_features + +def _sincos_tri(t: T) -> tuple[T, T]: + """Implementation detail of the TriEncode class""" + s = t - .25 + st = s - drjit.round(s) + ct = t - drjit.round(t) + return ( + drjit.fma(drjit.abs(st), -4, 1), + drjit.fma(drjit.abs(ct), -4, 1) + ) + +class TriEncode(Module): + r""" + Map an input onto a higher-dimensional space by transforming it using + triangular sine and cosine approximations of an increasing frequency. + + .. math:: + + x\mapsto \begin{bmatrix} + \sin_\triangle(2^0\,x)\\ + \cos_\triangle(2^0\,x)\\ + \vdots\\ + \cos_\triangle(2^{n-1}\, x)\\ + \sin_\triangle(2^{n-1}\, x) + \end{bmatrix} + + where + + .. math:: + + \cos_\triangle(x) = 1-4\left|x-\mathrm{round}(x)\right| + + and + + .. math:: + + \sin_\triangle(x) = \cos_\triangle(x-1/4) + + The value :math:`n` refers to the number of *octaves*. This layer increases + the dimension by a factor of :math:`2n`. + + Note that this encoding has period 1. If your input exceeds the interval + :math:`[0, 1]`, it is advisable that you reduce it to this range to avoid + losing information. + + Minima/maxima of higher frequency components conincide on a regular + lattice, which can lead to reduced fitting performance at those locations. + Specify the optional parameter ``shift`` to phase-shift the :math:`i`-th + frequency by :math:`2\,\pi\,\mathrm{shift}` to avoid this behavior. + + The following plot shows the first two octaves applied to the linear + function on :math:`[0, 1]` (without shift). + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/tri_encode_light.svg + :class: only-light + :width: 600px + :align: center + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/tri_encode_dark.svg + :class: only-dark + :width: 600px + :align: center + """ + + def __init__(self, octaves: int = 0, shift: float = 0) -> None: + self.octaves = octaves + self.shift = shift + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + return self, size * self.octaves * 2 + + def __repr__(self) -> str: + return f'TriEncode({self.octaves})' + + def __call__(self, arg: CoopVec, /) -> CoopVec: + args, r = list(arg), list() + for arg in args: + for i in range(self.octaves): + s, c = _sincos_tri(drjit.fma(arg, 2**i, self.shift*i)) + r.append(s) + r.append(c) + return CoopVec(r) + + +class SinEncode(Module): + r""" + Map an input onto a higher-dimensional space by transforming it using sines + and cosines of an increasing frequency. + + .. math:: + + x\mapsto \begin{bmatrix} + \sin(2^0\, 2\pi x)\\ + \cos(2^0\, 2\pi x)\\ + \vdots\\ + \sin(2^{n-1}\, 2\pi x)\\ + \cos(2^{n-1}\, 2\pi x)\\ + \end{bmatrix} + + + The value :math:`n` refers to the number of *octaves*. This layer increases + the dimension by a factor of :math:`2n`. + + Note that this encoding has period 1. If your input exceeds the interval + :math:`[0, 1]`, it is advisable that you reduce it to this range to avoid + losing information. + + Minima/maxima of higher frequency components conincide on a regular + lattice, which can lead to reduced fitting performance at those locations. + Specify the optional parameter ``shift`` to phase-shift the :math:`i`-th + frequency by :math:`\mathrm{shift}` radians to avoid this behavior. + + The following plot shows the first two octaves applied to the linear + function on :math:`[0, 1]` (without shift). + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/sin_encode_light.svg + :class: only-light + :width: 600px + :align: center + + .. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/sin_encode_dark.svg + :class: only-dark + :width: 600px + :align: center + """ + + def __init__(self, octaves: int = 0, shift: float = 0) -> None: + self.octaves = octaves + + if shift == 0: + self.shift = None + else: + self.shift = (drjit.sin(shift*2*drjit.pi), + drjit.cos(shift*2*drjit.pi)) + + def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: + return self, size * self.octaves * 2 + + def __repr__(self) -> str: + return f'SinEncode({self.octaves})' + + def __call__(self, arg: CoopVec, /) -> CoopVec: + args, r = list(arg), list() + for arg in args: + s, c = drjit.sincos(arg * 2 * drjit.pi) + r.append(s) + r.append(c) + for _ in range(1, self.octaves): + # Recurrence for double angle sine/cosine + s2 = 2 * s + s, c = s2 * c, drjit.fma(-s2, s, 1) + r.append(s) + r.append(c) + + if self.shift: + # Recurrence for sine/cosine angle addition + ss, cs = self.shift + s, c = drjit.fma(s, cs, c*ss), \ + drjit.fma(c, cs, -s*ss) + + return CoopVec(r) + + diff --git a/drjit/stubs.pat b/drjit/stubs.pat index a38804d3..437bc732 100644 --- a/drjit/stubs.pat +++ b/drjit/stubs.pat @@ -108,13 +108,30 @@ drjit.select$: @overload def select(arg0: bool | AnyArray, arg1: T, arg2: T) -> T: ... -drjit.(atan2|minimum|maximum)$: +drjit.atan2$: @overload def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: \doc @overload def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + +drjit.step$: + @overload + def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: + \doc + @overload + def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + +drjit.(minimum|maximum)$: + @overload + def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, /) -> SelfT: + \doc + @overload + def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + @overload + def \1(arg0: CoopVec[ArrayT], arg1: object) -> CoopVec[ArrayT]: ... @overload + def \1(arg0: object, arg1: CoopVec[ArrayT]) -> CoopVec[ArrayT]: ... def \1(arg0: T, arg1: T, /) -> T: ... drjit.(empty|zeros|ones)$: @@ -128,21 +145,41 @@ drjit.(full|opaque)$: @overload def \1(dtype: type[T], value: T, shape: int | Sequence[int]) -> T: ... -drjit.(fma|lerp)$: +drjit.lerp$: @overload - def \1(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: + def lerp(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: \doc @overload - def \1(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... + def lerp(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... @overload - def \1(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + def lerp(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... @overload - def \1(arg0: T, arg1: T, arg2: T) -> T: ... + def lerp(arg0: T, arg1: T, arg2: T) -> T: ... + +drjit.fma$: + @overload + def fma(arg0: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg1: SelfCpT, arg2: SelfCpT, /) -> SelfT: + \doc + @overload + def fma(arg0: SelfCpT, arg1: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], arg2: SelfCpT, /) -> SelfT: ... + @overload + def fma(arg0: SelfCpT, arg1: SelfCpT, arg2: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], /) -> SelfT: ... + @overload + def fma(arg0: CoopVec[ArrayT], arg1: object, arg2: object) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: object, arg1: CoopVec[ArrayT], arg2: object) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: object, arg1: object, arg2: CoopVec[ArrayT]) -> CoopVec[ArrayT]: ... + @overload + def fma(arg0: T, arg1: T, arg2: T) -> T: ... drjit.reshape$: \from typing import Literal + @overload def reshape(dtype: type[T], value: object, shape: int | Sequence[int], order: Literal['A', 'C', 'F'] = 'A', shrink: bool = False) -> T: \doc + @overload + def reshape(value: object, shape: int | Sequence[int], order: Literal['A', 'C', 'F'] = 'A', shrink: bool = False) -> T: ... drjit.(isnan|isinf|isfinite)$: @overload @@ -265,7 +302,6 @@ drjit.sh_eval$: def sh_eval(d: ArrayBase[SelfT, SelfCpT, ValT, ValCpT, RedT, PlainT, MaskT], order: int) -> list[ValT]: \doc - # -------------- drjit.syntax, interop, detail ---------------- # Clean the drjit.interop stub @@ -645,3 +681,6 @@ drjit.__prefix__: \from typing import TypeAlias \from collections.abc import Iterable, Sequence Axis: TypeAlias = int | tuple[int] | None + +drjit.coop.__prefix__: + \from typing import overload, Literal diff --git a/ext/drjit-core b/ext/drjit-core index 32486b64..f6ccae53 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 32486b64dcbf3a8f7d0a28c274863d2c8ea25f65 +Subproject commit f6ccae53828aa7d320b697011d34931c2ab9934c diff --git a/include/drjit/extra.h b/include/drjit/extra.h index d2f24dd0..5d94e57c 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -520,6 +520,34 @@ DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT { // Return the AD reference count of a variable (for debugging) extern DRJIT_EXTRA_EXPORT uint32_t ad_var_ref(uint64_t index); +/// --------------------- Cooperative vector API --------------------- + +/// Pack a set of regular Dr.Jit variables to form a cooperative vector +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_pack(uint32_t n, const uint64_t *in); + +/// Unpack a cooperative vector into its components +extern DRJIT_EXTRA_EXPORT void ad_coop_vec_unpack(uint64_t index, uint32_t n, uint64_t *out); + +/// Perform a unary operation on a cooperative vector +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_unary_op(JitOp op, uint64_t a0); + +/// Perform a binary operation on a pair of cooperative vectors +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_binary_op(JitOp op, uint64_t a0, uint64_t a1); + +/// Perform a ternary operation on a triplet of cooperative vectors +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_ternary_op(JitOp op, uint64_t a0, uint64_t a1, uint64_t a2); + +/// Perform a matrix-vector multiplication + bias addition +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_matvec(uint64_t A_index, + const MatrixDescr *A_descr, + uint64_t x_index, + uint64_t b_index, + const MatrixDescr *b_descr, + int transpose); + +/// Cast a cooperative vector to a different precision +extern JIT_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt); + #if defined(__cplusplus) } #endif diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 11c5fda2..e4d1b128 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -43,6 +43,7 @@ */ #include "common.h" +#include "drjit-core/jit.h" #include #include #include @@ -140,6 +141,26 @@ DRJIT_NOINLINE JitVar scalar(JitBackend backend, VarType type, double value) { } } +/// As above, but for cooperative vectors +DRJIT_NOINLINE JitVar scalar_coop_vec(JitBackend backend, VarType type, double value, uint32_t length) { + switch (type) { + case VarType::Float16: { + drjit::half v = (drjit::half) value; + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float16, &v, 1, length)); + } + case VarType::Float32: { + float v = (float) value; + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float32, &v, 1, length)); + } + case VarType::Float64: { + return JitVar::steal(jit_coop_vec_literal(backend, VarType::Float64, &value, 1, length)); + } + default: + ad_fail("scalar_coop_vec(): unsupported AD scalar type"); + return JitVar(); + } +} + /// Create a scalar Jit variable with the same floating point type and backend /// as an already existing variable with the provided ``index`` DRJIT_INLINE JitVar scalar(Index index, double value) { @@ -147,6 +168,12 @@ DRJIT_INLINE JitVar scalar(Index index, double value) { return scalar(info.backend, info.type, value); } +/// As above, but for cooperative vectors +DRJIT_INLINE JitVar scalar_coop_vec(Index index, double value) { + VarInfo info = jit_set_backend(jit_index(index)); + return scalar_coop_vec(info.backend, info.type, value, jit_coop_vec_length(index)); +} + // ========================================================================== // Central data structures: edges, variables, global state // ========================================================================== @@ -216,7 +243,10 @@ enum VariableFlags : uint8_t { Visited = 1 << 4, /// Is this variable on an iteration boundary of an evaluated loop? - LoopBoundary = 1 << 5 + LoopBoundary = 1 << 5, + + /// Does this variable store a cooperative vector? + CoopVec = 1 << 6 }; /** @@ -311,6 +341,15 @@ struct Variable { void mul_accum(const JitVar &v1, const JitVar &v2, size_t src_size) { JitVar zero = scalar(v1.index(), 0.f), weight; + if (unlikely(flags & CoopVec)) { + // Specialized gradient propagation for cooperative vectors + if (grad.valid()) + grad = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, v1.index(), v2.index(), grad.index())); + else + grad = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, v1.index(), v2.index())); + return; + } + // Elide the zero check if ``v2`` is known not to be NaN/infinite if (jit_var_is_finite_literal(v2.index())) weight = v2; @@ -353,6 +392,15 @@ struct Variable { * optimizations. */ void accum(const JitVar& v, size_t src_size) { + if (unlikely(flags & CoopVec)) { + // Specialized gradient propagation for cooperative vectors + if (grad.valid()) + grad = JitVar::steal(jit_coop_vec_binary_op(JitOp::Add, v.index(), grad.index())); + else + grad = v; + return; + } + if (size == 1 && src_size != 1) { /* When this variable is scalar (size == 1) and the source is not (src_size != 1), the gradient must be reduced to a single @@ -857,6 +905,9 @@ static void ad_propagate_size(Variable *v) { } } +/// A tag to signal cooperative weights in the Arg() constructor +struct coop { }; + // This data structure encodes an ordinary dependence on a function argument struct Arg { Arg() = default; @@ -867,6 +918,9 @@ struct Arg { Arg(Index index, double value) : ad_index(::ad_index(index)), weight(scalar(index, value)) { } + Arg(Index index, double value, coop) + : ad_index(::ad_index(index)), weight(scalar_coop_vec(index, value)) { } + Arg(Arg &&a) = default; Arg(const Arg &a) = delete; Arg &operator=(const Arg &a) = delete; @@ -1016,6 +1070,8 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result, auto [ad_index, var] = ad_var_new(info.backend, info.size, info.type, symbolic, reuse_indices, label); + if (info.is_coop_vec) + var->flags |= VariableFlags::CoopVec; const char *tname = jit_type_name(info.type); if constexpr (N == 0) { @@ -2918,8 +2974,7 @@ Index ad_var_map_get(Index index) { /// Potentially use ad_var_map_get to rewrite the source or target of a /// gatter/scatter operation static Index ad_var_memop_remap(Index index, bool input) { - uint32_t flags = jit_flags(); - if (flags & (uint32_t) JitFlag::SymbolicScope) { + if (jit_flags() & (uint32_t) JitFlag::SymbolicScope) { index = ad_var_map_get(index); // Add to set of implicit variable dependencies @@ -3047,8 +3102,8 @@ class PacketGather : public dr::detail::CustomOpBase { void ad_var_gather_packet(size_t n, Index source, JitIndex offset, JitIndex mask, uint64_t *out, ReduceMode mode) { - uint32_t *out2 = (uint32_t *) alloca(sizeof(uint32_t) * n); - jit_var_gather_packet(n, jit_index(source), offset, mask, out2); + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t) * n); + jit_var_gather_packet(n, jit_index(source), offset, mask, tmp); ADIndex source_ad = ad_index(source); const std::vector &scopes = local_state.scopes; @@ -3064,16 +3119,16 @@ void ad_var_gather_packet(size_t n, Index source, JitIndex offset, op->add_index(backend, source_ad, true); for (size_t i = 0; i < n; ++i) { - out[i] = ad_var_new(out2[i]); - jit_var_dec_ref(out2[i]); + out[i] = ad_var_new(tmp[i]); + jit_var_dec_ref(tmp[i]); op->add_output(ad_index(out[i])); } if (!ad_custom_op(op.get())) - ad_raise("ad_var_gather_packet(): could not create CustomOp"); + ad_raise("ad_var_gather_packet(): could not create CustomOp!"); } else { for (size_t i = 0; i < n; ++i) - out[i] = out2[i]; + out[i] = tmp[i]; } } @@ -3152,11 +3207,11 @@ class PacketScatter : public dr::detail::CustomOpBase { if (op == ReduceOp::Identity && mode != ReduceMode::Permute) { JitMask value(true); - uint32_t *values = (uint32_t *) alloca(sizeof(uint32_t)*n); + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t)*n); for (size_t i = 0; i < n; ++i) - values[i] = value.index(); + tmp[i] = value.index(); m_blend = JitMask::steal(jit_var_scatter_packet( - n, m_blend.index(), values, offset, mask)); + n, m_blend.index(), tmp, offset, mask)); } if (op != ReduceOp::Add && op != ReduceOp::Identity) @@ -3171,7 +3226,7 @@ class PacketScatter : public dr::detail::CustomOpBase { void forward() override { std::lock_guard guard(state.lock); - JitIndex *grad_in = (JitIndex *) alloca(sizeof(JitIndex) * m_n); + JitIndex *grad_in = (JitIndex *) alloca(sizeof(JitIndex) * m_n); size_t n_valid = 0; JitVar zero = scalar(m_backend, m_type, 0.0); @@ -3212,7 +3267,7 @@ class PacketScatter : public dr::detail::CustomOpBase { void backward() override { std::lock_guard guard(state.lock); - JitIndex *out = (JitIndex *) alloca(sizeof(JitIndex) * m_n); + JitIndex *out = (JitIndex *) alloca(sizeof(JitIndex) * m_n); Variable *v = state[m_output_indices[0]]; if (!v->grad.valid()) @@ -3268,23 +3323,20 @@ class PacketScatter : public dr::detail::CustomOpBase { Index ad_var_scatter_packet(size_t n, Index target, const Index *values, JitIndex offset, JitIndex mask, ReduceOp op, ReduceMode mode) { - JitIndex *values2 = (JitIndex *) alloca(sizeof(JitIndex) * n); + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); bool attached = ad_index(target) != 0; for (size_t i = 0; i < n; ++i) { Index index = values[i]; - values2[i] = jit_index(index); - if (ad_index(index)) - attached = true; + tmp[i] = jit_index(index); + attached |= ad_index(index) != 0; } JitVar result = JitVar::steal(jit_var_scatter_packet( - n, jit_index(target), values2, offset, mask, op, mode)); + n, jit_index(target), tmp, offset, mask, op, mode)); bool perm_scatter = op == ReduceOp::Identity && mode == ReduceMode::Permute; - if (!attached) { - return result.release(); - } else { + if (attached) { // Track implicit dependencies & potentially remap variable IDs target = ad_var_memop_remap(target, false); @@ -3299,11 +3351,13 @@ Index ad_var_scatter_packet(size_t n, Index target, const Index *values, uint64_t ad_result = ad_var_new(result.index()); ps->add_output(ad_index(ad_result)); - if (!ad_custom_op(ps.get())) - ad_raise("ad_var_scatter_packet(): could not create CustomOp"); + if (ad_custom_op(ps.get())) + return ad_result; - return ad_result; + ad_var_dec_ref(ad_result); } + + return result.release(); } void ad_var_scatter_add_kahan(Index *target_1, Index *target_2, Index value, @@ -3563,9 +3617,10 @@ const char *ad_var_graphviz() { if (v->flags & VariableFlags::Symbolic) buffer.put("|{Symbolic}"); - buffer.fmt("|{Type: %s|Size: %zu}|{a%u|Refs: %u}}\"", - type_name_short[v->type], v->size, - index, (uint32_t) v->ref_count); + buffer.fmt("|{Type: %s%s|Size: %zu}|{a%u|Refs: %u}}\"", + type_name_short[v->type], + (v->flags & VariableFlags::CoopVec) ? " [coop]" : "", + v->size, index, (uint32_t) v->ref_count); if (color) buffer.fmt(" fillcolor=%s style=filled", color); @@ -3605,7 +3660,7 @@ const char *ad_var_graphviz() { " l4 [style=filled fillcolor=yellowgreen label=\"Gradient present\"];\n" " l3 [style=filled fillcolor=salmon label=\"Input\"];\n" " l2 [style=filled fillcolor=lightblue2 label=\"Output\"];\n" - " l1 [style=filled fillcolor=wheat label=\"Labeled\"];\n" + " l0 [style=filled fillcolor=wheat label=\"Labeled\"];\n" " }\n" "}\n"); @@ -3739,6 +3794,520 @@ void ad_copy_implicit_deps(drjit::vector& result, bool input) { } } +// ========================================================================== +// Cooperative vector API +// ========================================================================== + +class CoopVecPack : public dr::detail::CustomOpBase { +public: + ~CoopVecPack() { + std::lock_guard guard(state.lock); + for (uint32_t index: m_output_indices) + ad_var_dec_ref_int(index, state[index]); + } + + void forward() override { + std::lock_guard guard(state.lock); + uint32_t size = (uint32_t) m_input_indices.size(); + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * size); + size_t n_valid = 0; + + Variable *target = state[m_output_indices[0]]; + JitVar zero = scalar(m_backend, (VarType) target->type, 0.0); + + for (uint32_t i = 0; i < size; ++i) { + tmp[i] = zero.index(); + + if (m_inputs[i]) { + Variable *v2 = state[m_inputs[i]]; + if (v2->grad.valid()) { + tmp[i] = v2->grad.index(); + n_valid++; + } + } + } + + if (n_valid) { + JitVar packed = JitVar::steal(jit_coop_vec_pack(size, tmp)); + target->accum(packed, target->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + uint32_t n = (uint32_t) m_input_indices.size(); + + Variable *v = state[m_output_indices[0]]; + if (!v->grad.valid()) + return; + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + jit_coop_vec_unpack(v->grad.index(), n, tmp); + + for (size_t i = 0; i < m_input_indices.size(); ++i) { + Variable *v2 = state[m_inputs[i]]; + v2->accum(JitVar::steal(tmp[i]), v2->size); + } + } + + void add_input(JitBackend backend, uint32_t index) { + add_index(backend, index, true); + // No need for extra reference counting + m_inputs.push_back(index); + } + + void add_output(JitBackend backend, uint32_t index) { + add_index(backend, index, false); + std::lock_guard guard(state.lock); + ad_var_inc_ref_int(index, state[index]); + } + + const char *name() const override { return "pack"; } + +private: + std::vector m_inputs; +}; + +/// Pack a set of regular Dr.Jit variables to form a cooperative vector +Index ad_coop_vec_pack(uint32_t n, const Index *in) { + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + bool attached = false; + + if (n == 0) + return 0; + + for (uint32_t i = 0; i < n; ++i) { + Index index = in[i]; + tmp[i] = jit_index(index); + attached |= ad_index(index) != 0; + } + + JitVar result = JitVar::steal(jit_coop_vec_pack(n, tmp)); + + if (attached) { + VarInfo vi = jit_set_backend(result.index()); + + ref ps = new CoopVecPack(); + for (size_t i = 0; i < n; ++i) + ps->add_input(vi.backend, ad_index(in[i])); + + uint64_t ad_result = ad_var_new(result.index()); + ps->add_output(vi.backend, ad_index(ad_result)); + + if (ad_custom_op(ps.get())) + return ad_result; + + ad_var_dec_ref(ad_result); + } + + return result.release(); +} + +class CoopVecUnpack : public dr::detail::CustomOpBase { +public: + ~CoopVecUnpack() { + std::lock_guard guard(state.lock); + for (ADIndex index : m_output_indices) + ad_var_dec_ref_int(index, state[index]); + } + + void forward() override { + std::lock_guard guard(state.lock); + size_t n = m_output_indices.size(); + + const Variable *v = state[m_input_indices[0]]; + if (!v->grad.valid()) + return; + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + jit_coop_vec_unpack(v->grad.index(), n, tmp); + + for (size_t i = 0; i < n; ++i) { + Variable *vo = state[m_output_indices[i]]; + vo->accum(JitVar::steal(tmp[i]), vo->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + size_t n = m_output_indices.size(); + + JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); + + for (size_t i = 0; i < n; ++i) { + const Variable *v = state[m_output_indices[i]]; + uint32_t index = v->grad.index(); + + if (index) + jit_var_inc_ref(index); + else + index = scalar(m_backend, (VarType) v->type, 0.0).release(); + + tmp[i] = index; + } + + JitVar packed = JitVar::steal(jit_coop_vec_pack(n, tmp)); + for (size_t i = 0; i < m_output_indices.size(); ++i) + jit_var_dec_ref(tmp[i]); + + Variable *source = state[m_input_indices[0]]; + source->accum(packed, source->size); + } + + void add_output(uint32_t index) { + add_index(m_backend, index, false); + + std::lock_guard guard(state.lock); + ad_var_inc_ref_int(index, state[index]); + } + + const char *name() const override { return "unpack"; } +}; + +/// Unpack a cooperative vector into its components +void ad_coop_vec_unpack(uint64_t index, uint32_t n, uint64_t *out) { + uint32_t *tmp = (uint32_t *) alloca(sizeof(uint32_t) * n); + jit_coop_vec_unpack(index, n, tmp); + + ADIndex ad_index = ::ad_index(index); + const std::vector &scopes = local_state.scopes; + if (!scopes.empty()) + scopes.back().maybe_disable(ad_index); + + if (ad_index) { + ref op = new CoopVecUnpack(); + JitBackend backend = jit_set_backend(jit_index(index)).backend; + op->add_index(backend, ad_index, true); + + for (uint32_t i = 0; i < n; ++i) { + out[i] = ad_var_new(tmp[i]); + jit_var_dec_ref(tmp[i]); + op->add_output(::ad_index(out[i])); + } + + if (!ad_custom_op(op.get())) + ad_raise("ad_coop_vec_unpack(): could not create CustomOp!"); + } else { + for (uint32_t i = 0; i < n; ++i) + out[i] = tmp[i]; + } +} + +/// Perform a unary operation on a cooperative vector +uint64_t ad_coop_vec_unary_op(JitOp op, uint64_t i0) { + JitVar result = JitVar::steal( + jit_coop_vec_unary_op(op, jit_index(i0))); + + if (is_detached(i0)) { + return result.release(); + } else { + switch (op) { + case JitOp::Exp2: { + JitVar scale = scalar_coop_vec(i0, dr::LogTwo); + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, result.index(), scale.index())); + return ad_var_new("exp2", std::move(result), Arg(i0, std::move(w0))); + } + + case JitOp::Tanh: { + // Mini-max polynomial fit made using Sollya (max. relative error = 0.0052) + // Q1 = fpminimax(4*y/((1 + y)^2)-y, [|1, 2, 3, 4, 5|], [|halfprecision...|], [0, 1-1e-20]); + // Q2 = horner(Q1 + y); + // print(Q2); + + JitVar scale = scalar_coop_vec(i0, -2.8853900817779268), // -2/log(2) + c0 = scalar_coop_vec(i0, 3.98046875), + c1 = scalar_coop_vec(i0, -7.4140625), + c2 = scalar_coop_vec(i0, 8.2421875), + c3 = scalar_coop_vec(i0, -5.1640625), + c4 = scalar_coop_vec(i0, 1.35546875); + + JitVar x0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, (uint32_t) i0, scale.index())), + x1 = JitVar::steal(jit_coop_vec_unary_op(JitOp::Exp2, x0.index())), + y0 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), c4.index(), c3.index())), + y1 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y0.index(), c2.index())), + y2 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y1.index(), c1.index())), + y3 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y2.index(), c0.index())), + y4 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, x1.index(), y3.index())); + + return ad_var_new("tanh", std::move(result), Arg(i0, std::move(y4))); + } + + default: + ad_raise("ad_coop_vec_unary_op(): differentiable version not implemented."); + } + } +} + +/// Perform a binary operation on a pair of cooperative vectors +uint64_t ad_coop_vec_binary_op(JitOp op, uint64_t i0, uint64_t i1) { + JitVar result = JitVar::steal( + jit_coop_vec_binary_op(op, jit_index(i0), jit_index(i1))); + + if (is_detached(i0, i1)) { + return result.release(); + } else { + switch (op) { + case JitOp::Add: + return ad_var_new("add", std::move(result), + Arg(i0, 1.0, coop{}), + Arg(i1, 1.0, coop{})); + break; + + case JitOp::Sub: + return ad_var_new("sub", std::move(result), + Arg(i0, 1.0, coop{}), + Arg(i1, -1.0, coop{})); + break; + + case JitOp::Mul: + return ad_var_new("mul", std::move(result), + Arg(i0, JitVar::borrow(jit_index(i1))), + Arg(i1, JitVar::borrow(jit_index(i0)))); + break; + + case JitOp::Min: { + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i0), jit_index(i1))), + w1 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i1), jit_index(i0))); + return ad_var_new("min", std::move(result), + Arg(i0, std::move(w1)), + Arg(i1, std::move(w0))); + } + break; + + case JitOp::Max: { + JitVar w0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i0), jit_index(i1))), + w1 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Step, jit_index(i1), jit_index(i0))); + return ad_var_new("max", std::move(result), + Arg(i0, std::move(w0)), + Arg(i1, std::move(w1))); + } + break; + + case JitOp::Step: + return result.release(); + + default: + ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented."); + } + } +} + +/// Perform a ternary operation on a triplet of cooperative vectors +uint64_t ad_coop_vec_ternary_op(JitOp op, uint64_t i0, uint64_t i1, uint64_t i2) { + JitVar result = JitVar::steal( + jit_coop_vec_ternary_op(op, jit_index(i0), jit_index(i1), jit_index(i2))); + + if (is_detached(i0, i1, i2)) { + return result.release(); + } else { + switch (op) { + case JitOp::Fma: + return ad_var_new("fma", std::move(result), + Arg(i0, JitVar::borrow(jit_index(i1))), + Arg(i1, JitVar::borrow(jit_index(i0))), + Arg(i2, 1.0, coop{})); + + default: + ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented."); + } + } +} + + +struct CoopCast : Special { + CoopCast(VarType v1, VarType v2) : v1(v1), v2(v2) { } + + void backward(Variable *source, const Variable *target) override { + source->accum(JitVar::steal(jit_coop_vec_cast(target->grad.index(), v1)), + target->size); + } + + void forward(const Variable *source, Variable *target) override { + target->accum(JitVar::steal(jit_coop_vec_cast(source->grad.index(), v2)), + source->size); + } + + VarType v1, v2; +}; + + +Index ad_coop_vec_cast(Index i0, VarType vt) { + JitVar result = JitVar::steal(jit_coop_vec_cast(jit_index(i0), vt)); + + if (is_detached(i0)) { + return result.release(); + } else { + return ad_var_new("cast", std::move(result), + SpecialArg(i0, new CoopCast(jit_var_type((JitIndex) i0), vt))); + } +} + + +class CoopMatVec : public dr::detail::CustomOpBase { +public: + CoopMatVec(Index A_index, const MatrixDescr *A_descr, Index x_index, + Index b_index, const MatrixDescr *b_descr, int transpose) + : m_A(A_index), m_A_descr(*A_descr), m_x(x_index), m_b(b_index), + m_transpose(transpose) { + if (b_descr) + m_b_descr = *b_descr; + ad_var_inc_ref(m_A); + ad_var_inc_ref(m_b); + ad_var_inc_ref(m_x); + m_out = 0; + } + + ~CoopMatVec() { + ad_var_dec_ref(m_A); + ad_var_dec_ref(m_b); + ad_var_dec_ref(m_x); + ad_var_dec_ref(m_out); + } + + void forward() override { + std::lock_guard guard(state.lock); + + const Variable *A_v = ad_index(m_A) ? state[ad_index(m_A)] : nullptr, + *x_v = ad_index(m_x) ? state[ad_index(m_x)] : nullptr, + *b_v = ad_index(m_b) ? state[ad_index(m_b)] : nullptr; + + Variable *out_v = state[m_output_indices[0]]; + bool has_b_grad = b_v && b_v->grad.valid(); + + if (A_v && A_v->grad.valid()) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + A_v->grad.index(), &m_A_descr, jit_index(m_x), + has_b_grad ? b_v->grad.index() : 0, + has_b_grad ? &m_b_descr : nullptr, m_transpose)); + out_v->accum(result, out_v->size); + has_b_grad = false; + } + + if (x_v && x_v->grad.valid()) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + jit_index(m_A), &m_A_descr, x_v->grad.index(), + has_b_grad ? b_v->grad.index() : 0, + has_b_grad ? &m_b_descr : nullptr, m_transpose)); + out_v->accum(result, out_v->size); + has_b_grad = false; + } + + if (has_b_grad) { + JitVar result = JitVar::steal(jit_coop_vec_load( + b_v->grad.index(), m_b_descr.offset, m_b_descr.rows)); + out_v->accum(result, out_v->size); + } + } + + void backward() override { + std::lock_guard guard(state.lock); + Variable *out_v = state[m_output_indices[0]]; + const JitVar &grad = out_v->grad; + + if (!grad.valid()) + return; + + Variable *A_v = ad_index(m_A) ? state[ad_index(m_A)] : nullptr, + *x_v = ad_index(m_x) ? state[ad_index(m_x)] : nullptr, + *b_v = ad_index(m_b) ? state[ad_index(m_b)] : nullptr; + + if (x_v) { + JitVar result = JitVar::steal(jit_coop_vec_matvec( + jit_index(m_A), &m_A_descr, grad.index(), 0, + nullptr, m_transpose == 0 ? 1 : 0)); + x_v->accum(result, x_v->size); + } + + if (A_v) { + uint32_t vec_a = jit_index(m_x), + vec_b = jit_index(grad.index()); + if (m_transpose) + std::swap(vec_a, vec_b); + + A_v->grad = JitVar::steal(jit_coop_vec_outer_product_accum( + A_v->grad.index(), jit_var_size(jit_index(m_A)), &m_A_descr, + vec_b, vec_a)); + } + + if (b_v) { + b_v->grad = JitVar::steal(jit_coop_vec_accum( + b_v->grad.index(), jit_var_size(jit_index(m_b)), m_b_descr.offset, + grad.index())); + } + } + + void set_output(JitBackend backend, Index index) { + add_index(backend, ad_index(index), false); + m_out = (index >> 32) << 32; + ad_var_inc_ref(m_out); + } + + const char *name() const override { return "matvec"; } + +private: + Index m_A; + MatrixDescr m_A_descr; + Index m_x; + Index m_b; + Index m_out; + MatrixDescr m_b_descr; + int m_transpose; +}; + +uint64_t ad_coop_vec_matvec(uint64_t A_index, const MatrixDescr *A_descr, + uint64_t x_index, uint64_t b_index, + const MatrixDescr *b_descr, int transpose) { + + uint32_t A_index_j = jit_index(A_index), + x_index_j = jit_index(x_index), + b_index_j = jit_index(b_index), + A_index_a = ad_index(A_index), + x_index_a = ad_index(x_index), + b_index_a = ad_index(b_index); + + if (A_index_a || x_index_a || b_index_a) { + const std::vector &scopes = local_state.scopes; + if (!scopes.empty()) { + const Scope &s = scopes.back(); + s.maybe_disable(A_index_a); + s.maybe_disable(x_index_a); + s.maybe_disable(b_index_a); + } + } + + JitVar result = JitVar::steal(jit_coop_vec_matvec( + A_index_j, A_descr, x_index_j, b_index_j, b_descr, transpose)); + + if (!A_index_a && !x_index_a && !b_index_a) { + return result.release(); + } else { + { + std::lock_guard guard(state.lock); + A_index = ad_var_memop_remap(A_index, true); + b_index = ad_var_memop_remap(b_index, true); + A_index_j = jit_index(A_index); + b_index_j = jit_index(b_index); + A_index_a = ad_index(A_index); + b_index_a = ad_index(b_index); + } + + ref op = new CoopMatVec(A_index, A_descr, x_index, + b_index, b_descr, transpose); + JitBackend backend = jit_set_backend(x_index_j).backend; + op->add_index(backend, A_index_a, true); + op->add_index(backend, x_index_a, true); + op->add_index(backend, b_index_a, true); + + uint64_t result_diff = ad_var_new(result.index()); + op->set_output(backend, result_diff); + + if (!ad_custom_op(op.get())) + ad_raise("ad_coop_vec_matvec(): could not create CustomOp!"); + + return result_diff; + } +} + // ========================================================================== // Custom operations // ========================================================================== diff --git a/src/extra/math.cpp b/src/extra/math.cpp index ccad1c88..2168cbd0 100644 --- a/src/extra/math.cpp +++ b/src/extra/math.cpp @@ -93,7 +93,6 @@ DEFINE_MATH_OP(acos) DEFINE_MATH_OP(atan) DEFINE_MATH_OP(sinh) DEFINE_MATH_OP(cosh) -DEFINE_MATH_OP(tanh) DEFINE_MATH_OP(asinh) DEFINE_MATH_OP(acosh) DEFINE_MATH_OP(atanh) @@ -236,3 +235,23 @@ DRJIT_EXTRA_EXPORT uint32_t jit_var_cos(uint32_t i0) { return 0; } } + +DRJIT_EXTRA_EXPORT uint32_t jit_var_tanh(uint32_t i0) { + VarInfo info = jit_set_backend(i0); + + switch (info.type) { + case VarType::Float16: + return dr::tanh(Float16::borrow(i0)).release(); + + case VarType::Float32: + if (info.backend == JitBackend::CUDA) + return jit_var_tanh_intrinsic(i0); + return dr::tanh(Float32::borrow(i0)).release(); + + case VarType::Float64: + return dr::tanh(Float64::borrow(i0)).release(); + default: + jit_fail("jit_var_tanh(): invalid operand!"); + return 0; + } +} diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 9886323d..678ac1fb 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -21,7 +21,7 @@ configure_file( ) set(PY_FILES - config.py __init__.py ast.py detail.py interop.py dda.py opt.py + config.py __init__.py ast.py detail.py interop.py dda.py opt.py nn.py _sh_eval.py _reduce.py scalar/__init__.py llvm/__init__.py llvm/ad.py cuda/__init__.py cuda/ad.py) @@ -83,6 +83,7 @@ nanobind_add_module( tracker.h tracker.cpp local.h local.cpp resample.h resample.cpp + coop_vec.h coop_vec.cpp # Backends scalar.h scalar.cpp @@ -211,6 +212,13 @@ if (NOT (DRJIT_SANITIZE_ASAN OR DRJIT_SANITIZE_UBSAN)) ${STUB_ARGS} ) + nanobind_add_stub( + drjit-stub-nn + MODULE drjit.nn + OUTPUT ${DRJIT_PYTHON_DST_DIR}/nn.pyi + ${STUB_ARGS} + ) + nanobind_add_stub( drjit-stub-scalar MODULE drjit.scalar diff --git a/src/python/apply.cpp b/src/python/apply.cpp index c611e099..fbfc0c47 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -446,7 +446,8 @@ NB_NOINLINE PyObject *apply_tensor(ArrayOp op, Slot slot, expanded_shapes_alloc[index] = vector(ndim, 1); vector& expanded_shape = expanded_shapes_alloc[index]; size_t offset = ndim - src_ndim; - memcpy(&expanded_shape[offset], shape->data(), sizeof(size_t) * src_ndim); + if (src_ndim) + memcpy(&expanded_shape[offset], shape->data(), sizeof(size_t) * src_ndim); return (const vector*)&expanded_shape; }; @@ -635,20 +636,18 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) traverse(op, tc, h2); - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - for (nb::handle field : df) { - nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); - } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); - } else { - tc.traverse_unknown(h); + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + for (auto [k, v] : ds) + traverse(op, tc, nb::getattr(h, k)); + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + traverse(op, tc, nb::getattr(h, k)); } + } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { + cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else { + tc.traverse_unknown(h); } } catch (nb::python_error &e) { nb::raise_from(e, PyExc_RuntimeError, @@ -889,25 +888,23 @@ nb::object transform(const char *op, TransformCallback &tc, nb::handle h) { for (auto [k, v] : nb::borrow(h)) tmp[k] = transform(op, tc, v); result = std::move(tmp); - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - nb::object tmp = tp(); - for (auto [k, v] : ds) - nb::setattr(tmp, k, transform(op, tc, nb::getattr(h, k))); - result = std::move(tmp); - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - nb::object tmp = nb::dict(); - for (nb::handle field : df) { - nb::object k = field.attr(DR_STR(name)); - tmp[k] = transform(op, tc, nb::getattr(h, k)); - } - result = tp(**tmp); - } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); - result = nb::borrow(h); - } else if (!result.is_valid()) { - result = tc.transform_unknown(h); + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + nb::object tmp = tp(); + for (auto [k, v] : ds) + nb::setattr(tmp, k, transform(op, tc, nb::getattr(h, k))); + result = std::move(tmp); + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + nb::object tmp = nb::dict(); + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + tmp[k] = transform(op, tc, nb::getattr(h, k)); } + result = tp(**tmp); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); + result = nb::borrow(h); + } else if (!result.is_valid()) { + result = tc.transform_unknown(h); } return result; } catch (nb::python_error &e) { diff --git a/src/python/autodiff.cpp b/src/python/autodiff.cpp index 46f27fd0..e1f654ba 100644 --- a/src/python/autodiff.cpp +++ b/src/python/autodiff.cpp @@ -12,6 +12,7 @@ #include #include #include "autodiff.h" +#include "coop_vec.h" #include "apply.h" #include "meta.h" #include "init.h" @@ -43,6 +44,24 @@ static void set_grad_enabled(nb::handle h, bool enable_) { } } } + + void traverse_unknown(nb::handle h) override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) { + uint64_t index = v->m_index; + bool grad_enabled = ((uint32_t) index) != index; + if (enable != grad_enabled) { + if (enable) { + nb::raise( + "to create a differentiable cooperative vector, " + "construct it from grad-enabled components."); + } else { + jit_var_inc_ref((uint32_t) index); + ad_var_dec_ref(index); + v->m_index = (uint32_t) index; + } + } + } + } }; SetGradEnabled sge(enable_); @@ -88,6 +107,11 @@ bool grad_enabled(nb::handle h) { if (s.is_diff && is_float(s)) result |= ad_grad_enabled(s.index(inst_ptr(h))) != 0; } + + void traverse_unknown(nb::handle h) override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) + result |= ad_grad_enabled(v->m_index); + } }; GradEnabled ge; @@ -139,6 +163,15 @@ static nb::object detach(nb::handle h, bool preserve_type_ = true) { nb::inst_copy(h2, h1); } } + + nb::object transform_unknown(nb::handle h) const override { + if (CoopVec *v = nullptr; nb::try_cast(h, v, false), v != nullptr) { + uint32_t index = (uint32_t) v->m_index; + jit_var_inc_ref(index); + return nb::cast(CoopVec(index, v->m_size, v->m_type)); + } + return nb::borrow(h); + } }; if ((is_drjit_array(h) && !supp(h.type()).is_diff)) diff --git a/src/python/base.cpp b/src/python/base.cpp index 1cc02a86..f5b5761d 100644 --- a/src/python/base.cpp +++ b/src/python/base.cpp @@ -1261,7 +1261,7 @@ void export_base(nb::module_ &m) { m.def("power", [](Py_ssize_t arg0, Py_ssize_t arg1) { return std::pow(arg0, arg1); }, - doc_pow); + doc_power); m.def("power", [](double arg0, double arg1) { return std::pow(arg0, arg1); }); diff --git a/src/python/coop_vec.cpp b/src/python/coop_vec.cpp new file mode 100644 index 00000000..4d7a1b5a --- /dev/null +++ b/src/python/coop_vec.cpp @@ -0,0 +1,740 @@ +/* + src/coop_vec.cpp -- Python bindings for Cooperative CoopVecs + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "common.h" +#include "base.h" +#include "init.h" +#include "meta.h" +#include "apply.h" +#include "coop_vec.h" +#include +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include +#include +#include +#include + + +/// Cooperative vector constructor +CoopVec::CoopVec(nb::handle arg) { + construct(arg); +} + +void CoopVec::construct(nb::handle arg) { + nb::handle single_arg = nb::none(); + if (nb::len(arg) == 1) + single_arg = arg[0]; + + if (CoopVec *v = nullptr; nb::try_cast(single_arg, v, false), v != nullptr) { + m_index = ad_var_inc_ref(v->m_index); + m_size = v->m_size; + m_type = v->m_type; + return; + } + + nb::handle arg_tp = single_arg.type(); + if (is_drjit_type(arg_tp)) { + const ArraySupplement &s = supp(arg_tp); + if (s.is_tensor) { + const dr::vector &shape = s.tensor_shape(inst_ptr(single_arg)); + if (shape.size() <= 2) { + construct(nb::list(single_arg)); + return; + } + } + } + + /// Flatten a PyTree into a set of 1D arrays used to construct a cooperative vector + struct Flatten: TraverseCallback { + std::vector result; + + void operator()(nb::handle h) { + if ((JitBackend) supp(h.type()).backend != JitBackend::None) + result.push_back(nb::borrow(h)); + } + + void traverse_unknown(nb::handle h) { + if (PyIter_Check(h.ptr())) + traverse("drjit.nn.CoopVec", *this, nb::list(h)); + else if (PyLong_CheckExact(h.ptr()) || PyFloat_CheckExact(h.ptr())) + result.push_back(nb::borrow(h)); + else + nb::raise("encountered an unknown type \"%s\"", nb::inst_name(h).c_str()); + } + }; + + Flatten cb; + traverse("drjit.nn.CoopVec", cb, arg); + + uint32_t size = (uint32_t) cb.result.size(); + + if (cb.result.empty()) + nb::raise("drjit.nn.CoopVec(): cannot be empty!"); + + // Identify type + for (uint32_t i = 0; i < size; ++i) { + nb::handle tp = cb.result[i].type(); + if (is_drjit_type(tp)) { + m_type = tp; + break; + } + } + + // Check that this type makes sense + if (!m_type.is_valid()) + nb::raise_type_error( + "drjit.nn.CoopVec(): at least one Jit-compiled 1D array is required as input " + "(e.g., of type 'drjit.cuda.Float16')!"); + + const ArraySupplement &s = supp(m_type); + if (s.ndim != 1 || (JitBackend) s.backend == JitBackend::None) + nb::raise_type_error( + "drjit.nn.CoopVec(): expected Jit-compiled 1D arrays as input " + "(e.g., of type 'drjit.cuda.Float16')!"); + + // Check/cast the other arguments + uint64_t *tmp = (uint64_t *) alloca(sizeof(uint64_t) * size); + for (uint32_t i = 0; i < size; ++i) { + nb::object value = cb.result[i]; + try { + if (!value.type().is(m_type)) { + value = m_type(value); + cb.result[i] = value; + } + tmp[i] = s.index(inst_ptr(value)); + } catch (...) { + nb::raise_type_error( + "drjit.nn.CoopVec.__init__(): encountered an incompatible " + "argument of type \"%s\" (expected \"%s\")!", + nb::inst_name(value).c_str(), + nb::type_name(m_type).c_str()); + } + } + + m_index = ad_coop_vec_pack(size, tmp); + m_size = size; +} + +/// Unpack a cooperative vector into a Python list +nb::list CoopVec::expand_to_list() const { + if (m_size == 0) + return nb::list(); + + uint64_t *tmp = (uint64_t *) alloca(m_size * sizeof(uint64_t)); + ad_coop_vec_unpack(m_index, m_size, tmp); + + nb::list result; + const ArraySupplement &s = supp(m_type); + for (uint32_t i = 0; i < m_size; ++i) { + nb::object o = nb::inst_alloc(m_type); + s.init_index(tmp[i], inst_ptr(o)); + ad_var_dec_ref(tmp[i]); + nb::inst_mark_ready(o); + result.append(std::move(o)); + } + return result; +} + +/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf +nb::object CoopVec::expand_to_vector() const { + ArrayMeta m = supp(m_type); + m.ndim = 2; + m.shape[0] = DRJIT_DYNAMIC; + m.shape[1] = DRJIT_DYNAMIC; + return meta_get_type(m)(expand_to_list()); +} + +/// Perform one of several supported unary operations +template static CoopVec coop_vec_unary_op(const CoopVec &arg) { + if ((JitBackend) supp(arg.m_type).backend == JitBackend::LLVM) { + nb::object unpacked = arg.expand_to_vector(), func; + + switch (Op) { + case JitOp::Exp2: func = array_module.attr("exp2"); break; + case JitOp::Tanh: func = array_module.attr("tanh"); break; + case JitOp::Log2: func = array_module.attr("log2"); break; + default: + nb::raise("Unsupported operation!"); + } + + return CoopVec(func(unpacked)); + } + + return CoopVec( + ad_coop_vec_unary_op(Op, arg.m_index), + arg.m_size, + arg.m_type + ); +} + +/// Perform one of several supported binary operations +template +static nb::object coop_vec_binary_op(nb::handle h0, nb::handle h1) { + nb::object o[2] { nb::borrow(h0), nb::borrow(h1) }; + CoopVec *ptr[2] { }; + CoopVec *c = nullptr; + + for (uint32_t i = 0; i < 2; ++i) { + if (nb::try_cast(o[i], ptr[i], false)) + c = ptr[i]; + } + if (!c) + return nb::steal(NB_NEXT_OVERLOAD); + + for (uint32_t i = 0; i < 2; ++i) { + if (ptr[i]) + continue; + + nb::list args; + nb::object oi = c->m_type(o[i]); + for (uint32_t j = 0; j < c->m_size; ++j) + args.append(oi); + + o[i] = nb::cast(CoopVec(nb::borrow(nb::tuple(args)))); + if (!nb::try_cast(o[i], ptr[i], false)) + nb::raise("CoopVec::binary_op(): internal error"); + } + + return nb::cast(CoopVec( + ad_coop_vec_binary_op( + Op, + ptr[0]->m_index, + ptr[1]->m_index + ), + c->m_size, + c->m_type + )); +} + +/// Perform a ternary operation (currently only FMA) +template +static nb::object coop_vec_ternary_op(nb::handle h0, nb::handle h1, + nb::handle h2) { + nb::object o[3] { nb::borrow(h0), nb::borrow(h1), nb::borrow(h2) }; + CoopVec *ptr[3] { }; + CoopVec *c = nullptr; + + for (uint32_t i = 0; i < 3; ++i) { + if (nb::try_cast(o[i], ptr[i], false)) + c = ptr[i]; + } + if (!c) + return nb::steal(NB_NEXT_OVERLOAD); + + for (uint32_t i = 0; i < 3; ++i) { + if (ptr[i]) + continue; + + nb::list args; + for (uint32_t j = 0; j < c->m_size; ++j) + args.append(c->m_type(o[i])); + + o[i] = nb::cast(CoopVec(nb::borrow(nb::tuple(args)))); + if (!nb::try_cast(o[i], ptr[i], false)) + nb::raise("CoopVec::ternary_op(): internal error"); + } + + return nb::cast(CoopVec( + ad_coop_vec_ternary_op( + Op, + ptr[0]->m_index, + ptr[1]->m_index, + ptr[2]->m_index + ), + c->m_size, + c->m_type + )); +} + +/// Matrix-vector product +static CoopVec matvec(const MatrixView &A, + const CoopVec &x, + std::optional b, + bool transpose) { + + return { + ad_coop_vec_matvec( + A.index(), + &A.descr, + x.m_index, + b.has_value() ? b.value()->index() : 0, + b.has_value() ? &b.value()->descr : nullptr, + ((int) transpose) ^ ((int) A.transpose) + ), + transpose ? A.descr.cols : A.descr.rows, + x.m_type + }; +} + +nb::str MatrixView::repr() const { + const char *layout; + switch (descr.layout) { + case MatrixLayout::InferencingOptimal: layout = "inference"; break; + case MatrixLayout::TrainingOptimal: layout = "training"; break; + case MatrixLayout::RowMajor: layout = "row_major"; break; + default: layout = "unknown"; break; + } + return nb::str( + "drjit.nn.MatrixView[\n" + " dtype={},\n" + " layout={},\n" + " shape=({}, {}),\n" + " stride={},\n" + " offset={}\n" + " size={}\n" + " buffer=<{} instance>\n" + "]" + ).format( + descr.dtype, + layout, + descr.rows, + descr.cols, + descr.stride, + descr.offset, + descr.size, + inst_name(buffer) + ); +} + +uint64_t MatrixView::index() const { + return supp(buffer.type()).index(inst_ptr(buffer)); +} + +MatrixView MatrixView::getitem(nb::object arg) const { + nb::object s[2]; + + if (descr.layout == MatrixLayout::InferencingOptimal || + descr.layout == MatrixLayout::TrainingOptimal) + nb::raise("drjit.MatrixView.__getitem__(): slicing is not permitted for " + "training/inferencing-optimal layouts!"); + + if (nb::isinstance(arg)) { + size_t l = nb::len(arg); + if (l == 0 || l > 2) + nb::raise("drjit.MatrixView.__getitem__(): expected 1 or 2 terms in " + "slice expression (got %zu)!", l); + s[0] = arg[0]; + if (l == 2) + s[1] = arg[1]; + } else { + s[0] = arg; + } + + if (!s[1].is_valid()) + s[1] = nb::slice(nb::none(), nb::none(), nb::none()); + + Py_ssize_t start[2], step[2]; + size_t len[2]; + + for (uint32_t i = 0; i < 2; ++i) { + uint32_t value; + if (nb::try_cast(s[i], value, false)) + s[i] = nb::slice(nb::int_(value), nb::int_(value + 1), nb::int_(1)); + nb::slice sl; + if (!nb::try_cast(s[i], sl, false)) + nb::raise("drjit.MatrixView.__getitem__(): expected 'int' or 'slice' " + "in slice expression, got '%s'!", + nb::inst_name(s[i]).c_str()); + size_t limit = i == 0 ? descr.rows : descr.cols; + auto [start_i, stop_i, step_i, len_i] = + sl.compute(limit); + start[i] = start_i; step[i] = step_i; len[i] = len_i; + } + + if (step[1] != 1) + nb::raise("drjit.MatrixView.__getitem__(): rows elements must be contiguous!"); + + if (len[0] == 0 || len[1] == 0) + nb::raise("drjit.MatrixView.__getitem__(): input array may not be empty!"); + + MatrixView result; + result.descr.rows = len[0]; + result.descr.cols = len[1]; + result.descr.offset = descr.offset + start[0] * descr.stride + start[1]; + result.descr.dtype = descr.dtype; + result.descr.layout = descr.layout; + result.descr.stride = descr.stride * step[0]; + result.descr.size = (len[0] - 1) * result.descr.stride + len[1]; + result.buffer = buffer; + return result; +} + +static MatrixView view(nb::handle_t arg) { + MatrixView result { }; + MatrixDescr &d = result.descr; + + const ArraySupplement &s = supp(arg.type()); + + d.dtype = (VarType) s.type; + d.layout = MatrixLayout::RowMajor; + + if (s.is_tensor) { + const dr::vector &shape = s.tensor_shape(inst_ptr(arg)); + if (shape.size() != 1 && shape.size() != 2) + nb::raise("drjit.view(): tensor must have 1 or 2 dimensions!"); + d.rows = shape[0]; + d.cols = shape.size() > 1 ? shape[1] : 1; + result.buffer = nb::steal(s.tensor_array(arg.ptr())); + } else if (s.ndim == 1 && s.shape[0] == DRJIT_DYNAMIC) { + d.rows = nb::len(arg); + d.cols = 1; + result.buffer = nb::borrow(arg); + } else { + nb::raise("Unsupported input type!"); + } + + d.stride = d.cols; + d.size = d.rows * d.cols; + d.offset = 0; + + if (d.rows == 0 || d.cols == 0) + nb::raise("drjit.view(): input array/tensor may not be empty!"); + + return result; +} + +struct RepackItem { + nb::object in_o; + nb::object out_o; + MatrixView *in; + MatrixView *out; + + RepackItem(nb::handle in_o, nb::handle out_o, MatrixView *in, MatrixView *out) + : in_o(nb::borrow(in_o)), out_o(nb::borrow(out_o)), in(in), out(out) { } + RepackItem(RepackItem&&) = default; + RepackItem(const RepackItem&) = default; +}; + +nb::handle view_type; +nb::handle coop_vector_type; + +static nb::object repack_impl(const char *name, MatrixLayout layout, + nb::handle arg_, uint32_t &offset, + std::vector &items) { + nb::handle arg_tp = arg_.type(); + nb::object arg = nb::borrow(arg_); + + if (is_drjit_type(arg_tp) && layout != MatrixLayout::RowMajor) { + arg = nb::cast(view(nb::handle_t(arg))); + arg_tp = view_type; + } + + if (arg_tp.is(view_type)) { + MatrixView *in_view = nb::cast(arg, false); + uint64_t in_index = supp(in_view->buffer.type()).index(inst_ptr(in_view->buffer)); + MatrixDescr out_descr = + jit_coop_vec_compute_layout(in_index, &in_view->descr, layout, offset); + MatrixView *out_view = new MatrixView{out_descr, nb::none()}; + nb::object result = nb::cast(out_view, nb::rv_policy::take_ownership); + items.emplace_back(arg, result, in_view, out_view); + offset = out_descr.offset + out_descr.size; + return result; + } else if (arg_tp.is(&PyTuple_Type)) { + nb::tuple t = nb::borrow(arg); + nb::list result; + for (nb::handle h : t) + result.append(repack_impl(name, layout, h, offset, items)); + return nb::tuple(result); + } else if (arg_tp.is(&PyList_Type)) { + nb::list l = nb::borrow(arg); + nb::list result; + for (nb::handle h : l) + result.append(repack_impl(name, layout, h, offset, items)); + return std::move(result); + } else if (arg_tp.is(&PyDict_Type)) { + nb::dict d = nb::borrow(arg); + nb::dict result; + for (auto [k, v] : d) + result[k] = repack_impl(name, layout, v, offset, items); + return std::move(result); + } else if (nb::dict ds = get_drjit_struct(arg_tp); ds.is_valid()) { + nb::object tmp = arg_tp(); + for (auto [k, v] : ds) + nb::setattr(tmp, k, repack_impl(name, layout, nb::getattr(arg, k), offset, items)); + return tmp; + } else if (nb::object df = get_dataclass_fields(arg_tp); df.is_valid()) { + nb::object tmp = nb::dict(); + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + tmp[k] = repack_impl(name, layout, nb::getattr(arg, k), offset, items); + } + return arg_tp(**tmp); + } else { + return nb::borrow(arg); + } +} + +static std::pair repack(const char *name, const char *layout_str, nb::handle arg) { + uint32_t offset = 0; + std::vector items; + MatrixLayout layout; + + if (layout_str) { + if (strcmp(layout_str, "inference") == 0) + layout = MatrixLayout::InferencingOptimal; + else if (strcmp(layout_str, "training") == 0) + layout = MatrixLayout::TrainingOptimal; + else + nb::raise("drjit.%s(): 'mode' must equal \"inference\" or \"training\"!", name); + } else { + layout = MatrixLayout::RowMajor; + } + + nb::object result = repack_impl(name, layout, arg, offset, items); + nb::object buffer = nb::none(); + + if (items.size() > 0) { + nb::handle buf_cur = items[0].in->buffer, + buf_tp = buf_cur.type(); + + buffer = full("zeros", buf_tp, nb::int_(0), offset, true); + const ArraySupplement &s = supp(buf_tp); + + std::vector in, out; + in.reserve(items.size()); + out.reserve(items.size()); + + auto submit = [&] { + jit_coop_vec_pack_matrices( + (uint32_t) in.size(), + s.index(inst_ptr(buf_cur)), + in.data(), + s.index(inst_ptr(buffer)), + out.data() + ); + }; + + for (size_t i = 0; i < items.size(); ++i) { + nb::handle buf_i = items[i].in->buffer, + buf_i_tp = buf_i.type(); + + if (!buf_i_tp.is(buf_tp)) { + nb::raise_type_error( + "drjit.%s(): encountered different input formats (%s vs %s)", name, + nb::type_name(buf_tp).c_str(), + nb::type_name(buf_i_tp).c_str()); + } + + if (!buf_cur.is(buf_i)) { + submit(); + in.clear(); + out.clear(); + buf_cur = buf_i; + } + + items[i].out->buffer = buffer; + + in.push_back(items[i].in->descr); + out.push_back(items[i].out->descr); + } + + if (!in.empty()) + submit(); + } + + return { buffer, result }; +} + +static CoopVec coopvec_abs_workaround(nb::handle_t &v) { + nb::list result; + for (nb::handle h: v) + result.append(nb::steal(PyNumber_Absolute(h.ptr()))); + return CoopVec(result); +} + +void export_coop_vec(nb::module_ &m) { + nb::module_ nn = m.def_submodule("detail").def_submodule("nn"); + nn.attr("__name__") = "drjit.nn"; + + nn.attr("ArrayT") = nb::type_var("ArrayT", "bound"_a = "drjit.ArrayBase"); + for (const char *name : + { "T", "SelfT", "SelfCpT", "ValT", "ValCpT", "RedT", "PlainT", "MaskT" }) + nn.attr(name) = nb::type_var(name); + + coop_vector_type = nb::class_(nn, "CoopVec", nb::is_generic(), nb::sig("class CoopVec(typing.Generic[T])")) + .def(nb::init(), + nb::sig("def __init__(self, *args: typing.Unpack[typing.Tuple[typing.Union[drjit.ArrayBase[SelfT, SelfCpT, ValT, ValCpT, T, PlainT, MaskT], float, int], ...]]) -> None"), + doc_coop_CoopVec_init) + .def("__iter__", [](const CoopVec &v) { return iter(v.expand_to_list()); }, + nb::sig("def __iter__(self, /) -> typing.Iterator[T]")) + .def("__add__", &coop_vec_binary_op, + nb::sig("def __add__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__radd__", &coop_vec_binary_op, + nb::sig("def __radd__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__sub__", &coop_vec_binary_op, + nb::sig("def __sub__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__rsub__", &coop_vec_binary_op, + nb::sig("def __rsub__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__mul__", &coop_vec_binary_op, + nb::sig("def __mul__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def("__rmul__", &coop_vec_binary_op, + nb::sig("def __rmul__(self, arg: CoopVec[T] | T | float | int, /) -> CoopVec[T]")) + .def_prop_ro("index", [](const CoopVec &v) { return v.m_index; }) + .def_prop_ro("type", [](const CoopVec &v) { return v.m_type; }) + .def("__len__", [](const CoopVec &v) { return v.m_size; }) + .def("__abs__", &coopvec_abs_workaround) + .def("__repr__", + [](const CoopVec &v) { + return nb::str("drjit.nn.CoopVec[{}, shape=({}, {})]") + .format(nb::type_name(v.m_type), v.m_size, + jit_var_size(v.m_index)); + }); + + view_type = nb::class_(nn, "MatrixView", doc_coop_MatrixView) + .def(nb::init<>()) + .def("__repr__", &MatrixView::repr) + .def("__getitem__", &MatrixView::getitem, + nb::sig("def __getitem__(self, arg: typing.Union[int, slice, typing.Tuple[typing.Union[int, slice], typing.Union[int, slice]]]) -> MatrixView")) + .def_prop_rw("dtype", + [](MatrixView &v) { return v.descr.dtype; }, + [](MatrixView &v, VarType v2) { v.descr.dtype = v2; }) + .def_prop_rw("offset", + [](MatrixView &v) { return v.descr.offset; }, + [](MatrixView &v, uint32_t v2) { v.descr.offset = v2; }) + .def_prop_rw("stride", + [](MatrixView &v) { return v.descr.stride; }, + [](MatrixView &v, uint32_t v2) { v.descr.stride = v2; }) + .def_prop_rw("size", + [](MatrixView &v) { return v.descr.size; }, + [](MatrixView &v, uint32_t v2) { v.descr.size = v2; }) + .def_prop_rw("layout", + [](MatrixView &v) { + switch (v.descr.layout) { + case MatrixLayout::InferencingOptimal: return "inference"; + case MatrixLayout::TrainingOptimal: return "training"; + case MatrixLayout::RowMajor: return "row_major"; + default: return "unknown"; + } + }, + [](MatrixView &v, const char *s) { + if (strcmp(s, "inference") == 0) + v.descr.layout = MatrixLayout::InferencingOptimal; + else if (strcmp(s, "training") == 0) + v.descr.layout = MatrixLayout::TrainingOptimal; + else if (strcmp(s, "row_major") == 0) + v.descr.layout = MatrixLayout::RowMajor; + else + nb::raise("Unknown layout!"); + }, + nb::for_getter(nb::sig("def layout(self) -> typing.Literal['inference', 'training', 'row_major']")), + nb::for_setter(nb::sig("def layout(self, value: typing.Literal['inference', 'training', 'row_major']) -> None"))) + .def_prop_rw("transpose", + [](MatrixView &v) { return v.transpose; }, + [](MatrixView &v, bool v2) { v.transpose = v2; }) + .def_prop_rw("shape", + [](MatrixView &v) { + return std::make_pair(v.descr.rows, v.descr.cols); + }, + [](MatrixView &v, std::pair v2) { + v.descr.rows = v2.first; + v.descr.cols = v2.second; + }) + .def("__matmul__", [](const MatrixView &self, const CoopVec &x) { return matvec(self, x, {}, false); }, + nb::sig("def __matmul__(self, arg: CoopVec[T], /) -> CoopVec[T]")) + .def_rw("buffer", &MatrixView::buffer) + .def_prop_ro("T", + [](MatrixView &v) { + MatrixView r; + r.descr = v.descr; + r.buffer = v.buffer; + r.transpose = !v.transpose; + return r; + }) + .def_prop_ro("grad", + [](MatrixView &v) { + MatrixView r; + r.descr = v.descr; + r.buffer = v.buffer.attr("grad"); + r.transpose = v.transpose; + return r; + }); + + + nb::dict drjit_struct; + drjit_struct["layout"] = nb::handle(&PyUnicode_Type); + drjit_struct["buffer"] = nb::none(); + drjit_struct["dtype"] = nb::type(); + drjit_struct["shape"] = nb::handle(&PyTuple_Type); + drjit_struct["offset"] = nb::handle(&PyLong_Type); + drjit_struct["size"] = nb::handle(&PyLong_Type); + drjit_struct["stride"] = nb::handle(&PyLong_Type); + drjit_struct["transpose"] = nb::handle(&PyBool_Type); + view_type.attr("DRJIT_STRUCT") = drjit_struct; + + nn.def("view", &view, + doc_coop_view); + + nn.def("pack", [](nb::handle arg, const char *layout) { return repack("pack", layout, arg); }, + nb::arg(), "layout"_a = "inference", + nb::sig("def pack(arg: MatrixView | drjit.AnyArray, *, layout: typing.Literal['inference', 'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, MatrixView]"), + doc_coop_pack); + + nn.def("pack", + [](nb::args args, const char *layout) { + auto temp = repack("pack", layout, args); + nb::list l; + l.append(temp.first); + l.extend(temp.second); + return nb::tuple(l); + }, + "args"_a, "layout"_a = "inference", + nb::sig("def pack(*args: PyTree, layout: typing.Literal['inference', " + "'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, " + "typing.Unpack[typing.Tuple[PyTree, ...]]]")); + + nn.def("unpack", [](nb::handle arg) { + return repack("unpack", nullptr, arg); }, + nb::sig("def unpack(arg: MatrixView | drjit.AnyArray, /) -> typing.Tuple[drjit.ArrayBase, MatrixView]"), + doc_coop_unpack); + + nn.def("unpack", + [](nb::args args) { + auto temp = repack("unpack", nullptr, args); + nb::list l; + l.append(temp.first); + l.extend(temp.second); + return nb::tuple(l); + }, + "args"_a, + nb::sig("def unpack(*args: PyTree) -> typing.Tuple[drjit.ArrayBase, " + "typing.Unpack[typing.Tuple[PyTree, ...]]]")); + + nn.def("matvec", &matvec, "A"_a.noconvert(), "x"_a.noconvert(), + "b"_a.noconvert() = nb::none(), "transpose"_a = false, + nb::sig("def matvec(A: MatrixView, x: drjit.nn.CoopVec[T], b: typing.Optional[MatrixView] = " + "None, /, transpose: bool = False) -> drjit.nn.CoopVec[T]"), + doc_coop_matvec); + + nn.def("cast", + [](CoopVec vec, nb::type_object_t tp) { + const ArraySupplement &s = supp(tp); + ArrayMeta m = supp(vec.m_type); + m.type = s.type; + nb::handle new_type = meta_get_type(m); + return CoopVec(ad_coop_vec_cast(vec.m_index, (VarType) s.type), + vec.m_size, new_type); + }, nb::sig("def cast(arg0: CoopVec[T], arg1: typing.Type[ArrayT], /) -> CoopVec[ArrayT]"), + doc_coop_cast + ); + + m.def("fma", &coop_vec_ternary_op); + m.def("minimum", &coop_vec_binary_op); + m.def("maximum", &coop_vec_binary_op); + m.def("step", &coop_vec_binary_op, doc_step); + m.def("log2", &coop_vec_unary_op); + m.def("exp2", &coop_vec_unary_op); + m.def("tanh", &coop_vec_unary_op); + m.def("step", [](nb::handle h0, nb::handle h1) { + return select( + nb::steal(PyObject_RichCompare(h0.ptr(), h1.ptr(), Py_LT)), + nb::int_(0), nb::int_(1)); + }); + m.def("abs", coopvec_abs_workaround); +} diff --git a/src/python/coop_vec.h b/src/python/coop_vec.h new file mode 100644 index 00000000..47b4e076 --- /dev/null +++ b/src/python/coop_vec.h @@ -0,0 +1,83 @@ +/* + src/coop_vec.h -- Python bindings for Cooperative CoopVecs + + Copyright (c) 2025 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. +*/ + +#include "common.h" + +extern void export_coop_vec(nb::module_ &m); + +/// Cooperative vector container data structure +struct CoopVec { + /// JIT variable ID + uint64_t m_index = 0; + /// Number of entries + uint32_t m_size = 0; + /// Element type + nb::handle m_type; + + CoopVec(nb::handle arg); + + /// Steals ownership of 'index' + CoopVec(uint64_t index, uint32_t size, nb::handle type) + : m_index(index), m_size(size), m_type(type) { } + + /// Copy constructor + CoopVec(const CoopVec &vec) + : m_index(vec.m_index), m_size(vec.m_size), m_type(vec.m_type) { + ad_var_inc_ref(m_index); + } + CoopVec(CoopVec &&vec) noexcept + : m_index(vec.m_index), m_size(vec.m_size), m_type(vec.m_type) { + vec.m_index = 0; + vec.m_size = 0; + vec.m_type = nb::handle(); + } + CoopVec& operator=(CoopVec &&x) { + ad_var_dec_ref(m_index); + m_index = x.m_index; + m_size = x.m_size; + m_type = x.m_type; + x.m_index = x.m_size = 0; + x.m_type = nb::handle(); + return *this; + } + ~CoopVec() { ad_var_dec_ref(m_index); } + + /// Expand a cooperative vector into a Python list + nb::list expand_to_list() const; + + /// Expand a cooperative vector into a Dr.Jit array type (e.g. ArrayXf) + nb::object expand_to_vector() const; + +private: + void construct(nb::handle arg); +}; + +/// Shared view into a matrix +struct MatrixView { + /// Shape, strides, etc. + MatrixDescr descr{}; + + /// Dr.Jit 1D array holding the data + nb::object buffer; + + /// Should the view be transposed? + bool transpose = false; + + MatrixView() = default; + MatrixView(const MatrixDescr &descr, const nb::handle &buffer) + : descr(descr), buffer(nb::borrow(buffer)), transpose(false) { } + + nb::str repr() const; + MatrixView getitem(nb::object arg) const; + uint64_t index() const; +}; + + +extern nb::handle view_type; +extern nb::handle coop_vector_type; diff --git a/src/python/detail.cpp b/src/python/detail.cpp index ef9f27e9..d9fc2cb2 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -307,7 +307,14 @@ void export_detail(nb::module_ &) { []() { int major, minor, patch; jit_llvm_version(&major, &minor, &patch); - return nb::str("{}.{}.{}").format(major, minor, patch); + return nb::make_tuple(major, minor, patch); + }) + + .def("cuda_version", + []() { + int major, minor; + jit_cuda_version(&major, &minor); + return nb::make_tuple(major, minor); }) .def("trace_func", &trace_func, "frame"_a, "event"_a, diff --git a/src/python/dlpack.cpp b/src/python/dlpack.cpp index a0ac00ed..b96f7543 100644 --- a/src/python/dlpack.cpp +++ b/src/python/dlpack.cpp @@ -252,6 +252,10 @@ void export_dlpack(nb::module_ &) { [](nb::handle_t h) { return nb::ndarray(dlpack(h, true).handle()); }, doc_array) + .def("to_numpy", // needed for Matplotlib + [](nb::handle_t h) { + return nb::ndarray(dlpack(h, true).handle()); + }, doc_array) .def("torch", [](nb::handle_t h) { nb::module_ torch = nb::module_::import_("torch.utils.dlpack"); diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 0046499e..76697ed1 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -579,7 +579,7 @@ Returns: object: The result of the operation ``arg*arg`` -.. topic:: pow +.. topic:: power Raise the first argument to a power specified via the second argument. @@ -591,7 +591,7 @@ reduces operation to a sequence of multiplies and adds (potentially followed by a reciprocation operation when ``arg1`` is negative). - The general case involves recursive use of the identity ``pow(arg0, arg1) = + The general case involves recursive use of the identity ``power(arg0, arg1) = exp2(log2(arg0) * arg1)``. There is no difference between using :py:func:`drjit.power()` and the builtin @@ -7311,7 +7311,7 @@ `__. The operation is a no-op when no profile collection tool is attached. - Note the difference between this context manager and :py:ref:`dr.profile_enable() + Note the difference between this context manager and :py:func:`dr.profile_enable() `, which enables targeted profiling of a smaller region of code (as opposed to profiling the entire program). @@ -7328,7 +7328,7 @@ code_to_be_profiled() Note the difference between this context manager and - :py:ref:`dr.profile_range() `, which annotates a profiled + :py:func:`dr.profile_range() `, which annotates a profiled region with a label. .. topic:: ReduceMode @@ -7503,7 +7503,7 @@ >>> from drjit.llvm.ad import TensorXf >>> value = dr.arange(TensorXf, 6) - >>> dr.reshape(dtype=TensorXf, value=value, shape=(3, -1)) + >>> dr.reshape(value, (3, -1)) [[0, 1] [2, 3] [4, 5]] @@ -7511,16 +7511,17 @@ 2. **Reshaping nested arrays**: The function can ravel and unravel nested arrays (which have some static dimensions). This provides a high-level interface that subsumes the functions :py:func:`drjit.ravel` and - :py:func:`drjit.unravel`. + :py:func:`drjit.unravel`. In this case, the target ``dtype`` must be + specified: .. code-block:: pycon >>> from drjit.llvm.ad import Array2f, Array3f >>> value = Array2f([1, 2, 3], [4, 5, 6]) - >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='C') + >>> dr.reshape(Array3f, value, shape=(3, -1), order='C') [[1, 4, 2], [5, 3, 6]] - >>> dr.reshape(dtype=Array3f, value=value, shape=(3, -1), order='F') + >>> dr.reshape(Array3f, value, shape=(3, -1), order='F') [[1, 3, 5], [2, 4, 6]] @@ -7604,11 +7605,12 @@ f'{size} elements in a queue of size {queue_size}') # Reshape the queue and re-run the loop - state = dr.reshape(dtype=type(state), value=queue, shape=size, shrink=True) + state = dr.reshape(queue, shape=size, shrink=True) Args: dtype (type): Desired output type of the reshaped array. This could equal ``type(value)`` or refer to an entirely different array type. + Must only be specified if the target dtype is different. value (object): An arbitrary Dr.Jit array, tensor, or :ref:`PyTree `. The function returns unknown objects of other types @@ -8104,3 +8106,178 @@ .. topic:: leak_warnings Query whether leak warnings are enabled. See :py:func:`drjit.detail.set_leak_warnings()`. + +.. topic:: step + + Step function. + + This function generates a step function by comparing ``arg0`` to ``arg1``. + The function is equivalent to + + .. code-block:: python + + dr.select( + arg0 < arg1, + 0, # if arg0 < arg1 + 1, # if arg1 >= arg1 + ) + + Args: + arg0 (object): A Dr.Jit array/tensor or Python arithmetic type + + arg1 (object): A Dr.Jit array/tensor or Python arithmetic type + + Returns: + object: The computed array as described above + +.. topic:: coop_CoopVec + + A *cooperative vector* is a dynamically-sized container of elements of a + consistent type. It admits both floating point and integer 1D arrays as + elements (e.g., :py:class:`drjit.cuda.Float16`, + :py:class:`drjit.llvm.UInt32`). + + Seen from a high level, cooperative vectors resemble nested array types, + such as as :py:class:`drjit.cuda.ArrayXf16`. A variety of conversions + between cooperative vectors and regular Dr.Jit arrays are possible. + + .. code-block:: python + + # Pack individual components into a cooperative vector + vec = drjit.nn.CoopVec(x, y, z) + + # Unpack components + x, y, z = vec + + # Unpack directly into 3D array + xyz = Array3f(vec) + + # Convert a 3D array and a 2D array into a 5D cooperative vector + a1: Array3f = ... + a2: Array2f = ... + vec = drjit.nn.CoopVec(a1, a2) + + The main difference between regular Dr.Jit arrays and cooperative vectors is + that they *do not permit indexed element access*. For example, the following + operation raises an Exception: + + .. code-block:: pycon + + >>> vec = drjit.nn.CoopVec(x, y, z) + >>> vec[1] + Traceback (most recent call last): + File "", line 1, in + TypeError: 'drjit.nn.CoopVec' object is not subscriptable + + The compilation stack may arbitrarily redistribute the elements of a + cooperative vector across threads for efficiency (this is what + *cooperative* refers to). Indexed access to a cooperative vector's elements + would interfere with such optimizations. + + To unpack a cooperative vector into its components, use an expression + like ``x, y, z = vec``, ``ArrayXf(vec)``, or ``list(vec)``. + +.. topic:: coop_CoopVec_init + + The constructor accepts a variable number of arguments including Dr.Jit + arrays, scalar Python integers and floating point values, and :ref:`PyTrees + `. It flattens this input into a list of vector components. + + At least one Jit-compiled array must be provided as input so that Dr.Jit can + infer the cooperative vector's element type. An exception will be raised if + the input contains Dr.Jit arrays of inconsistent scalar types (e.g., + :py:class:`drjit.cuda.Array2f` and :py:class:`drjit.cuda.UInt`). + +.. topic:: coop_MatrixView + + The :py:class:`drjit.nn.MatrixView` provides pointer into a buffer along with + shape and type metadata. + + Dr.Jit uses views to tightly pack sequences of matrices and bias vectors + into a joint buffer, and to preserve information about the underlying data + type and layout. The :py:func:`__getitem__` function can be used to slice a + view into smaller sub-blocks. + + The typical process is to pack a PyTree of weight and bias vectors via + :py:func:`drjit.pack()` into an inference or training-optimal + representation. The returned views can then be passed to + :py:func:`drjit.nn.matvec()`. + +.. topic:: coop_view + + Convert a Dr.Jit array or tensor into a *view*. + + This function simply returns a view of the original tensor without + transforming the underlying representation. This is useful to + + - Use :py:func:`drjit.nn.matvec` with a row-major matrix layout (which, + however, is not recommended, since this can be significantly slower + compared to matrices in inference/training-optimal layouts). + + - Slice a larger matrix into sub-blocks before passing them to + :py:func:`drjit.nn.pack` (which also accepts *views* as inputs). + This is useful when several matrices are already packed into a single + matrix (which is, however, still in row-major layout). They can then be + directly re-packed into optimal layouts without performing further + unnecessary copies. + +.. topic:: coop_pack + + A training-optimal layout must be used used if the program + *backpropagates* (as in :py:func:`dr.backward*() `) + gradients through matrix-vector products. Forward derivative propagation (as + in :py:func:`dr.forward*() `) does not require a + training-optimal layout. + + If the input matrices are already packed in a row-major layout, call + :py:func:`dr.nn.view() ` to create an efficient reference + and then pass slices of the view to :py:func:`dr.nn.pack() + `. This avoids additional copies. + + .. code-block:: + + mat: TensorXf = ... + mat_view = dr.nn.view(mat) + + A1_view, A2_view = dr.nn.pack( + mat_view[0:32, :], + mat_view[32:64, :] + ) + +.. topic:: coop_unpack + + The function :py:func:`dr.nn.unpack() ` transforms a + sequence (or :ref:`PyTree `) of vectors and optimal-layout matrices + back into row-major layout. + + .. code-block:: python + + A_out, b_out = dr.nn.unpack(A_opt, b_opt) + + Note that the output of this function are (row-major) *views* into a shared + buffer. These views can be converted back into regular tensors: + + .. code-block:: python + + A = TensorXf16(A) + +.. topic:: coop_matvec + + Evaluate a matrix-vector multiplication involving a cooperative vector. + + This function takes a *matrix view* ``A`` (see :py:func:`drjit.nn.pack` + and :py:func:`drjit.nn.view` for details on views) and a *cooperative + vector* ``x``. It then computes the associated matrix-vector product and + returns it in the form of a new cooperative vector (potentially with a + different size). + + The function can optionally apply an additive bias (i.e., to evaluate ``A@x + + b``). This bias vector ``b`` should also be specified as a view. + + Specify ``tranpose=True`` to multiply by the transpose of the matrix ``A``. + On the CUDA/OptiX backend, this feature requires that ``A`` is inference + or training-optimal layout. + +.. topic:: coop_cast + + Cast the numeric type underlying a cooperative vector diff --git a/src/python/eval.cpp b/src/python/eval.cpp index 01faa5e0..cd0e8afb 100644 --- a/src/python/eval.cpp +++ b/src/python/eval.cpp @@ -11,6 +11,7 @@ #include "eval.h" #include "apply.h" #include "local.h" +#include "coop_vec.h" bool schedule(nb::handle h) { bool result_ = false; @@ -31,6 +32,8 @@ bool schedule(nb::handle h) { for (uint32_t index : local.arrays()) result |= (bool) jit_var_schedule(index); } + if (h.type().is(coop_vector_type)) + nb::raise("Cooperative vectors cannot be evaluated. They must be unpacked into regular variables."); } }; diff --git a/src/python/init.cpp b/src/python/init.cpp index bb1e1b17..98bec4d1 100644 --- a/src/python/init.cpp +++ b/src/python/init.cpp @@ -12,12 +12,14 @@ #include #include #include "../ext/nanobind/src/buffer.h" +#include "drjit/python.h" #include "meta.h" #include "base.h" #include "memop.h" #include "shape.h" #include "dlpack.h" #include "init.h" +#include "coop_vec.h" #include /// Forward declaration @@ -134,7 +136,7 @@ int tp_init_array(PyObject *self, PyObject *args, PyObject *kwds) noexcept { } // Try to construct from an instance created by another - // array programming framework + // array programming framework or a Dr.Jit tensor nb::object converted_complex_scalar; if (is_drjit_tensor || (!arg_is_drjit && !is_builtin(arg_tp) && nb::ndarray_check(arg))) { // For scalar types we want to rely on broadcasting below @@ -142,14 +144,31 @@ int tp_init_array(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Import flattened array in C-style ordering nb::object flattened; - if (is_drjit_tensor) - flattened = nb::steal(supp(arg_tp).tensor_array(arg)); - else - flattened = import_ndarray(s, arg); - if (s.is_complex) do_flip_axes = true; + if (is_drjit_tensor) { + const ArraySupplement &as = supp(arg_tp); + const dr::vector &shape = as.tensor_shape(inst_ptr(arg)); + if (shape.size() != s.ndim) + nb::raise("dimensionality mismatch (target has %u, " + "source has %zu dimensions)", + s.ndim, shape.size()); + for (uint32_t d = 0; d < s.ndim; ++d) { + if (s.shape[d] == DRJIT_DYNAMIC) + continue; + size_t source_shape = + do_flip_axes ? shape[shape.size() - 1 - d] + : shape[d]; + if (s.shape[d] != source_shape) + nb::raise("mismatched shape (axis %u has size %u in target type, %zu in source tensor)", + d, s.shape[d], source_shape); + } + flattened = nb::steal(as.tensor_array(arg)); + } else { + flattened = import_ndarray(s, arg); + } + nb::object unraveled = unravel( nb::borrow>(self_tp), flattened, do_flip_axes ? 'F' : 'C'); @@ -645,6 +664,27 @@ static void ndarray_keep_alive(JitBackend backend, uint32_t index, nb::detail::n nb::object full_alt(nb::type_object dtype, nb::handle value, size_t size); nb::object empty_alt(nb::type_object dtype, size_t size); +nb::object view_to_tensor(nb::handle h, dr::vector &shape) { + MatrixView &view = nb::cast(nb::handle(h)); + if (view.transpose) + nb::raise("The view is transposed. Conversion into tensor format still " + "needs to be implemented."); + + if (view.descr.layout != MatrixLayout::RowMajor) + nb::raise("This tensor is in an inference/training-optimal layout. To " + "convert it back into tensor form, you must unpack it into a " + "row-major representation via drjit.nn.unpack()."); + + if (view.descr.stride != view.descr.cols) + nb::raise("Unsupported row stride!"); + + shape.push_back(view.descr.rows); + shape.push_back(view.descr.cols); + + return view.buffer[nb::slice(view.descr.offset, + view.descr.offset + view.descr.size, 1u)]; +} + int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { PyTypeObject *self_tp = Py_TYPE(self); @@ -660,7 +700,9 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { bool do_flip_axes = flip_axes == Py_True; PyTypeObject *array_tp = array ? Py_TYPE(array) : nullptr; - raise_if(do_flip_axes && (shape || !array_tp || !is_drjit_type(array_tp) || + raise_if(do_flip_axes && (shape || !array_tp || + (!is_drjit_type(array_tp) && + !nb::handle(array_tp).is(coop_vector_type)) || array_tp == self_tp), "flip_axes=True requires that 'shape' is not specified, and " "that the input is a nested Dr.Jit array type (e.g. " @@ -676,6 +718,14 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Same type -> copy constructor if (array_tp == self_tp) { + if (shape) + nb::raise( + "use 'Tensor(x.array, shape)' or 'drjit.reshape(Tensor, x, " + "shape)' to reshape a tensor"); + if (do_flip_axes) + nb::raise("The flip_axes argument is only supported when " + "constructing tensors from N-D arrays or cooperative " + "vectors"); nb::detail::nb_inst_copy(self, array); return 0; } @@ -690,6 +740,8 @@ int tp_init_tensor(PyObject *self, PyObject *args, PyObject *kwds) noexcept { // Try to construct from an instance created by another // array programming framework flat = import_ndarray(s, array, &shape_vec); + } else if (nb::isinstance(nb::handle(array))) { + flat = view_to_tensor(array, shape_vec); } else { // Infer the shape of an arbitrary data structure & flatten it VarType vt = (VarType) s.type; @@ -984,12 +1036,24 @@ nb::object linspace(const nb::type_object_t &dtype, if (size == 0) return dtype(); - nb::object result = nb::inst_alloc(counter_tp); - counter_s.init_counter((size_t) size, inst_ptr(result)); - nb::inst_mark_ready(result); + nb::object counter = nb::inst_alloc(counter_tp); + counter_s.init_counter((size_t) size, inst_ptr(counter)); + nb::inst_mark_ready(counter); + + nb::handle dtype_c = dtype; + if ((VarType) s.type == VarType::Float16) { + ArrayMeta m = s; + m.type = (uint16_t) VarType::Float32; + dtype_c = meta_get_type(m); + } double step = (stop - start) / (size - ((endpoint && size > 0) ? 1 : 0)); - return fma(dtype(result), dtype(step), dtype(start)); + nb::object result = fma(dtype_c(counter), dtype_c(step), dtype_c(start)); + + if (!dtype_c.is(dtype)) + result = dtype(result); + + return result; } /// Extract types from typing.Optional[T], typing.Union[T, None], etc. diff --git a/src/python/main.cpp b/src/python/main.cpp index 99712f3e..74fa3239 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -40,6 +40,7 @@ #include "tracker.h" #include "local.h" #include "resample.h" +#include "coop_vec.h" static int active_backend = -1; @@ -228,6 +229,7 @@ NB_MODULE(_drjit_ext, m_) { jit_init_async(backends); export_bind(detail); + export_coop_vec(m); export_base(m); export_init(m); export_shape(m); diff --git a/src/python/memop.cpp b/src/python/memop.cpp index 23d9c111..d0e3cc60 100644 --- a/src/python/memop.cpp +++ b/src/python/memop.cpp @@ -559,7 +559,8 @@ static void ravel_recursive(nb::handle result, nb::handle value, nb::object index = arange(nb::borrow>(index_dtype), offset, offset + strides[depth] * shape[depth], strides[depth]); - ::scatter(nb::borrow(result), nb::borrow(value), index, nb::cast(true)); + ::scatter(nb::borrow(result), nb::borrow(value), index, nb::cast(true), + ReduceMode::Permute); } else { result[offset] = value; } @@ -625,6 +626,26 @@ nb::object ravel(nb::handle h, char order, vt = (VarType) s.type; is_dynamic = s.shape[s.ndim - 1] == DRJIT_DYNAMIC; is_diff = s.is_diff; + } else if (nb::isinstance(h)) { + nb::object o = nb::borrow(h); + while (true) { + if (!nb::hasattr(o, "__len__") || nb::len(o) == 0) { + if (vt_in) + vt = (VarType) *vt_in; + break; + } + if (is_drjit_array(o)) { + const ArraySupplement &s = supp(o.type()); + backend = (JitBackend) s.backend; + vt = (VarType) s.type; + is_dynamic = s.ndim != 0 && s.shape[s.ndim - 1] == DRJIT_DYNAMIC; + is_diff = s.is_diff; + break; + } + o = o[0]; + } + } else if (nb::isinstance(h)) { + return ravel(nb::list(h), order, shape_out, strides_out, vt_in); } else if (vt_in) { vt = (VarType) *vt_in; } @@ -1065,6 +1086,18 @@ static nb::object reshape_2(nb::type_object dtype, nb::handle value, return reshape(dtype, value, shape_vec, order, shrink); } +static nb::object reshape_same_dtype(nb::handle value, + const dr::vector &target_shape, + char order, bool shrink) { + return reshape(nb::borrow(value.type()), value, target_shape, + order, shrink); +} + +static nb::object reshape_same_dtype_2(nb::handle value, Py_ssize_t shape, + char order, bool shrink) { + return reshape_2(nb::borrow(value.type()), value, shape, order, shrink); +} + static nb::object repeat_or_tile(nb::handle h, size_t count, bool tile) { struct RepeatOrTileOp : TransformCallback { size_t count; @@ -1149,6 +1182,10 @@ void export_memop(nb::module_ &m) { "shape"_a, "order"_a = 'A', "shrink"_a = false, doc_reshape) .def("reshape", &reshape_2, "dtype"_a, "value"_a, "shape"_a, "order"_a = 'A', "shrink"_a = false) + .def("reshape", &reshape_same_dtype, "value"_a, + "shape"_a, "order"_a = 'A', "shrink"_a = false, doc_reshape) + .def("reshape", &reshape_same_dtype_2, "value"_a, + "shape"_a, "order"_a = 'A', "shrink"_a = false) .def("tile", [](nb::handle h, size_t count) { return repeat_or_tile(h, count, true); diff --git a/src/python/meta.cpp b/src/python/meta.cpp index 940bafb8..ba3b962b 100644 --- a/src/python/meta.cpp +++ b/src/python/meta.cpp @@ -12,7 +12,6 @@ #include "base.h" #include "../ext/nanobind/src/buffer.h" #include -#include /// Check if the given metadata record is valid bool meta_check(ArrayMeta m) noexcept { diff --git a/src/python/random.h b/src/python/random.h index 0480f6ca..43d7799e 100644 --- a/src/python/random.h +++ b/src/python/random.h @@ -54,7 +54,7 @@ void bind_pcg32(nb::module_ &m) { } if (!key) - nb::raise_type_error("Invalid 'dtype'"); + nb::raise_type_error("PCG32.next_float(): invalid 'dtype'"); auto &&fn = self.attr(key); return !mask.is(Py_True) ? fn(mask) : fn(); diff --git a/src/python/reduce.cpp b/src/python/reduce.cpp index 17ed4dbb..9292d5cf 100644 --- a/src/python/reduce.cpp +++ b/src/python/reduce.cpp @@ -17,6 +17,7 @@ #include "init.h" #include "apply.h" #include "detail.h" +#include "coop_vec.h" #include using ReduceInit = nb::object(); @@ -542,6 +543,10 @@ nb::object dot(nb::handle h0, nb::handle h1) { } if (use_fma) { + if (tp0.is(coop_vector_type) || tp1.is(coop_vector_type)) { + nb::list o0 = nb::list(h0), o1 = nb::list(h1); + return dot(o1, o1); + } nb::object result = h0[0] * h1[0], fma = array_module.attr("fma"); for (size_t i = 1; i < lr; ++i) diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index ca68abf5..e2200bdd 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -17,6 +17,7 @@ #include "base.h" #include "local.h" #include "shape.h" +#include "coop_vec.h" #include #include #include @@ -332,6 +333,8 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { ctx.label.c_str(), nb::inst_name(prev).c_str(), nb::type_name(tp).c_str()); + // Were there any external changes to sub-PyTree variable indices (As + // opposed to changes done by the VariableTracker) bool changed = false; if (is_drjit_type(tp)) { @@ -402,8 +405,7 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { VarInfo vi = jit_set_backend((uint32_t) idx); if (new_variable) { - if (!v->index_orig) - v->index_orig = ad_var_inc_ref(idx); + v->index_orig = ad_var_inc_ref(idx); v->index = ad_var_inc_ref(idx); v->size = vi.size; } else { @@ -521,6 +523,41 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { ScopedAppendLabel guard(ctx, "[", nb::repr(kv[0]).c_str(), "]"); changed |= traverse(ctx, kv[1]); } + } else if (tp.is(coop_vector_type)) { + CoopVec *vec = nb::cast(h, false); + uint32_t idx = vec->m_index; + size_t size = size_valid(v, ctx.label, h, vec->m_size); + + if (new_variable) { + v->index_orig = ad_var_inc_ref(idx); + v->index = ad_var_inc_ref(idx); + } else { + changed = idx != v->index; + if (changed) { + uint64_t old = v->index; + v->index = ad_var_inc_ref(idx); + ad_var_dec_ref(old); + } + } + + if (!ctx.write && !changed && !new_variable) { + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(ctx, "[", i, "]"); + changed |= traverse(ctx, state.find(ctx.label)->second.value); + } + } else { + nb::list l(h), r; + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(ctx, "[", i, "]"); + changed |= traverse(ctx, l[i]); + } + if (ctx.write) { + *vec = CoopVec(l); + ad_var_inc_ref(vec->m_index); + ad_var_dec_ref(v->index); + v->index = vec->m_index; + } + } } else { nb::object traverse_cb = nb::getattr( h, ctx.write ? DR_STR(_traverse_1_cb_rw) : DR_STR(_traverse_1_cb_ro), @@ -631,7 +668,7 @@ void VariableTracker::verify_size(size_t size) { strcmp(jit_var_kind_name((uint32_t) v.index), "loop_phi") == 0) continue; - size_t size_2 = jit_var_size((uint32_t)v.index); + size_t size_2 = jit_var_size((uint32_t) v.index); if (size != size_2 && size != 1 && size_2 != 1 && !jit_var_is_dirty((uint32_t)v.index)) nb::raise("this operation processes arrays of size %zu, while " @@ -730,6 +767,11 @@ nb::object VariableTracker::Impl::restore(dr::string &label) { ScopedAppendLabel guard(label, "[", nb::repr(k).c_str(), "]"); d[k] = restore(label); } + } else if (tp.is(coop_vector_type)) { + CoopVec *vec = nb::cast(value, false); + ad_var_inc_ref(v->index_orig); + ad_var_dec_ref(vec->m_index); + vec->m_index = v->index_orig; } else { if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, _] : ds) { @@ -847,48 +889,58 @@ std::pair VariableTracker::Impl::rebuild(dr::string &label) { value = tmp; } } - } else { - if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - nb::object tmp = tp(); - for (auto [k, _] : ds) { - ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); - auto [o, n] = rebuild(label); - nb::setattr(tmp, k, o); - new_object |= n; - } - if (new_object) { - if (mutate) { - for (nb::handle k : ds.keys()) - nb::setattr(value, k, nb::getattr(tmp, k)); - new_object = false; - } else { - value = tmp; - } - } - } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - nb::dict tmp; - for (auto field : df) { - nb::object k = field.attr(DR_STR(name)); - ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); - auto [o, n] = rebuild(label); - tmp[k] = o; - new_object |= n; + } else if (tp.is(coop_vector_type)) { + size_t size = size_valid(v, label, value, nb::len(value)); + nb::list tmp; + + for (size_t i = 0; i < size; ++i) { + ScopedAppendLabel guard(label, "[", i, "]"); + auto [o, n] = rebuild(label); + tmp.append(o); + } + + value = nb::cast(CoopVec(tmp)); + new_object = true; + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + nb::object tmp = tp(); + for (auto [k, _] : ds) { + ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); + auto [o, n] = rebuild(label); + nb::setattr(tmp, k, o); + new_object |= n; + } + if (new_object) { + if (mutate) { + for (nb::handle k : ds.keys()) + nb::setattr(value, k, nb::getattr(tmp, k)); + new_object = false; + } else { + value = tmp; } - if (new_object) { - if (mutate) { - for (auto field : df) { - nb::object k = field.attr(DR_STR(name)); - nb::setattr(value, k, tmp[k]); - } - new_object = false; - } else { - value = tp(**tmp); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + nb::dict tmp; + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + ScopedAppendLabel guard(label, ".", nb::str(k).c_str()); + auto [o, n] = rebuild(label); + tmp[k] = o; + new_object |= n; + } + if (new_object) { + if (mutate) { + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + nb::setattr(value, k, tmp[k]); } + new_object = false; + } else { + value = tp(**tmp); } - } else if (!value.is(v->value)) { - value = v->value; - new_object = true; } + } else if (!value.is(v->value)) { + value = v->value; + new_object = true; } return { value, new_object }; diff --git a/tests/test_coop_vec.py b/tests/test_coop_vec.py new file mode 100644 index 00000000..968235e9 --- /dev/null +++ b/tests/test_coop_vec.py @@ -0,0 +1,527 @@ +import drjit as dr +import drjit.nn as nn +import pytest +import sys + +def skip_if_coopvec_not_supported(t): + if dr.backend_v(t) == dr.JitBackend.CUDA: + if dr.detail.cuda_version() < (12, 8): + pytest.skip("CUDA driver does not support cooperative vectors") + +@pytest.test_arrays('jit,float16,shape=(3, *),-diff', 'jit,float32,shape=(3, *),-diff') +def test01_pack_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test coop vector creation and unpacking + m = sys.modules[t.__module__] + v = dr.full(dr.value_t(t), 7, 32) + x = nn.CoopVec(t(1, 2, 3), t(4, 5, 6), v, 8) + assert len(x) == 8 + assert len(nn.CoopVec(*x, 2, (4, 5), *x)) == 19 + y = list(x) + z = m.ArrayXf(x) + result_ok = True + for i in range(8): + result_ok &= dr.all(y[i] == i+1) + result_ok &= dr.all(z[i] == i+1) + assert result_ok + + +@pytest.mark.parametrize('size', [0, 20, 10]) +@pytest.test_arrays('jit,float16,shape=(*),-diff', 'jit,float32,shape=(*),-diff') +def test02_add_sub(t, size): + skip_if_coopvec_not_supported(t) + + # Test addition and subtraction + x = nn.CoopVec(dr.full(t, 5, 32), 6, *tuple(range(size))) + y = x + 15 + z = y - 2 + r0, r1 = list(z)[0:2] + dr.schedule(r0, r1) + assert dr.all((r0 == 18) & (r1 == 19)) + +@pytest.mark.parametrize('size', [0, 20, 100]) +@pytest.test_arrays('jit,float16,shape=(*),-diff', 'jit,float32,shape=(*),-diff') +def test03_add_min_max_fma(t, size): + skip_if_coopvec_not_supported(t) + + # Test min/max/FMA operations + x = nn.CoopVec(t(5), 8, *tuple(range(size))) + x_min = dr.minimum(x, 6) + x_max = dr.maximum(x, 7) + # zero addition needed to work around a constant propagation bug in R570 driver.. + zero = dr.opaque(t, 0) + z = dr.fma(x_min, x_max, 1+zero) + r0, r1 = list(z)[0:2] + dr.schedule(r0, r1) + assert r0 == 36 and r1 == 49 + + +@pytest.mark.parametrize('sub_slice', [False, True]) +@pytest.test_arrays('jit,float16,shape=(*),-diff') +def test04_pack_unpack(t, sub_slice): + skip_if_coopvec_not_supported(t) + + # Test the nn.pack() and nn.unpack() memory operations + m = sys.modules[t.__module__] + extra = 2 if sub_slice else 0 + X = m.TensorXf16(dr.arange(t, 24*(32+extra)), (24, 32+extra)) + Xv = nn.view(X) + + assert Xv.dtype == dr.VarType.Float16 + assert Xv.offset == 0 + assert Xv.size == 24*(32+extra) + assert Xv.shape == (24, 32+extra) + assert Xv.stride == 32+extra + assert Xv.buffer is X.array + + Xv1 = Xv[0:16, 0:32] + Xv2 = Xv[16:, 0:32] + X1 = X[0:16, 0:32] + X2 = X[16:, 0:32] + + assert Xv1.dtype == dr.VarType.Float16 + assert Xv1.offset == 0 + assert Xv1.shape == (16, 32) + assert Xv1.stride == 32+extra + assert Xv1.size == (Xv1.shape[0] - 1) * Xv1.stride + Xv1.shape[1] + assert Xv1.buffer is X.array + + assert Xv2.dtype == dr.VarType.Float16 + assert Xv2.offset == 16*(32+extra) + assert Xv2.size == (Xv2.shape[0] - 1) * Xv2.stride + Xv2.shape[1] + assert Xv2.shape == (8, 32) + assert Xv2.stride == 32+extra + assert Xv2.buffer is X.array + + for i in range(2): + _, *Pa = nn.pack( + Xv1, Xv2, + layout='inference' if i == 0 else 'training' + ) + + _, X1a, X2a = nn.unpack(*Pa) + assert dr.all(m.TensorXf16(X1a) == X1[:, 0:32], axis=None) + assert dr.all(m.TensorXf16(X2a) == X2[:, 0:32], axis=None) + + +@pytest.mark.parametrize('shape', [(2, 8), (5, 2), (16, 16)]) +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.mark.parametrize('bias', [False, True]) +@pytest.mark.parametrize('pack', [False, True]) +@pytest.test_arrays('jit,tensor,float16,-diff', 'jit,tensor,float32,-diff') +def test05_matvec(t, shape, transpose, bias, pack): + skip_if_coopvec_not_supported(t) + + # Test matrix multiplication for various sizes and configurations (primal) + m = sys.modules[t.__module__] + Tensor = t + Float = dr.array_t(t) + + if dr.backend_v(t) == dr.JitBackend.CUDA: + if (not pack and shape[1] == 2) or \ + (not pack and transpose) or \ + dr.type_v(t) == dr.VarType.Float32: + pytest.skip("Unsupported configuration") + + output_size = shape[1] if transpose else shape[0] + input_size = shape[0] if transpose else shape[1] + + A = Tensor(m.PCG32(dr.prod(shape), 1).next_float_normal(Float), shape) + A_n = A.numpy() + + if bias: + b = Tensor(m.PCG32(output_size, 2).next_float_normal(Float)) + b_n = b.numpy() + else: + b = b_n = None + + if pack: + if bias: + _, A, b = nn.pack(A, b) + assert A.buffer is b.buffer + else: + _, A = nn.pack(A) + else: + A = nn.view(A) + if bias: + b = nn.view(b) + + rng_3 = m.PCG32(32, 3) + x = [rng_3.next_float_normal(Float) for _ in range(input_size)] + x_n = Tensor(x).numpy() + + x = nn.CoopVec(x) + r = nn.matvec(A, x, b, transpose=transpose) + r_n = Tensor(r).numpy() + + if transpose: + A_n = A_n.T + ref = A_n @ x_n + + if bias: + ref += b_n[:, None] + + assert dr.allclose(r_n, ref) + + +@pytest.test_arrays('jit,shape=(*),float16,-diff', 'jit,shape=(*),float32,-diff') +@pytest.mark.parametrize('op', ['exp2', 'log2', 'tanh']) +def test06_unary(t, op): + skip_if_coopvec_not_supported(t) + + # Test some special unary operations that are supported by coop vectors + func = getattr(dr, op) + x = nn.CoopVec(t(0.1), t(0.2), t(0.3)) + r = func(x) + x, y, z = r + dr.schedule(x, y, z) + assert dr.allclose(x[0], func(0.1), rtol=1e-3) + assert dr.allclose(y[0], func(0.2), rtol=1e-3) + assert dr.allclose(z[0], func(0.3), rtol=1e-3) + + +@pytest.test_arrays('jit,shape=(*),float16,-diff', 'jit,shape=(*),float32,-diff') +def test07_step(t): + skip_if_coopvec_not_supported(t) + + # Test the dr.step() function on coop vectors + x = nn.CoopVec(t(0.1), t(0.2)) + y = nn.CoopVec(t(0.15), t(0.15)) + z = dr.step(x, y) + r0, r1 = z + dr.schedule(r0, r1) + assert r0 == 0 and r1 == 1 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test08_fwd_grad_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test that forward gradients correctly propagate through coop vector creation and unpacking + a, b = t(1), t(2) + dr.enable_grad(a, b) + z = nn.CoopVec(a, b) # pack + assert dr.grad_enabled(z) + assert not dr.grad_enabled(dr.detach(z)) + x, y = z # unpack + a.grad = 4 + b.grad = 5 + dr.forward_to(x, y) + dr.schedule(x.grad, y.grad) + assert x.grad == 4 + assert y.grad == 5 + assert dr.grad_enabled(z) + dr.disable_grad(z) + assert not dr.grad_enabled(z) + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test09_bwd_grad_unpack(t): + skip_if_coopvec_not_supported(t) + + # Test that backward gradients correctly propagate through coop vector creation and unpacking + a, b = t(1), t(2) + dr.enable_grad(a, b) + z = nn.CoopVec(a, b) # pack + x, y = z # unpack + x.grad = 4 + y.grad = 5 + dr.backward_to(a, b) + dr.schedule(a.grad, b.grad) + assert a.grad == 4 + assert b.grad == 5 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test10_fwd_addition(t): + skip_if_coopvec_not_supported(t) + + # Propagate forward gradients through an addition + a, b = t(1), t(1) + c, d = t(1), t(1) + dr.enable_grad(a, b, c, d) + x0 = nn.CoopVec(a, b) + x1 = nn.CoopVec(c, d) + x2 = x0 + x1 + r0, r1 = x2 + a.grad = 1 + b.grad = 2 + c.grad = 100 + d.grad = 200 + dr.forward_to(r0, r1) + dr.schedule(r0.grad, r1.grad) + assert r0.grad == 101 and r1.grad == 202 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test11_bwd_mul(t): + skip_if_coopvec_not_supported(t) + + # Propagate forward gradients through an addition + a, b = t(8), t(9) + c, d = t(3), t(2) + dr.enable_grad(a, b, c, d) + x0 = nn.CoopVec(a, b) + x1 = nn.CoopVec(c, d) + x2 = x0 * x1 + r0, r1 = x2 + r0.grad = 1 + r1.grad = 10 + dr.backward_to(a, b, c, d) + dr.schedule(a.grad, b.grad, c.grad, d.grad) + assert a.grad == 3 and b.grad == 20 + assert c.grad == 8 and d.grad == 90 + + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test12_bwd_min_max_fma(t): + skip_if_coopvec_not_supported(t) + + # Check derivatives of supported binary/ternary operations + x = [ t(1), t(2), t(3), t(4) ] + y = t(5) + z = t(6) + minval = t(25) + maxval = t(12) + dr.enable_grad(x, y, z, minval, maxval) + q = nn.CoopVec(x) + + q = dr.fma(q, y, z) + q = dr.minimum(q, minval) + q = dr.maximum(q, maxval) + + a, b, c, d = q + dr.backward_from(a+b*2 + c*4 + d*8) + dr.schedule(x[0].grad, x[1].grad, x[2].grad, x[3].grad, y.grad, + z.grad, minval.grad, maxval.grad, a, b, c, d) + assert a[0] == 12 and b[0] == 16 and c[0] == 21 and d[0] == 25 + assert x[0].grad[0] == 0 and x[1].grad[0] == 10 and x[2].grad[0] == 20 and x[3].grad[0] == 0 + assert minval.grad[0] == 8 and maxval.grad[0] == 1 + +@pytest.test_arrays('jit,shape=(*),float16,diff', 'jit,shape=(*),float32,diff') +def test13_exp2_tanh_fwd(t): + skip_if_coopvec_not_supported(t) + + # Check derivatives of supported unary transcendental operations + x = t(2) + dr.enable_grad(x) + y = nn.CoopVec(x) + r0 = dr.exp2(y) + r1 = dr.tanh(y) + r0, = r0; r1, = r1 + dr.forward_from(x) + dr.schedule(r0, r1, r0.grad, r1.grad) + assert dr.allclose(r0[0], 4) + assert dr.allclose(r1[0], 0.9640275800758169, rtol=1e-3) + assert dr.allclose(r0.grad[0], 2.77259, rtol=1e-3) + assert dr.allclose(r1.grad[0], 0.0706508, rtol=1e-2) + + +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.mark.parametrize('has_A_grad', [False, True]) +@pytest.mark.parametrize('has_x_grad', [False, True]) +@pytest.mark.parametrize('has_b_grad', [None, False, True]) +@pytest.mark.parametrize('layout', ['training', 'inference']) +@pytest.test_arrays('jit,tensor,float16,diff') +def test14_matvec_fwd(t, transpose, has_A_grad, has_x_grad, has_b_grad, layout): + skip_if_coopvec_not_supported(t) + + # Test forward-propagation of derivatives from input through matrix multiplication + m = sys.modules[t.__module__] + Tensor = t + Float = dr.array_t(t) + Matrix2f = m.Matrix2f16 + Array2f = m.Array2f16 + + if not has_A_grad and not has_x_grad and not has_b_grad: + pytest.skip("Trivial configuration") + if dr.backend_v(Float) == dr.JitBackend.LLVM and layout == 'training': + pytest.skip("Layout not used in LLVM backend") + + # Set up 'A' matrix + A = [[4, 2], [5, 1]] + A_grad = [[2, 1], [1, -1]] + _, A_v = nn.pack(Tensor(A), layout=layout) + A_ref = Matrix2f(A) + if has_A_grad: + _, A_grad_v = nn.pack(Tensor(A_grad)) + assert not dr.grad_enabled(A_v) + dr.enable_grad(A_v) + assert dr.grad_enabled(A_v) + assert not dr.grad_enabled(dr.detach(A_v)) + A_v.buffer.grad = A_grad_v.buffer + dr.enable_grad(A_ref) + dr.set_grad(A_ref, A_grad) + + # Set up 'x' vector + x = Array2f(1, 2) + if has_x_grad: + dr.enable_grad(x) + x.grad = [2, 1] + x_v = nn.CoopVec(x) + + # Set up 'b' vector + b_v = None + b_ref = Array2f(0) + if has_b_grad is not None: + b1, b2 = Float(-1), Float(1) + b_ref = Array2f(b1, b2) + _, b_v = nn.pack(Tensor([-1, 1])) + + if has_b_grad is True: + dr.enable_grad(b_ref) + b_ref.grad = [1, -1] + _, b_grad_v = nn.pack(Tensor([1, -1])) + dr.enable_grad(b_v.buffer) + b_v.buffer.grad = b_grad_v.buffer + + # Compute the reference + if transpose: + A_ref = A_ref.T + y_ref = A_ref @ x + b_ref + + y = Array2f(nn.matvec(A_v, x_v, b_v, transpose)) + dr.forward_to(y, y_ref) + dr.schedule(y, y.grad, y_ref, y_ref.grad) + + # print(f"primal: y={y} vs ref={y_ref}") + # print(f"grad: y={y.grad} vs ref={y_ref.grad}") + + assert dr.all((y == y_ref) & (y.grad == y_ref.grad)) + + +@pytest.mark.parametrize('transpose', [False, True]) +@pytest.test_arrays('jit,tensor,float16,-diff') +def test15_matvec_in_vcall(t, transpose): + skip_if_coopvec_not_supported(t) + + # Check that mat-vec products still work as expected when done from a callable + Float = dr.array_t(t) + UInt32 = dr.uint32_array_t(Float) + size = 64 + A = dr.normal(t, (size, size)) + b = dr.normal(t, size) + _, A, b = nn.pack(A, b) + + def mult_it(): + x = nn.CoopVec( + Float(i/(size-1) - 0.5) for i in range(size) + ) + return list(nn.matvec(A, x, b, transpose=transpose))[0] + + r0 = mult_it() + r1 = dr.switch(UInt32(0), [mult_it]) + + dr.schedule(r0, r1) + assert dr.allclose(r0[0], r1[0]) + + # Try again without bias vector + b = None + + r0 = mult_it() + r1 = dr.switch(UInt32(0), [mult_it]) + + dr.schedule(r0, r1) + assert r0[0] == r1[0] + + +@pytest.mark.parametrize('in_vcall', [False, True]) +@pytest.test_arrays('jit,tensor,float16,diff') +def test16_matvec_bwd(t, in_vcall): + skip_if_coopvec_not_supported(t) + + # Test the reverse-mode derivative of a matrix-vector product + # (potentially in a vcall) + + m = sys.modules[t.__module__] + UInt32 = m.UInt32 + A = t([[1, 3], [-2, 4], [3, -2]]) + b = t([0, 0, 0]) + buffer, Av, bv = nn.pack(A, b, layout='training') + x = m.Array2f16(2, 4) + dr.enable_grad(x, buffer) + + def do_mul(x): + xv = nn.CoopVec(x) + yv = nn.matvec(Av, xv, bv) + return m.Array3f16(yv) + + if in_vcall: + y = dr.switch(UInt32(0), [do_mul], x) + else: + y = do_mul(x) + + z = dr.opaque(dr.array_t(t), 0) + + y.grad = (-2+z, 5+z, 10+z) + dr.backward_from(y) + grad_x = x.grad + + # print(f"{y=}") + # print(f"{grad_x=}") + + grad_x_ref = m.Array2f16(18, -6) + assert dr.all(grad_x_ref == grad_x) + + dr.schedule(grad_x) + _, grad_A = nn.unpack(Av.grad) + _, grad_b = nn.unpack(bv.grad) + + grad_A = t(grad_A) + grad_b = t(grad_b)[:, 0] + grad_A_ref = t([[-4, -8], [10, 20], [20, 40]]) + grad_b_ref = t([-2, 5, 10]) + assert dr.all(grad_A_ref == grad_A) + assert dr.all(grad_b_ref == grad_b) + + +@pytest.test_arrays('jit,shape=(*),float16,diff') +def test17_cast(t): + skip_if_coopvec_not_supported(t) + + z = dr.opaque(t, 0) + a = nn.CoopVec( + z + 1, + z + 2, + z + 3 + ) + b = nn.cast(a, dr.float32_array_t(t)) + c = nn.cast(b, dr.float16_array_t(t)) + x, y, z = c + dr.eval(x, y, z) + assert x[0] == 1 and y[0] == 2 and z[0] == 3 + + + +@pytest.test_arrays('jit,shape=(*),float32,-diff') +@dr.syntax +def test18_symbolic_loop_if_stmt(t): + skip_if_coopvec_not_supported(t) + + # Test that cooperative vectors can be passed through + # symbolic loops and conditionals + UInt32 = dr.uint32_array_t(t) + a = nn.CoopVec(t(1), t(2)) + i = UInt32(0) + + while i < 10: + if i > 5: + a += a + i += 1 + + x, y = a + dr.schedule(x, y, i) + assert x[0] == 16 and y[0] == 32 + + +@pytest.test_arrays('jit,shape=(*),float32,-diff') +@dr.syntax +def test19_no_eval(t): + skip_if_coopvec_not_supported(t) + + # Cooperative vectors cannot be evaluted via dr.eval() + UInt32 = dr.uint32_array_t(t) + a = nn.CoopVec(t(1), t(2)) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.eval(a) From 83285adeef87b44421324407b552c74f10dc3bc2 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Fri, 18 Apr 2025 09:39:24 +0900 Subject: [PATCH 3/9] Incorporate review feedback --- docs/coop_vec.rst | 1 + docs/nn.rst | 18 ++++++++++-------- docs/what.rst | 2 +- drjit/nn.py | 20 +++++++++++++++----- include/drjit/extra.h | 2 +- src/python/coop_vec.cpp | 14 +++++++------- src/python/docstr.rst | 35 +++++++++++++++++++---------------- src/python/eval.cpp | 10 ++++++++++ src/python/tracker.cpp | 2 +- tests/test_coop_vec.py | 9 +++++++-- 10 files changed, 72 insertions(+), 41 deletions(-) diff --git a/docs/coop_vec.rst b/docs/coop_vec.rst index 2a5b4489..0ce66789 100644 --- a/docs/coop_vec.rst +++ b/docs/coop_vec.rst @@ -28,6 +28,7 @@ Dr.Jit supports cooperative vectors on both of its backends: `__, leveraging built-in `tensor cores `__ for acceleration. + Driver version R570 or newer is required to use this feature. - On the **CPU (LLVM) backend**, compilation of cooperative vector operations targets the available instruction set extensions (AVX512, NEON, etc.). diff --git a/docs/nn.rst b/docs/nn.rst index aa658594..8e90b02c 100644 --- a/docs/nn.rst +++ b/docs/nn.rst @@ -90,14 +90,14 @@ mixed-precision training. net = net.alloc(TensorXf16, 2) # Convert to training-optimal layout - coeffs, net = nn.pack(net, layout='training') + weights, net = nn.pack(net, layout='training') print(net) - # Optimize a single precision copy of the parameters - opt = Adam(lr=1e-3, params={'coeffs': Float32(coeffs)}) + # Optimize a single-precision copy of the parameters + opt = Adam(lr=1e-3, params={'weights': Float32(weights)}) # This is an adaptive mixed-precision (AMP) optimization, where a half - # precision computation runs within a larger single precision program. + # precision computation runs within a larger single-precision program. # Gradient scaling is required to make this numerically well-behaved. scaler = GradScaler() @@ -105,15 +105,15 @@ mixed-precision training. for i in tqdm(range(40000)): # Update network state from optimizer - coeffs[:] = Float16(opt['coeffs']) + weights[:] = Float16(opt['weights']) # Generate jittered positions on [0, 1]^2 t = dr.arange(Float32, res) - p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res*res))) / res + p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res * res))) / res # Evaluate neural net + L2 loss img = Array3f(net(nn.CoopVec(p))) - loss = dr.squared_norm(tex.eval(p)-img) + loss = dr.squared_norm(tex.eval(p) - img) # Mixed-precision training: take suitably scaled steps dr.backward(scaler.scale(loss)) @@ -121,8 +121,10 @@ mixed-precision training. # Done optimizing, now let's plot the result t = dr.linspace(Float32, 0, 1, res) - p= Array2f(dr.meshgrid(t, t)) + p = Array2f(dr.meshgrid(t, t)) img = Array3f(net(nn.CoopVec(p))) + + # Convert 'img' with shape 3 x (N*N) into a N x N x 3 tensor img = dr.reshape(TensorXf(img, flip_axes=True), (res, res, 3)) import matplotlib.pyplot as plt diff --git a/docs/what.rst b/docs/what.rst index 59250170..04726ea0 100644 --- a/docs/what.rst +++ b/docs/what.rst @@ -24,7 +24,7 @@ Using Dr.Jit involves two steps: Perhaps the most significant difference to the majority of existing tools is that Dr.Jit is *not primarily* a machine learning library. While it does provide support for neural network :ref:`evaluation and training `, -it its sweet spot are non-neural programs characterized by *embarrassing +its sweet spot are non-neural programs characterized by *embarrassing parallelism*---that is to say, programs with large data-parallel regions. A good example of this are `Monte Carlo `__ methods with their diff --git a/drjit/nn.py b/drjit/nn.py index e0ae987f..b73aeb41 100644 --- a/drjit/nn.py +++ b/drjit/nn.py @@ -61,6 +61,16 @@ def __call__(self, arg: CoopVec, /) -> CoopVec: raise NotImplementedError(f"{type(self).__name__}.__call__() implementation is missing.") def _alloc(self, dtype: Type[drjit.ArrayBase], size: int, /) -> Tuple[Module, int]: + """ + Internal method used to propagate argument sizes and allocate weight + storage of all NN modules. + + The method takes to parameters as input: a weight storage type + ``dtype`` (e.g., :py:class:`drjit.cuda.ad.TensorXf16`) and ``size``, + the number of input arguments of the module. The function returns a + potentially new module instance with allocated weights, plus the number + of outputs. + """ return self, size def alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1) -> Module: @@ -110,7 +120,7 @@ def __len__(self): """Return the number of contained models""" return len(self.layers) - def __getitem__(self, index: Union[int], /) -> Module: # type: ignore + def __getitem__(self, index: int, /) -> Module: # type: ignore """Return the model at position ``index``""" return self.layers[index] @@ -155,8 +165,8 @@ class LeakyReLU(Module): \end{cases} """ - DRJIT_STRUCT = { 'negative_slope': float } - def __init__(self, negative_slope: float = 1e-2): + DRJIT_STRUCT = { 'negative_slope': Union[float, drjit.ArrayBase] } + def __init__(self, negative_slope: Union[float, drjit.ArrayBase] = 1e-2): self.negative_slope = negative_slope def __call__(self, arg: CoopVec, /) -> CoopVec: @@ -449,8 +459,8 @@ def __init__(self, octaves: int = 0, shift: float = 0) -> None: if shift == 0: self.shift = None else: - self.shift = (drjit.sin(shift*2*drjit.pi), - drjit.cos(shift*2*drjit.pi)) + self.shift = (drjit.sin(shift * 2 * drjit.pi), + drjit.cos(shift * 2 * drjit.pi)) def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: return self, size * self.octaves * 2 diff --git a/include/drjit/extra.h b/include/drjit/extra.h index 5d94e57c..a8148d2a 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -546,7 +546,7 @@ extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_matvec(uint64_t A_index, int transpose); /// Cast a cooperative vector to a different precision -extern JIT_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt); +extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt); #if defined(__cplusplus) } diff --git a/src/python/coop_vec.cpp b/src/python/coop_vec.cpp index 4d7a1b5a..d0314c52 100644 --- a/src/python/coop_vec.cpp +++ b/src/python/coop_vec.cpp @@ -561,7 +561,7 @@ void export_coop_vec(nb::module_ &m) { coop_vector_type = nb::class_(nn, "CoopVec", nb::is_generic(), nb::sig("class CoopVec(typing.Generic[T])")) .def(nb::init(), nb::sig("def __init__(self, *args: typing.Unpack[typing.Tuple[typing.Union[drjit.ArrayBase[SelfT, SelfCpT, ValT, ValCpT, T, PlainT, MaskT], float, int], ...]]) -> None"), - doc_coop_CoopVec_init) + doc_nn_CoopVec_init) .def("__iter__", [](const CoopVec &v) { return iter(v.expand_to_list()); }, nb::sig("def __iter__(self, /) -> typing.Iterator[T]")) .def("__add__", &coop_vec_binary_op, @@ -587,7 +587,7 @@ void export_coop_vec(nb::module_ &m) { jit_var_size(v.m_index)); }); - view_type = nb::class_(nn, "MatrixView", doc_coop_MatrixView) + view_type = nb::class_(nn, "MatrixView", doc_nn_MatrixView) .def(nb::init<>()) .def("__repr__", &MatrixView::repr) .def("__getitem__", &MatrixView::getitem, @@ -669,12 +669,12 @@ void export_coop_vec(nb::module_ &m) { view_type.attr("DRJIT_STRUCT") = drjit_struct; nn.def("view", &view, - doc_coop_view); + doc_nn_view); nn.def("pack", [](nb::handle arg, const char *layout) { return repack("pack", layout, arg); }, nb::arg(), "layout"_a = "inference", nb::sig("def pack(arg: MatrixView | drjit.AnyArray, *, layout: typing.Literal['inference', 'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, MatrixView]"), - doc_coop_pack); + doc_nn_pack); nn.def("pack", [](nb::args args, const char *layout) { @@ -692,7 +692,7 @@ void export_coop_vec(nb::module_ &m) { nn.def("unpack", [](nb::handle arg) { return repack("unpack", nullptr, arg); }, nb::sig("def unpack(arg: MatrixView | drjit.AnyArray, /) -> typing.Tuple[drjit.ArrayBase, MatrixView]"), - doc_coop_unpack); + doc_nn_unpack); nn.def("unpack", [](nb::args args) { @@ -710,7 +710,7 @@ void export_coop_vec(nb::module_ &m) { "b"_a.noconvert() = nb::none(), "transpose"_a = false, nb::sig("def matvec(A: MatrixView, x: drjit.nn.CoopVec[T], b: typing.Optional[MatrixView] = " "None, /, transpose: bool = False) -> drjit.nn.CoopVec[T]"), - doc_coop_matvec); + doc_nn_matvec); nn.def("cast", [](CoopVec vec, nb::type_object_t tp) { @@ -721,7 +721,7 @@ void export_coop_vec(nb::module_ &m) { return CoopVec(ad_coop_vec_cast(vec.m_index, (VarType) s.type), vec.m_size, new_type); }, nb::sig("def cast(arg0: CoopVec[T], arg1: typing.Type[ArrayT], /) -> CoopVec[ArrayT]"), - doc_coop_cast + doc_nn_cast ); m.def("fma", &coop_vec_ternary_op); diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 76697ed1..e06df9fd 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -8130,12 +8130,14 @@ Returns: object: The computed array as described above -.. topic:: coop_CoopVec +.. topic:: nn_CoopVec A *cooperative vector* is a dynamically-sized container of elements of a consistent type. It admits both floating point and integer 1D arrays as elements (e.g., :py:class:`drjit.cuda.Float16`, - :py:class:`drjit.llvm.UInt32`). + :py:class:`drjit.llvm.UInt32`). Cooperative vectors primarily exist to + enable the compilation of expressions that make use of matrix-vector + multiplication. Seen from a high level, cooperative vectors resemble nested array types, such as as :py:class:`drjit.cuda.ArrayXf16`. A variety of conversions @@ -8177,7 +8179,7 @@ To unpack a cooperative vector into its components, use an expression like ``x, y, z = vec``, ``ArrayXf(vec)``, or ``list(vec)``. -.. topic:: coop_CoopVec_init +.. topic:: nn_CoopVec_init The constructor accepts a variable number of arguments including Dr.Jit arrays, scalar Python integers and floating point values, and :ref:`PyTrees @@ -8188,7 +8190,7 @@ the input contains Dr.Jit arrays of inconsistent scalar types (e.g., :py:class:`drjit.cuda.Array2f` and :py:class:`drjit.cuda.UInt`). -.. topic:: coop_MatrixView +.. topic:: nn_MatrixView The :py:class:`drjit.nn.MatrixView` provides pointer into a buffer along with shape and type metadata. @@ -8203,7 +8205,7 @@ representation. The returned views can then be passed to :py:func:`drjit.nn.matvec()`. -.. topic:: coop_view +.. topic:: nn_view Convert a Dr.Jit array or tensor into a *view*. @@ -8221,13 +8223,13 @@ directly re-packed into optimal layouts without performing further unnecessary copies. -.. topic:: coop_pack +.. topic:: nn_pack - A training-optimal layout must be used used if the program - *backpropagates* (as in :py:func:`dr.backward*() `) - gradients through matrix-vector products. Forward derivative propagation (as - in :py:func:`dr.forward*() `) does not require a - training-optimal layout. + A training-optimal layout must be used used if the program *backpropagates* + (as in :py:func:`dr.backward*() `) gradients through + matrix-vector products. Inference (primal evaluation) and forward derivative + propagation (as in :py:func:`dr.forward*() `) does not + require a training-optimal layout. If the input matrices are already packed in a row-major layout, call :py:func:`dr.nn.view() ` to create an efficient reference @@ -8244,7 +8246,7 @@ mat_view[32:64, :] ) -.. topic:: coop_unpack +.. topic:: nn_unpack The function :py:func:`dr.nn.unpack() ` transforms a sequence (or :ref:`PyTree `) of vectors and optimal-layout matrices @@ -8255,13 +8257,14 @@ A_out, b_out = dr.nn.unpack(A_opt, b_opt) Note that the output of this function are (row-major) *views* into a shared - buffer. These views can be converted back into regular tensors: + buffer. Each view holds a reference to the shared buffer. Views can be + converted back into regular tensors: .. code-block:: python A = TensorXf16(A) -.. topic:: coop_matvec +.. topic:: nn_matvec Evaluate a matrix-vector multiplication involving a cooperative vector. @@ -8275,9 +8278,9 @@ + b``). This bias vector ``b`` should also be specified as a view. Specify ``tranpose=True`` to multiply by the transpose of the matrix ``A``. - On the CUDA/OptiX backend, this feature requires that ``A`` is inference + On the CUDA/OptiX backend, this feature requires that ``A`` is in inference or training-optimal layout. -.. topic:: coop_cast +.. topic:: nn_cast Cast the numeric type underlying a cooperative vector diff --git a/src/python/eval.cpp b/src/python/eval.cpp index cd0e8afb..4bd5a2ff 100644 --- a/src/python/eval.cpp +++ b/src/python/eval.cpp @@ -68,6 +68,16 @@ static void make_opaque(nb::handle h) { ad_var_dec_ref(index_new); } + + void traverse_unknown(nb::handle h) override { + if (h.type().is(local_type)) { + Local & local = nb::cast(h); + for (uint32_t index : local.arrays()) + result |= (bool) jit_var_schedule(index); + } + if (h.type().is(coop_vector_type)) + nb::raise("Cooperative vectors cannot be evaluated. They must be unpacked into regular variables."); + } }; ScheduleForceCallback sfc; diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index e2200bdd..f974cb0c 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -333,7 +333,7 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { ctx.label.c_str(), nb::inst_name(prev).c_str(), nb::type_name(tp).c_str()); - // Were there any external changes to sub-PyTree variable indices (As + // Were there any external changes to sub-PyTree variable indices (as // opposed to changes done by the VariableTracker) bool changed = false; diff --git a/tests/test_coop_vec.py b/tests/test_coop_vec.py index 968235e9..8efc5f29 100644 --- a/tests/test_coop_vec.py +++ b/tests/test_coop_vec.py @@ -6,7 +6,7 @@ def skip_if_coopvec_not_supported(t): if dr.backend_v(t) == dr.JitBackend.CUDA: if dr.detail.cuda_version() < (12, 8): - pytest.skip("CUDA driver does not support cooperative vectors") + pytest.skip("CUDA driver does not support cooperative vectors (Driver R570) or later is required") @pytest.test_arrays('jit,float16,shape=(3, *),-diff', 'jit,float32,shape=(3, *),-diff') def test01_pack_unpack(t): @@ -20,6 +20,7 @@ def test01_pack_unpack(t): assert len(nn.CoopVec(*x, 2, (4, 5), *x)) == 19 y = list(x) z = m.ArrayXf(x) + assert len(y) == 8 and len(z) == 8 result_ok = True for i in range(8): result_ok &= dr.all(y[i] == i+1) @@ -258,7 +259,7 @@ def test10_fwd_addition(t): def test11_bwd_mul(t): skip_if_coopvec_not_supported(t) - # Propagate forward gradients through an addition + # Propagate forward gradients through a multiplication a, b = t(8), t(9) c, d = t(3), t(2) dr.enable_grad(a, b, c, d) @@ -523,5 +524,9 @@ def test19_no_eval(t): # Cooperative vectors cannot be evaluted via dr.eval() UInt32 = dr.uint32_array_t(t) a = nn.CoopVec(t(1), t(2)) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.schedule(a) with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): dr.eval(a) + with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): + dr.make_opaque(a) From 3529fe435056a352c08585d097d005435101b16c Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Mon, 21 Apr 2025 23:37:01 +0900 Subject: [PATCH 4/9] Evaluated loop derivative improvement This commit improves handling of evaluated loops with grad-enabled state variables. Previously, the AD variable ID of each differentiable state variable changed in every iteration, even if the loop did not touch that variable. This is an implementation detail of the loop evaluation code, that should, however, not leak into user code. This commit fixes this behavior. --- include/drjit/extra.h | 3 +++ src/extra/autodiff.cpp | 14 +++++++++++ src/extra/loop.cpp | 56 +++++++++++++++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index a8148d2a..f10dd932 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -488,6 +488,9 @@ extern DRJIT_EXTRA_EXPORT uint64_t ad_var_map_get(uint64_t index); extern DRJIT_EXTRA_EXPORT int ad_leak_warnings(); extern DRJIT_EXTRA_EXPORT void ad_set_leak_warnings(int value); +/// Extract the i-th predecessor of an AD node (or return 0) +extern DRJIT_EXTRA_EXPORT uint32_t ad_pred(uint32_t index, uint32_t i); + #if defined(__GNUC__) DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT { /* If 'index' is known at compile time, it can only be zero, in diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index e4d1b128..3b2a365b 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -2344,6 +2344,20 @@ void ad_mark_loop_boundary(Index index) { } } +uint32_t ad_pred(uint32_t ad_index, uint32_t i_) { + std::lock_guard guard(state.lock); + const Variable *v = state[ad_index]; + uint32_t edge = v->next_bwd; + + for (uint32_t i = 0; i < i_; ++i) { + if (!edge) + return 0; + edge = state.edges[edge].next_bwd; + } + + return state.edges[edge].source; +} + // ========================================================================== // Implementation of arithmetic operations and transcendental functions diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index ddcb66ab..2a75970a 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -156,6 +156,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, JitVar active_it; size_t it = 0; bool grad_suspended = ad_grad_suspended(); + dr::vector copy_bit(indices1.size(), true); while (true) { // Evaluate the loop state @@ -190,14 +191,27 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, } for (size_t i = 0; i < indices2.size(); ++i) { - // Kernel caching: Must create an AD copy so that gradient - // computation steps involving this variable (even if unchangecd - // & only used as a read-only dependency) are correctly placed - // within their associated loop iterations. This does not create - // a copy of the underlying JIT variable. - - uint64_t i1 = indices2[i], - i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1); + // Potentially create an AD copy here (i.e., assign a new AD node + // representing a copy of the original loop state). Note that this + // is a symbolic copy in the AD graph that does not consume actual + // device memory. This copy is needed to prevent a degeneracy of + // forward derivative propagation, where a loop variable does not + // change at all, yet all loop iterations depend on this variable in + // a differentiable sense. When the AD traversal reaches this + // variable, this will generate a huge kernel that propagates the + // derivative to every single loop iteration, instead of splitting + // the computation into per-iteration kernels. By creating marked + // (ad_mark_loop-boundary) copies, we can ensure correct sequencing. + // The extra copies are only used within the loop and removed below. + + uint64_t i1 = indices2[i], i2; + + if (!grad_suspended && (i1 >> 32) != 0 && (indices1[i] >> 32) == (i1 >> 32)) { + i2 = ad_var_copy(i1); + } else { + i2 = ad_var_inc_ref(i1); + copy_bit[i] = false; + } ad_var_dec_ref(i1); ad_mark_loop_boundary(i2); @@ -208,7 +222,8 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, write_cb(payload, indices2, false); indices1.release(); - indices1.swap(indices2); + indices2.release(); + read_cb(payload, indices1); active_it = JitVar::borrow(cond_cb(payload)); active_it.schedule_(); @@ -216,6 +231,29 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, active.schedule_force_(); } + { + bool changed = false; + for (size_t i = 0; i < indices1.size(); ++i) { + if (!copy_bit[i]) + continue; + + // The AD index of this was copied a number of times (see above for + // the rationale). Let's now remove these again. + uint32_t ad_index = (uint32_t) (indices1[i] >> 32); + for (uint32_t j = 0; j < it; ++j) + ad_index = ad_pred(ad_index, 0); + + uint64_t index_new = (((uint64_t) ad_index) << 32) | (uint32_t) indices1[i]; + ad_var_inc_ref(index_new); + ad_var_dec_ref(indices1[i]); + indices1[i] = index_new; + changed = true; + } + + if (changed) + write_cb(payload, indices1, false); + } + return it; } From 12295ed63169cce95cb96af0798049748226af2c Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Mon, 21 Apr 2025 23:38:25 +0900 Subject: [PATCH 5/9] Improved symbolic loop backward derivative This commit fixes bugs in the compilation of reverse-mode derivatives of simple loops (i.e, loops with max_iterations==-1) and updates the test suite to cover problematic cases. --- src/extra/loop.cpp | 80 ++++++++++++++++++++----------------- tests/test_while_loop_ad.py | 27 ++++++++++++- 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index 2a75970a..6b246eb6 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -743,8 +743,7 @@ struct LoopOp : public dr::detail::CustomOpBase { before the loop> state_i = + tracking was enabled before the loop> grad_state_o = @@ -756,26 +755,21 @@ struct LoopOp : public dr::detail::CustomOpBase { dr.disable_grad(state) */ - void bwd_body_simple() { - // Create differentiable loop state variables + void bwd_body_simple() { // Create differentiable loop state variables m_state2.release(); - index32_vector tmp; - size_t offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; uint64_t index; - if (in.has_grad_out && in.has_grad_in) { + if (in.has_grad_in) { index = ad_var_new((uint32_t) m_state[i]); - tmp.push_back_borrow((uint32_t) m_state[offset]); + ad_var_map_put(combine(m_input_indices[in.grad_in_index]), index); } else { index = ad_var_inc_ref(m_state[i]); } - m_state2.push_back_steal(index); - if (in.has_grad_out) - offset++; + m_state2.push_back_steal(index); } // Run the loop body @@ -792,16 +786,18 @@ struct LoopOp : public dr::detail::CustomOpBase { m_read_cb(m_payload, m_state2); // AD backward propagation pass - offset = m_inputs.size(); + size_t offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - if (!in.has_grad_out) + if (!in.has_grad_in && !in.has_grad_out) continue; - ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); - - if (!in.has_grad_in) + if (in.has_grad_out && (m_state2[i] >> 32)) { + ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); ad_enqueue(dr::ADMode::Backward, m_state2[i]); + } else if (in.has_grad_in) { + ad_accum_grad(m_state2[i], (uint32_t) m_state[offset]); + } offset++; } @@ -809,18 +805,23 @@ struct LoopOp : public dr::detail::CustomOpBase { ad_traverse(dr::ADMode::Backward, (uint32_t) dr::ADFlag::ClearNone); // Read the loop output + derivatives copy to loop state vars - m_state.release(); - for (size_t i = 0; i < m_inputs.size(); ++i) - m_state.push_back_borrow((uint32_t) m_state2[i]); + for (size_t i = 0; i < m_inputs.size(); ++i) { + jit_var_inc_ref((uint32_t) m_state2[i]); + ad_var_dec_ref((uint32_t) m_state[i]); + m_state[i] = (uint32_t) m_state2[i]; + } offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - - if (!in.has_grad_out) + if (!in.has_grad_in && !in.has_grad_out) continue; - uint32_t grad = ad_grad(m_state2[i]); - m_state.push_back_steal(grad); + + if (in.has_grad_in) { + ad_var_dec_ref(m_state[offset]); + m_state[offset] = ad_grad(m_state2[i]); + } + offset++; } @@ -834,10 +835,17 @@ struct LoopOp : public dr::detail::CustomOpBase { for (const Input &i : m_inputs) m_state.push_back_borrow(i.index); + uint32_t index = 0; for (const Input &in : m_inputs) { - uint32_t grad; - if (!in.has_grad_out) + index++; + if (!in.has_grad_out && !in.has_grad_in) continue; + uint32_t grad; + + if (in.has_grad_in && in.has_grad_out) + jit_raise("LoopOp::backward_simple(): unsupported " + "configuration. Variable %u (r%u) is marked both as a " + "differentiable output and an input.", index, (uint32_t) in.index); if (in.has_grad_in) { uint64_t zero = 0; @@ -856,29 +864,30 @@ struct LoopOp : public dr::detail::CustomOpBase { [](void *p) { return ((LoopOp *) p)->fwd_cond(); }, [](void *p) { return ((LoopOp *) p)->bwd_body_simple(); }, nullptr, false); - size_t offset = m_inputs.size(); for (const Input &in : m_inputs) { - if (!in.has_grad_out) + if (!in.has_grad_out && !in.has_grad_in) continue; - if (in.has_grad_in) { - ad_accum_grad(combine(m_input_indices[in.grad_in_index]), + if (in.has_grad_out) + ad_accum_grad(combine(m_output_indices[in.grad_out_offset]), (uint32_t) m_state[offset]); - } offset++; } + offset = m_inputs.size(); for (size_t i = 0; i < m_inputs.size(); ++i) { const Input &in = m_inputs[i]; - if (!in.has_grad_out) + if (!in.has_grad_out && !in.has_grad_in) continue; - ad_accum_grad(combine(m_output_indices[in.grad_out_offset]), - (uint32_t) m_state[m_inputs.size() + in.grad_in_offset]); - } + if (in.has_grad_in) + ad_accum_grad(combine(m_input_indices[in.grad_in_index]), + (uint32_t) m_state[offset]); + offset++; + } m_state.release(); } @@ -1006,8 +1015,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, vt != VarType::Float64) continue; - if (max_iterations == 0 && - (uint32_t) indices_in[i] == (uint32_t) indices_out[i]) { + if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i]) { // Keep unchanged variables out of the AD system if (indices_in[i] != indices_out[i]) { ad_var_inc_ref(indices_in[i]); diff --git a/tests/test_while_loop_ad.py b/tests/test_while_loop_ad.py index bf12f339..3d0ea36e 100644 --- a/tests/test_while_loop_ad.py +++ b/tests/test_while_loop_ad.py @@ -64,9 +64,10 @@ def test03_sum_loop_fwd(t, mode): @pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.mark.parametrize('make_copy', [True, False]) @pytest.test_arrays('float32,diff,shape=(*)') @dr.syntax -def test04_sum_loop_rev(t, mode): +def test04_sum_loop_rev(t, mode, make_copy): # Test the "sum loop" optimization (max_iterations=-1) for # consistency against test03 UInt32 = dr.uint32_array_t(t) @@ -74,8 +75,11 @@ def test04_sum_loop_rev(t, mode): y, i = Float(0), UInt32(0) x = dr.linspace(Float, .25, 1, 4) - xo = x dr.enable_grad(x) + if make_copy: + xo = Float(x) + else: + xo = x while dr.hint(i < 10, max_iterations=-1, mode=mode): y += x**i @@ -87,6 +91,7 @@ def test04_sum_loop_rev(t, mode): assert dr.allclose(y, [1.33333, 1.99805, 3.77475, 10]) assert dr.allclose(xo.grad, [1.77773, 3.95703, 12.0956, 45]) + @pytest.mark.parametrize('variant', ['fwd', 'bwd']) @pytest.test_arrays('float32,is_diff,shape=(*)') def test05_evaluated_ad_kernel_launch_count(t, variant): @@ -132,6 +137,7 @@ def test05_evaluated_ad_kernel_launch_count(t, variant): for k in h: assert k['operation_count'] < iterations + @pytest.mark.parametrize('variant', [0, 1]) @pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) @pytest.test_arrays('float32,diff,shape=(*)') @@ -240,3 +246,20 @@ def loop(l: list, t, mode): dr.backward(loss) + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('float32,is_diff,shape=(*)') +@dr.syntax +def test32_simple_loop(t, mode): + # Testcase for simple backwards derivatives with gathers + i = dr.uint32_array_t(t)(0) + x = dr.ones(t, 10) + q = dr.zeros(t) + dr.enable_grad(x, 10) + + while dr.hint(i < 10, max_iterations=-1, mode=mode): + q += dr.gather(t, x, i) + i += 1 + + dr.backward(q) + assert dr.all(x.grad == [1]*10) From a45e1c245c0dfa6366280858472628e9120a1abc Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Mon, 21 Apr 2025 23:38:48 +0900 Subject: [PATCH 6/9] Fix derivative of ``nn.matmul()`` in simple symbolic loops This commit fixes bugs and adds tests to ensure that matrix multiplication can be correctly differentiated in reverse-mode when it occurs inside a "simple" loop (i.e., a loop with max_iterations==-1). --- ext/drjit-core | 2 +- src/extra/autodiff.cpp | 17 ++++++---- src/python/coop_vec.cpp | 2 +- tests/test_coop_vec.py | 67 ++++++++++++++++++++++++++++++++++++++++ tests/test_while_loop.py | 1 + 5 files changed, 81 insertions(+), 8 deletions(-) diff --git a/ext/drjit-core b/ext/drjit-core index f6ccae53..01223075 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit f6ccae53828aa7d320b697011d34931c2ab9934c +Subproject commit 01223075552619d13dffb9e0437560ff2bf1a178 diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 3b2a365b..d0328fe6 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -1233,7 +1233,7 @@ void ad_accum_grad(Index index, JitIndex value) { size_t size_in = value_v.size(); if (v->size != size_in && size_in != 1 && size_in != 0 && v->size != 1) - ad_raise("ad_set_grad(): attempted to store a gradient of size " + ad_raise("ad_accum_grad(): attempted to store a gradient of size " "%zu into AD variable a%u, which has size %zu!", size_in, ad_index, v->size); @@ -2942,6 +2942,7 @@ Index ad_var_cast(Index i0, VarType vt) { void ad_var_map_put(Index source, Index target) { uint32_t ad_index_source = ad_index(source), ad_index_target = ad_index(target); + ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target); if (ad_index_target == 0) return; @@ -3822,7 +3823,7 @@ class CoopVecPack : public dr::detail::CustomOpBase { void forward() override { std::lock_guard guard(state.lock); - uint32_t size = (uint32_t) m_input_indices.size(); + uint32_t size = (uint32_t) m_inputs.size(); JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * size); size_t n_valid = 0; @@ -3849,7 +3850,7 @@ class CoopVecPack : public dr::detail::CustomOpBase { void backward() override { std::lock_guard guard(state.lock); - uint32_t n = (uint32_t) m_input_indices.size(); + uint32_t n = (uint32_t) m_inputs.size(); Variable *v = state[m_output_indices[0]]; if (!v->grad.valid()) @@ -3858,9 +3859,13 @@ class CoopVecPack : public dr::detail::CustomOpBase { JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n); jit_coop_vec_unpack(v->grad.index(), n, tmp); - for (size_t i = 0; i < m_input_indices.size(); ++i) { - Variable *v2 = state[m_inputs[i]]; - v2->accum(JitVar::steal(tmp[i]), v2->size); + for (size_t i = 0; i < m_inputs.size(); ++i) { + uint32_t index = m_inputs[i]; + JitVar var = JitVar::steal(tmp[i]); + if (!index) + continue; + Variable *v2 = state[index]; + v2->accum(var, v2->size); } } diff --git a/src/python/coop_vec.cpp b/src/python/coop_vec.cpp index d0314c52..4f3a8d0c 100644 --- a/src/python/coop_vec.cpp +++ b/src/python/coop_vec.cpp @@ -284,7 +284,7 @@ nb::str MatrixView::repr() const { return nb::str( "drjit.nn.MatrixView[\n" " dtype={},\n" - " layout={},\n" + " layout=\"{}\",\n" " shape=({}, {}),\n" " stride={},\n" " offset={}\n" diff --git a/tests/test_coop_vec.py b/tests/test_coop_vec.py index 8efc5f29..1e868eed 100644 --- a/tests/test_coop_vec.py +++ b/tests/test_coop_vec.py @@ -530,3 +530,70 @@ def test19_no_eval(t): dr.eval(a) with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"): dr.make_opaque(a) + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test20_matvec_in_loop(t, mode): + # Check that derivative inference works when + # cooperative vectors are used inside loops + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + _, A_view, b_view = nn.pack(A, b, layout='inference') + + cnt = UInt32(0) + res = Float32(0) + + while dr.hint(cnt < 3, mode=mode): + x = nn.CoopVec(Float16([0.5]), Float16([0.5])) + a, b = nn.matvec(A_view, x, b_view) + res += Float32(a) + cnt += 1 + + assert res == 3 + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test21_optimize_in_loop_bwd(t, mode): + # Check that derivative backpropagation occurs when + # cooperative vectors are used inside loops + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + buf, A_view, b_view = nn.pack(A, b, layout='training') + dr.enable_grad(buf) + + cnt = dr.zeros(UInt32, 2) + res = dr.zeros(Float32, 2) + + while dr.hint(cnt < 3, max_iterations=-1, mode=mode): + x = nn.CoopVec(Float16(0.5), Float16(0.5)) + a, _ = nn.matvec(A_view, x, b_view) + res += Float32(a) + cnt += 1 + + dr.backward(res) + + _, A_view, b_view = nn.unpack(A_view.grad, b_view.grad) + A = TensorXf16(A_view) + b = TensorXf16(b_view) + assert dr.all(A == TensorXf16([[3, 3], [0, 0]])) + assert dr.all(b == TensorXf16([[6], [0]])) diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index f6fc805a..342ada06 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -736,3 +736,4 @@ def test31_tensor_loop_preserve_shape(t, mode): assert a.shape == (10, 11) assert a.shape == (10, 11) + From a9f6f386e0a2130adc6c6eab8a57193a48ae8d7e Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 22 Apr 2025 07:13:41 +0900 Subject: [PATCH 7/9] move verbose ``__repr__()`` methods for ``nn.Module`` classes --- drjit/nn.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/drjit/nn.py b/drjit/nn.py index b73aeb41..081b5591 100644 --- a/drjit/nn.py +++ b/drjit/nn.py @@ -244,6 +244,8 @@ def __init__(self, dtype: Optional[Type[drjit.ArrayBase]] = None): self.dtype = dtype def __call__(self, arg: CoopVec, /) -> CoopVec: return cast(arg, self.dtype) + def __repr__(self): + return f'Cast(dtype={self.dtype.__name__})' class Linear(Module): r""" @@ -286,7 +288,7 @@ def __init__(self, in_features: int = -1, out_features: int = -1, bias = True) - self.weights = self.bias = None def __repr__(self) -> str: - s = f'Linear({self.config[0]}, {self.config[1]}' + s = f'Linear(in_features={self.config[0]}, out_features={self.config[1]}' if not self.config[2]: s += ', bias=False' s += ')' @@ -391,15 +393,20 @@ class TriEncode(Module): :align: center """ + DRJIT_STRUCT = { 'octaves' : int, 'shift': float, 'channels': int } + def __init__(self, octaves: int = 0, shift: float = 0) -> None: self.octaves = octaves self.shift = shift + self.channels = -1 def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: - return self, size * self.octaves * 2 + r = TriEncode(self.octaves, self.shift) + r.channels = size + return r, size * self.octaves * 2 def __repr__(self) -> str: - return f'TriEncode({self.octaves})' + return f'TriEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})' def __call__(self, arg: CoopVec, /) -> CoopVec: args, r = list(arg), list() @@ -453,8 +460,11 @@ class SinEncode(Module): :align: center """ + DRJIT_STRUCT = { 'octaves' : int, 'shift': Union[tuple, None], 'channels': int } + def __init__(self, octaves: int = 0, shift: float = 0) -> None: self.octaves = octaves + self.channels = -1 if shift == 0: self.shift = None @@ -463,10 +473,13 @@ def __init__(self, octaves: int = 0, shift: float = 0) -> None: drjit.cos(shift * 2 * drjit.pi)) def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]: - return self, size * self.octaves * 2 + r = SinEncode(self.octaves) + r.channels = size + r.shift = self.shift + return r, size * self.octaves * 2 def __repr__(self) -> str: - return f'SinEncode({self.octaves})' + return f'SinEncode(octaves={self.octaves}, shift={self.shift}, in_channels={self.channels}, out_features={self.channels*self.octaves*2})' def __call__(self, arg: CoopVec, /) -> CoopVec: args, r = list(arg), list() From dd7995a9e45811fbb9302c9de946b98432dfd5dc Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 22 Apr 2025 07:56:08 +0900 Subject: [PATCH 8/9] update Dr.Jit-Core repo --- ext/drjit-core | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/drjit-core b/ext/drjit-core index 01223075..8c88e176 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 01223075552619d13dffb9e0437560ff2bf1a178 +Subproject commit 8c88e1768b5907f55d6f65e57f87409c75b071ea From a9237035236891c37fa9908e25634bca812a3917 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 22 Apr 2025 20:25:36 +0900 Subject: [PATCH 9/9] improve test coverage --- ext/drjit-core | 2 +- src/extra/autodiff.cpp | 3 +++ src/python/init.cpp | 3 ++- tests/test_coop_vec.py | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/ext/drjit-core b/ext/drjit-core index 8c88e176..89c0db27 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 8c88e1768b5907f55d6f65e57f87409c75b071ea +Subproject commit 89c0db27ea5a1d4b49e310c058386098933bc5bd diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index d0328fe6..6e45b62b 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -2345,6 +2345,9 @@ void ad_mark_loop_boundary(Index index) { } uint32_t ad_pred(uint32_t ad_index, uint32_t i_) { + if (ad_index == 0) + return 0; + std::lock_guard guard(state.lock); const Variable *v = state[ad_index]; uint32_t edge = v->next_bwd; diff --git a/src/python/init.cpp b/src/python/init.cpp index 98bec4d1..be422a40 100644 --- a/src/python/init.cpp +++ b/src/python/init.cpp @@ -676,7 +676,8 @@ nb::object view_to_tensor(nb::handle h, dr::vector &shape) { "row-major representation via drjit.nn.unpack()."); if (view.descr.stride != view.descr.cols) - nb::raise("Unsupported row stride!"); + nb::raise("Unsupported row stride: expected stride %u, found %u.", + view.descr.cols, view.descr.stride); shape.push_back(view.descr.rows); shape.push_back(view.descr.cols); diff --git a/tests/test_coop_vec.py b/tests/test_coop_vec.py index 1e868eed..395ba041 100644 --- a/tests/test_coop_vec.py +++ b/tests/test_coop_vec.py @@ -538,6 +538,7 @@ def test19_no_eval(t): def test20_matvec_in_loop(t, mode): # Check that derivative inference works when # cooperative vectors are used inside loops + skip_if_coopvec_not_supported(t) m = sys.modules[t.__module__] Float16 = t @@ -568,6 +569,7 @@ def test20_matvec_in_loop(t, mode): def test21_optimize_in_loop_bwd(t, mode): # Check that derivative backpropagation occurs when # cooperative vectors are used inside loops + skip_if_coopvec_not_supported(t) m = sys.modules[t.__module__] Float16 = t @@ -597,3 +599,42 @@ def test21_optimize_in_loop_bwd(t, mode): b = TensorXf16(b_view) assert dr.all(A == TensorXf16([[3, 3], [0, 0]])) assert dr.all(b == TensorXf16([[6], [0]])) + + +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +@pytest.test_arrays('jit,shape=(*),float16,diff') +@dr.syntax +def test22_optimize_in_loop_bwd_v2(t, mode): + # Check that derivative backpropagation occurs when + # cooperative vectors are used inside loops, and the + # backprop call is placed there as well + + skip_if_coopvec_not_supported(t) + + m = sys.modules[t.__module__] + Float16 = t + TensorXf16 = m.TensorXf16 + Float32 = m.Float32 + UInt32 = m.UInt32 + + A = dr.ones(TensorXf16, shape=(2, 2)) + b = dr.zeros(Float16, shape=(2)) + + buf, A_view, b_view = nn.pack(A, b, layout='training') + dr.enable_grad(buf) + + cnt = dr.zeros(UInt32, 2) + res = dr.zeros(Float32, 2) + + while dr.hint(cnt < 3, mode=mode, exclude=[A_view, b_view]): + x = nn.CoopVec(Float16(0.5), Float16(0.5)) + a, _ = nn.matvec(A_view, x, b_view) + res = Float32(a) + dr.backward(res) + cnt += 1 + + _, A_view, b_view = nn.unpack(A_view.grad, b_view.grad) + A = TensorXf16(A_view) + b = TensorXf16(b_view) + assert dr.all(A == TensorXf16([[3, 3], [0, 0]])) + assert dr.all(b == TensorXf16([[6], [0]]))