Skip to content

Frozen Functions #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Frozen Functions #336

wants to merge 3 commits into from

Conversation

DoeringChristian
Copy link
Contributor

This PR implements the frozen function feature in drjit.
It adds a new FrozenFunction struct that is a wrapper over a Python callable.
When calling the operator() of the object, it traverses the input and determines a unique layout/key.
This key is then used to determine if the function has to be recorded or if it's possible to replay it.
The PR relies on functionality implemented in Frozen ThreadState, and #288.

This replaces #317 and changes the source branch.

Copy link
Member

@njroussel njroussel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good so far, I still have to go through some of the main bits like freeze.cpp/freeze.h. I'll come back to this tomorrow.

Comment on lines 1390 to 1505

DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque,
m_inv_resolution);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a note (to myself):

This will most likely need to be updated with #329 before this gets merged, or we do it in #329 if we merge this fist.

Copy link
Member

@njroussel njroussel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through the last few files I had missed.

Ignoring the few new comments, I'm overall pretty happy with the current state of the source code. As for the documentation, the only part that truly matters is the dr.freeze and I think it's good too - I don't really have the full picture yet, so take that with a pinch of salt. For the C++ documentation, it was plenty enough for me to understand how these mechanisms work and should be used 👍

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very partial review, I only looked at the files that had < ~150 lines of diff.

@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 6 times, most recently from 5f7efff to 3af2b69 Compare March 4, 2025 10:25
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 10 times, most recently from 15220b6 to 3d80462 Compare March 16, 2025 15:03
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 3 times, most recently from 971b79f to 2232018 Compare March 23, 2025 11:49
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 2 times, most recently from 02b5351 to 5bab83a Compare March 26, 2025 10:51
Copy link
Member

@wjakob wjakob left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First comments about docs/docstrings.

F2 = TypeVar("F2")

