-
Notifications
You must be signed in to change notification settings - Fork 46
Frozen Functions #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Frozen Functions #336
Conversation
a62bdf7
to
d47afff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good so far, I still have to go through some of the main bits like freeze.cpp
/freeze.h
. I'll come back to this tomorrow.
include/drjit/texture.h
Outdated
|
||
DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque, | ||
m_inv_resolution); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfbd9b5
to
bfb7674
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone through the last few files I had missed.
Ignoring the few new comments, I'm overall pretty happy with the current state of the source code. As for the documentation, the only part that truly matters is the dr.freeze
and I think it's good too - I don't really have the full picture yet, so take that with a pinch of salt. For the C++ documentation, it was plenty enough for me to understand how these mechanisms work and should be used 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very partial review, I only looked at the files that had < ~150 lines of diff.
1ae57b6
to
c12c89e
Compare
5f7efff
to
3af2b69
Compare
15220b6
to
3d80462
Compare
971b79f
to
2232018
Compare
02b5351
to
5bab83a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First comments about docs/docstrings.
F2 = TypeVar("F2") | ||
|
||
@overload | ||
def freeze( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you try rendering the documentation? (With python -m sphinx docs html
and then open html/index.html
). I think that you will need to reference any additions in docs/reference.rst
with some auto*
command.
drjit/__init__.py
Outdated
backend: Optional[JitBackend] = None, | ||
) -> Callable[[F], F]: | ||
""" | ||
Decorator to freeze a function for replaying kernels without re-tracing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Christian, I did a writing pass over the docstring of the function, here is my suggestion. I changed the names of some of the decorator parameters to be shorter that will also require corresponding updates elsewhere. Are you OK with those changes??
Decorator to "freeze" functions, which improves efficiency by removing
repeated JIT tracing overheads.
In general, Dr.Jit traces computation and then compiles and launches kernels
containing this trace (see the section on :ref:`evaluation <eval>` for
details). While the compilation step can often be skipped via caching, the
tracing cost can still be significant especially when repeatedly evaluating
complex models, e.g., as part of an optimization loop.
The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by
altogether removing the need to trace repeatedly. For example, consider the
following decorated function:
.. code-block:: python
@dr.freeze
def f(x, y, z):
return ... # Complicated code involving the arguments
Dr.Jit will trace the first call to the decorated function ``f()``, while
collecting additional information regarding the nature of the function's inputs
and regarding the CPU/GPU kernel launches representing the body of ``f()``.
If the function is subsequently called with *compatible* arguments (more on
this below), it will immediately launch the previously made CPU/GPU kernels
without re-tracing, which can substantially improve performance.
When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x``
having a different type compared to the previous call), it will conservatively
re-trace the body and keep track of another potential input configuration.
Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function
arguments and return values.
The following may trigger re-tracing:
- Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element.
- Changes in the **length** of a container (``list``, ``tuple``, ``dict``).
- Changes of **dictionary keys** or **field names** of dataclasses.
- Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable.
- Changes of (non-PyTree) **Python objects**, as detected by mismatching ``id()``.
The following more technical conditions also trigger re-tracing:
- A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``).
- The sets of variables of the same size change. In the example above, this
would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)``
subsequently.
- When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the
memory can be aligned or unaligned. A re-tracing step is needed when this
status changes.
These all correspond to situations where the generated kernel code may need to
change, and the system conservatively re-traces to ensure correctness.
Frozen functions support arguments with a different variable *width* (see
:py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of
variables of the same width stay consistent.
Some constructions are problematic and should be avoided in frozen functions.
- The function :py:func:`dr.width() <drjit.width>` returns an integer literal
that may be merged into the generated code. If the frozen function is later
rerun with differently-sized arguments, the executed kernels will still
reference the old size. One exception to this rule are constructions like
`dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on
the width value.
- @Christian: Should we say something about calling frozen functions from
other frozen functions?
- @Christian: Any other caveats? Is the part about dr.mean() still relevant? I
had removed that part since I thought we had a solution for dr.mean().
**Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes
several optional parameters that are helpful in certain situations.
- **Warning when re-tracing happens too often**: Incompatible arguments trigger
re-tracing, which can mask issues where *accidentally* incompatible arguments
keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected
performance benefits.
In such situations, it can be helpful to warn and identify changing
parameters by name. This feature is enabled and set to ``10`` by default.
.. code-block:: pycon
>>> @dr.freeze(warn_after=1)
>>> def f(x):
... return x
...
>>> f(Int(1))
>>> f(Float(1))
>>> <<< @Christian: needs an example of what the warning looks like
- **Limiting memory usage**. Storing kernels for many possible input
configuration requires device memory, which can become problematic. Set the
``limit=`` parameter to enable a LRU cache. This is useful when calls to a
function are mostly compatible but require occasional re-tracing.
Args:
limit (Optional[int]): An optional integer specifying the maximum number of
stored configurations. Once this limit is reached, incompatible calls
requiring re-tracing will cause the last used configuration to be dropped.
warn_after (int): When the number of re-tracing steps exceeds this value,
Dr.Jit will generate a warning that explains which variables changed
between calls to the function.
state_fn (Optional[Callable]): This optional callable can specify additional
state to identifies the configuration. ``state_fn`` will be called with
the same arguments as that of the decorated function. It should return a
traversable object (e.g., a list or tuple) that is conceptually treated
as if it was another input of the function.
backend (Optional[JitBackend]): If no inputs are given when calling the
frozen function, the backend used has to be specified using this argument.
It must match the backend used for computation within the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding
Changes of (non-PyTree) Python objects, as detected by mismatching
id()
.
I'm keeping track of non-PyTree objects based on a hash I calculate of them. Not sure if taking the id()
might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- @christian: Should we say something about calling frozen functions from
other frozen functions?
This should not be a problem, as the inner function is simply recorded by the outer one. I'll check that we actually test for that though. I think it should however not be something the user has to be concerned about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- @christian: Any other caveats? Is the part about dr.mean() still relevant? I
had removed that part since I thought we had a solution for dr.mean().
The mean
case should be covered by #361 . We could leave the comment here and remove it in the other PR?
4ca120b
to
c0ae426
Compare
1f08b94
to
f57a954
Compare
c821068
to
dcb5217
Compare
7de6d1c
to
bec6588
Compare
Added comment explaining traverse with `rw` arguments Added comment about c++ traversal Added recursion guards Simplified comment on TraverseCallback operator() Removed redundant profiling loop Fixed recursion_guard by wrapping it in namespace Fixed typing import Improved handling of nullptr variant/domains Removed warn log Added JitFlag FreezingTraverseScope Using try_cast in get_traversable_base Renamed to EnableObjectTraversal flag Added comment for traverse_py_cb_ro_impl Added rw value for traversal of trampolines Removed outdated function Excluding texture from traversal outside of frozen function Testing fix for windows compilation Fixing windows compilation bugs Added base value test for custom_type_ext traversal Added nested traversal test Exposing frozen function related flags to python Improved warning about non-drjit types Added option to specify backend if no input variable is specified Allow for changes in the last dimension of tensor shapes Marked payload and fn as used Reverted comment Formatting Removed printing of keys Fixed backend test Added method to detect recording frozen functions with the wrong backend Added comments for arguments of dr.freeze decorator Added flag test to DR_TRAVERSE_CB_RO/RW macros Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro Fixed tensor freezing indexing issue Removed deprecated logging code Added comment about tensor shape inference Using Wenzel's documentation, and renamed arguments Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear` Fixed typo Added warning text to documentation
48daace
to
843f668
Compare
This PR implements the frozen function feature in
drjit
.It adds a new
FrozenFunction
struct that is a wrapper over a Python callable.When calling the
operator()
of the object, it traverses the input and determines a unique layout/key.This key is then used to determine if the function has to be recorded or if it's possible to replay it.
The PR relies on functionality implemented in Frozen ThreadState, and #288.
This replaces #317 and changes the source branch.