-
Notifications
You must be signed in to change notification settings - Fork 49
Frozen Functions #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frozen Functions #336
Conversation
a62bdf7
to
d47afff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good so far, I still have to go through some of the main bits like freeze.cpp
/freeze.h
. I'll come back to this tomorrow.
include/drjit/texture.h
Outdated
|
||
DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque, | ||
m_inv_resolution); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfbd9b5
to
bfb7674
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone through the last few files I had missed.
Ignoring the few new comments, I'm overall pretty happy with the current state of the source code. As for the documentation, the only part that truly matters is the dr.freeze
and I think it's good too - I don't really have the full picture yet, so take that with a pinch of salt. For the C++ documentation, it was plenty enough for me to understand how these mechanisms work and should be used 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very partial review, I only looked at the files that had < ~150 lines of diff.
1ae57b6
to
c12c89e
Compare
5f7efff
to
3af2b69
Compare
15220b6
to
3d80462
Compare
971b79f
to
2232018
Compare
1f08b94
to
f57a954
Compare
c821068
to
dcb5217
Compare
7de6d1c
to
bec6588
Compare
48daace
to
843f668
Compare
fd8cf08
to
8ef5ddf
Compare
8ef5ddf
to
60a9a4a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments from me on this PR
* assigning to the input. The function takes an optional python-type if | ||
* it is known. | ||
*/ | ||
void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is something basic I do not understand. My assumption was that inputs/outputs to @dr.freeze
-annotated functions cannot be AD-attached. Yet there is lots of handling related to AD variables.
The documentation (docs/freeze.rst
) you added does not discuss AD variables. It would be good to explain somewhere what behavior is legal and what isn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Auto Opaque PR #361 adds the documentation in freeze.rst
. It should explain that only propagating gradients from inside to the inputs and their dependencies is possible from a frozen function. These traverse_ad_index
functions are required to handle this case.
@@ -0,0 +1,1551 @@ | |||
#include "freeze.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like in the Dr.Jit-Core implementation of kernel freezing, this file needs a detailed comment describing the high level "algorithm". The various functions are individually documented, but the big picture does not easily emerge when going through them.
It should cover the high level ideas like: how are the arguments to a function flattened, how do we detect misses/hits? You also implemented an optimization related to literal variables, that would be good to briefly mention too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will ad that in the auto opaque PR #361 if ok. Since this changes quite a bit.
382c8b9
to
f31870a
Compare
Added comment explaining traverse with `rw` arguments Added comment about c++ traversal Added recursion guards Simplified comment on TraverseCallback operator() Removed redundant profiling loop Fixed recursion_guard by wrapping it in namespace Fixed typing import Improved handling of nullptr variant/domains Removed warn log Added JitFlag FreezingTraverseScope Using try_cast in get_traversable_base Renamed to EnableObjectTraversal flag Added comment for traverse_py_cb_ro_impl Added rw value for traversal of trampolines Removed outdated function Excluding texture from traversal outside of frozen function Testing fix for windows compilation Fixing windows compilation bugs Added base value test for custom_type_ext traversal Added nested traversal test Exposing frozen function related flags to python Improved warning about non-drjit types Added option to specify backend if no input variable is specified Allow for changes in the last dimension of tensor shapes Marked payload and fn as used Reverted comment Formatting Removed printing of keys Fixed backend test Added method to detect recording frozen functions with the wrong backend Added comments for arguments of dr.freeze decorator Added flag test to DR_TRAVERSE_CB_RO/RW macros Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro Fixed tensor freezing indexing issue Removed deprecated logging code Added comment about tensor shape inference Using Wenzel's documentation, and renamed arguments Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear` Fixed typo Added warning text to documentation Fixed freezing drjit optimizers Added optimizer freezing tests Small refactor Suggestion from Wenzel
f31870a
to
09ff550
Compare
This PR implements the frozen function feature in
drjit
.It adds a new
FrozenFunction
struct that is a wrapper over a Python callable.When calling the
operator()
of the object, it traverses the input and determines a unique layout/key.This key is then used to determine if the function has to be recorded or if it's possible to replay it.
The PR relies on functionality implemented in Frozen ThreadState, and #288.
This replaces #317 and changes the source branch.