@overload
def freeze(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try rendering the documentation? (With python -m sphinx docs html and then open html/index.html). I think that you will need to reference any additions in docs/reference.rst with some auto* command.

backend: Optional[JitBackend] = None,
) -> Callable[[F], F]:
"""
Decorator to freeze a function for replaying kernels without re-tracing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Christian, I did a writing pass over the docstring of the function, here is my suggestion. I changed the names of some of the decorator parameters to be shorter that will also require corresponding updates elsewhere. Are you OK with those changes??

Decorator to "freeze" functions, which improves efficiency by removing
repeated JIT tracing overheads.

In general, Dr.Jit traces computation and then compiles and launches kernels
containing this trace (see the section on :ref:`evaluation <eval>` for
details). While the compilation step can often be skipped via caching, the
tracing cost can still be significant especially when repeatedly evaluating
complex models, e.g., as part of an optimization loop.

The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by
altogether removing the need to trace repeatedly. For example, consider the
following decorated function:

.. code-block:: python

   @dr.freeze
   def f(x, y, z):
      return ... # Complicated code involving the arguments

Dr.Jit will trace the first call to the decorated function ``f()``, while
collecting additional information regarding the nature of the function's inputs
and regarding the CPU/GPU kernel launches representing the body of ``f()``.

If the function is subsequently called with *compatible* arguments (more on
this below), it will immediately launch the previously made CPU/GPU kernels
without re-tracing, which can substantially improve performance.

When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x``
having a different type compared to the previous call), it will conservatively
re-trace the body and keep track of another potential input configuration.

Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function
arguments and return values.

The following may trigger re-tracing:

- Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element.
- Changes in the **length** of a container (``list``, ``tuple``, ``dict``).
- Changes of **dictionary keys**  or **field names** of dataclasses.
- Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable.
- Changes of (non-PyTree) **Python objects**, as detected by mismatching ``id()``.

The following more technical conditions also trigger re-tracing:
- A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``).
- The sets of variables of the same size change. In the example above, this
  would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)``
  subsequently.
- When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the
  memory can be aligned or unaligned. A re-tracing step is needed when this
  status changes.

These all correspond to situations where the generated kernel code may need to
change, and the system conservatively re-traces to ensure correctness.

Frozen functions support arguments with a different variable *width* (see
:py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of
variables of the same width stay consistent.

Some constructions are problematic and should be avoided in frozen functions.

- The function :py:func:`dr.width() <drjit.width>` returns an integer literal
  that may be merged into the generated code. If the frozen function is later
  rerun with differently-sized arguments, the executed kernels will still
  reference the old size. One exception to this rule are constructions like
  `dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on
  the width value.

- @Christian: Should we say something about calling frozen functions from
  other frozen functions?

- @Christian: Any other caveats? Is the part about dr.mean() still relevant? I
  had removed that part since I thought we had a solution for dr.mean().

**Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes
several optional parameters that are helpful in certain situations.

- **Warning when re-tracing happens too often**: Incompatible arguments trigger
  re-tracing, which can mask issues where *accidentally* incompatible arguments
  keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected
  performance benefits.

  In such situations, it can be helpful to warn and identify changing
  parameters by name. This feature is enabled and set to ``10`` by default.

   .. code-block:: pycon

      >>> @dr.freeze(warn_after=1)
      >>> def f(x):
      ...     return x
      ...
      >>> f(Int(1))
      >>> f(Float(1))
      >>> <<< @Christian: needs an example of what the warning looks like

- **Limiting memory usage**. Storing kernels for many possible input
  configuration requires device memory, which can become problematic. Set the
  ``limit=`` parameter to enable a LRU cache. This is useful when calls to a
  function are mostly compatible but require occasional re-tracing.

Args:
    limit (Optional[int]): An optional integer specifying the maximum number of
      stored configurations. Once this limit is reached, incompatible calls
      requiring re-tracing will cause the last used configuration to be dropped.

    warn_after (int): When the number of re-tracing steps exceeds this value,
      Dr.Jit will generate a warning that explains which variables changed
      between calls to the function.

    state_fn (Optional[Callable]): This optional callable can specify additional
      state to identifies the configuration. ``state_fn`` will be called with
      the same arguments as that of the decorated function. It should return a
      traversable object (e.g., a list or tuple) that is conceptually treated
      as if it was another input of the function.

    backend (Optional[JitBackend]): If no inputs are given when calling the
      frozen function, the backend used has to be specified using this argument.
      It must match the backend used for computation within the function.

Copy link
Contributor Author

@DoeringChristian DoeringChristian Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding

Changes of (non-PyTree) Python objects, as detected by mismatching id().

I'm keeping track of non-PyTree objects based on a hash I calculate of them. Not sure if taking the id() might be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • @christian: Should we say something about calling frozen functions from
    other frozen functions?

This should not be a problem, as the inner function is simply recorded by the outer one. I'll check that we actually test for that though. I think it should however not be something the user has to be concerned about.

Copy link
Contributor Author

@DoeringChristian DoeringChristian Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • @christian: Any other caveats? Is the part about dr.mean() still relevant? I
    had removed that part since I thought we had a solution for dr.mean().

The mean case should be covered by #361 . We could leave the comment here and remove it in the other PR?

@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 2 times, most recently from 4ca120b to c0ae426 Compare March 31, 2025 14:33
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 3 times, most recently from 1f08b94 to f57a954 Compare April 12, 2025 11:23
@wjakob wjakob force-pushed the master branch 4 times, most recently from c821068 to dcb5217 Compare April 14, 2025 16:01
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 4 times, most recently from 7de6d1c to bec6588 Compare April 18, 2025 11:42
Added comment explaining traverse with `rw` arguments

Added comment about c++ traversal

Added recursion guards

Simplified comment on TraverseCallback operator()

Removed redundant profiling loop

Fixed recursion_guard by wrapping it in namespace

Fixed typing import

Improved handling of nullptr variant/domains

Removed warn log

Added JitFlag FreezingTraverseScope

Using try_cast in get_traversable_base

Renamed to EnableObjectTraversal flag

Added comment for traverse_py_cb_ro_impl

Added rw value for traversal of trampolines

Removed outdated function

Excluding texture from traversal outside of frozen function

Testing fix for windows compilation

Fixing windows compilation bugs

Added base value test for custom_type_ext traversal

Added nested traversal test

Exposing frozen function related flags to python

Improved warning about non-drjit types

Added option to specify backend if no input variable is specified

Allow for changes in the last dimension of tensor shapes

Marked payload and fn as used

Reverted comment

Formatting

Removed printing of keys

Fixed backend test

Added method to detect recording frozen functions with the wrong backend

Added comments for arguments of dr.freeze decorator

Added flag test to DR_TRAVERSE_CB_RO/RW macros

Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro

Fixed tensor freezing indexing issue

Removed deprecated logging code

Added comment about tensor shape inference

Using Wenzel's documentation, and renamed arguments

Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear`

Fixed typo

Added warning text to documentation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants