Skip to content

Frozen Functions #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 19, 2025
Merged

Frozen Functions #336

merged 6 commits into from
Jun 19, 2025

Conversation

DoeringChristian
Copy link
Contributor

This PR implements the frozen function feature in drjit.
It adds a new FrozenFunction struct that is a wrapper over a Python callable.
When calling the operator() of the object, it traverses the input and determines a unique layout/key.
This key is then used to determine if the function has to be recorded or if it's possible to replay it.
The PR relies on functionality implemented in Frozen ThreadState, and #288.

This replaces #317 and changes the source branch.

Copy link
Member

@njroussel njroussel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good so far, I still have to go through some of the main bits like freeze.cpp/freeze.h. I'll come back to this tomorrow.

Comment on lines 1390 to 1505

DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque,
m_inv_resolution);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a note (to myself):

This will most likely need to be updated with #329 before this gets merged, or we do it in #329 if we merge this fist.

Copy link
Member

@njroussel njroussel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through the last few files I had missed.

Ignoring the few new comments, I'm overall pretty happy with the current state of the source code. As for the documentation, the only part that truly matters is the dr.freeze and I think it's good too - I don't really have the full picture yet, so take that with a pinch of salt. For the C++ documentation, it was plenty enough for me to understand how these mechanisms work and should be used 👍

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very partial review, I only looked at the files that had < ~150 lines of diff.

@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 6 times, most recently from 5f7efff to 3af2b69 Compare March 4, 2025 10:25
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 10 times, most recently from 15220b6 to 3d80462 Compare March 16, 2025 15:03
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 3 times, most recently from 971b79f to 2232018 Compare March 23, 2025 11:49
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 2 times, most recently from 1f08b94 to f57a954 Compare April 12, 2025 11:23
@wjakob wjakob force-pushed the master branch 4 times, most recently from c821068 to dcb5217 Compare April 14, 2025 16:01
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 4 times, most recently from 7de6d1c to bec6588 Compare April 18, 2025 11:42
@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 2 times, most recently from fd8cf08 to 8ef5ddf Compare June 4, 2025 13:15
Copy link
Member

@wjakob wjakob left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments from me on this PR

* assigning to the input. The function takes an optional python-type if
* it is known.
*/
void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is something basic I do not understand. My assumption was that inputs/outputs to @dr.freeze-annotated functions cannot be AD-attached. Yet there is lots of handling related to AD variables.

The documentation (docs/freeze.rst) you added does not discuss AD variables. It would be good to explain somewhere what behavior is legal and what isn't.

Copy link
Contributor Author

@DoeringChristian DoeringChristian Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Auto Opaque PR #361 adds the documentation in freeze.rst. It should explain that only propagating gradients from inside to the inputs and their dependencies is possible from a frozen function. These traverse_ad_index functions are required to handle this case.

@@ -0,0 +1,1551 @@
#include "freeze.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like in the Dr.Jit-Core implementation of kernel freezing, this file needs a detailed comment describing the high level "algorithm". The various functions are individually documented, but the big picture does not easily emerge when going through them.

It should cover the high level ideas like: how are the arguments to a function flattened, how do we detect misses/hits? You also implemented an optimization related to literal variables, that would be good to briefly mention too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will ad that in the auto opaque PR #361 if ok. Since this changes quite a bit.

@DoeringChristian DoeringChristian force-pushed the frozen-functions branch 5 times, most recently from 382c8b9 to f31870a Compare June 18, 2025 13:11
Added comment explaining traverse with `rw` arguments

Added comment about c++ traversal

Added recursion guards

Simplified comment on TraverseCallback operator()

Removed redundant profiling loop

Fixed recursion_guard by wrapping it in namespace

Fixed typing import

Improved handling of nullptr variant/domains

Removed warn log

Added JitFlag FreezingTraverseScope

Using try_cast in get_traversable_base

Renamed to EnableObjectTraversal flag

Added comment for traverse_py_cb_ro_impl

Added rw value for traversal of trampolines

Removed outdated function

Excluding texture from traversal outside of frozen function

Testing fix for windows compilation

Fixing windows compilation bugs

Added base value test for custom_type_ext traversal

Added nested traversal test

Exposing frozen function related flags to python

Improved warning about non-drjit types

Added option to specify backend if no input variable is specified

Allow for changes in the last dimension of tensor shapes

Marked payload and fn as used

Reverted comment

Formatting

Removed printing of keys

Fixed backend test

Added method to detect recording frozen functions with the wrong backend

Added comments for arguments of dr.freeze decorator

Added flag test to DR_TRAVERSE_CB_RO/RW macros

Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro

Fixed tensor freezing indexing issue

Removed deprecated logging code

Added comment about tensor shape inference

Using Wenzel's documentation, and renamed arguments

Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear`

Fixed typo

Added warning text to documentation

Fixed freezing drjit optimizers

Added optimizer freezing tests

Small refactor

Suggestion from Wenzel
@wjakob wjakob merged commit e1f88e2 into master Jun 19, 2025
@wjakob wjakob deleted the frozen-functions branch June 19, 2025 12:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants