Skip to content

Commit f57a954

Browse files
Frozen Functions Implementation
Added comment explaining traverse with `rw` arguments Added comment about c++ traversal Added recursion guards Simplified comment on TraverseCallback operator() Removed redundant profiling loop Fixed recursion_guard by wrapping it in namespace Fixed typing import Improved handling of nullptr variant/domains Removed warn log Added JitFlag FreezingTraverseScope Using try_cast in get_traversable_base Renamed to EnableObjectTraversal flag Added comment for traverse_py_cb_ro_impl Added rw value for traversal of trampolines Removed outdated function Excluding texture from traversal outside of frozen function Testing fix for windows compilation Fixing windows compilation bugs Added base value test for custom_type_ext traversal Added nested traversal test Exposing frozen function related flags to python Improved warning about non-drjit types Added option to specify backend if no input variable is specified Allow for changes in the last dimension of tensor shapes Marked payload and fn as used Reverted comment Formatting Removed printing of keys Fixed backend test Added method to detect recording frozen functions with the wrong backend Added comments for arguments of dr.freeze decorator Added flag test to DR_TRAVERSE_CB_RO/RW macros Removed test for `EnableObjectTraversal` flag in `DR_TRAVERSE_CB` macro Fixed tensor freezing indexing issue Removed deprecated logging code Added comment about tensor shape inference Using Wenzel's documentation, and renamed arguments Added comment regarding `frozen_function_tp_traverse` and `frozen_function_clear` Fixed typo Added warning text to documentation
1 parent 73261b6 commit f57a954

26 files changed

+6070
-62
lines changed

drjit/__init__.py

+252-2
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
import sys as _sys
2020
if _sys.version_info < (3, 11):
2121
try:
22-
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable
22+
from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar
2323
except ImportError:
2424
raise RuntimeError(
2525
"Dr.Jit requires the 'typing_extensions' package on Python <3.11")
2626
else:
27-
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable
27+
from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar
2828

2929
from .ast import syntax, hint
3030
from .interop import wrap
@@ -2401,6 +2401,256 @@ def binary_search(start, end, pred):
24012401

24022402
return start
24032403

2404+
# Represents the frozen function passed to the decorator without arguments
2405+
F = TypeVar("F")
2406+
# Represents the frozen function passed to the decorator with arguments
2407+
F2 = TypeVar("F2")
2408+
2409+
@overload
2410+
def freeze(
2411+
f: None = None,
2412+
*,
2413+
state_fn: Optional[Callable],
2414+
limit: Optional[int] = None,
2415+
warn_after: int = 10,
2416+
backend: Optional[JitBackend] = None,
2417+
) -> Callable[[F], F]:
2418+
"""
2419+
Decorator to "freeze" functions, which improves efficiency by removing
2420+
repeated JIT tracing overheads.
2421+
2422+
In general, Dr.Jit traces computation and then compiles and launches kernels
2423+
containing this trace (see the section on :ref:`evaluation <eval>` for
2424+
details). While the compilation step can often be skipped via caching, the
2425+
tracing cost can still be significant especially when repeatedly evaluating
2426+
complex models, e.g., as part of an optimization loop.
2427+
2428+
The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by
2429+
altogether removing the need to trace repeatedly. For example, consider the
2430+
following decorated function:
2431+
2432+
.. code-block:: python
2433+
2434+
@dr.freeze
2435+
def f(x, y, z):
2436+
return ... # Complicated code involving the arguments
2437+
2438+
Dr.Jit will trace the first call to the decorated function ``f()``, while
2439+
collecting additional information regarding the nature of the function's inputs
2440+
and regarding the CPU/GPU kernel launches representing the body of ``f()``.
2441+
2442+
If the function is subsequently called with *compatible* arguments (more on
2443+
this below), it will immediately launch the previously made CPU/GPU kernels
2444+
without re-tracing, which can substantially improve performance.
2445+
2446+
When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x``
2447+
having a different type compared to the previous call), it will conservatively
2448+
re-trace the body and keep track of another potential input configuration.
2449+
2450+
Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function
2451+
arguments and return values.
2452+
2453+
The following may trigger re-tracing:
2454+
2455+
- Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element.
2456+
- Changes in the **length** of a container (``list``, ``tuple``, ``dict``).
2457+
- Changes of **dictionary keys** or **field names** of dataclasses.
2458+
- Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable.
2459+
- Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``.
2460+
2461+
The following more technical conditions also trigger re-tracing:
2462+
- A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``).
2463+
- The sets of variables of the same size change. In the example above, this
2464+
would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)``
2465+
subsequently.
2466+
- When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the
2467+
memory can be aligned or unaligned. A re-tracing step is needed when this
2468+
status changes.
2469+
2470+
These all correspond to situations where the generated kernel code may need to
2471+
change, and the system conservatively re-traces to ensure correctness.
2472+
2473+
Frozen functions support arguments with a different variable *width* (see
2474+
:py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of
2475+
variables of the same width stay consistent.
2476+
2477+
Some constructions are problematic and should be avoided in frozen functions.
2478+
2479+
- The function :py:func:`dr.width() <drjit.width>` returns an integer literal
2480+
that may be merged into the generated code. If the frozen function is later
2481+
rerun with differently-sized arguments, the executed kernels will still
2482+
reference the old size. One exception to this rule are constructions like
2483+
`dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on
2484+
the width value.
2485+
2486+
**Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes
2487+
several optional parameters that are helpful in certain situations.
2488+
2489+
- **Warning when re-tracing happens too often**: Incompatible arguments trigger
2490+
re-tracing, which can mask issues where *accidentally* incompatible arguments
2491+
keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected
2492+
performance benefits.
2493+
2494+
In such situations, it can be helpful to warn and identify changing
2495+
parameters by name. This feature is enabled and set to ``10`` by default.
2496+
2497+
.. code-block:: pycon
2498+
2499+
>>> @dr.freeze(warn_after=1)
2500+
>>> def f(x):
2501+
... return x
2502+
...
2503+
>>> f(Int(1))
2504+
>>> f(Float(1))
2505+
The frozen function has been recorded 2 times, this indicates a problem
2506+
with how the frozen function is being called. For example, calling it
2507+
with changing python values such as an index. For more information about
2508+
which variables changed set the log level to ``LogLevel::Debug``.
2509+
2510+
- **Limiting memory usage**. Storing kernels for many possible input
2511+
configuration requires device memory, which can become problematic. Set the
2512+
``limit=`` parameter to enable a LRU cache. This is useful when calls to a
2513+
function are mostly compatible but require occasional re-tracing.
2514+
2515+
Args:
2516+
limit (Optional[int]): An optional integer specifying the maximum number of
2517+
stored configurations. Once this limit is reached, incompatible calls
2518+
requiring re-tracing will cause the last used configuration to be dropped.
2519+
2520+
warn_after (int): When the number of re-tracing steps exceeds this value,
2521+
Dr.Jit will generate a warning that explains which variables changed
2522+
between calls to the function.
2523+
2524+
state_fn (Optional[Callable]): This optional callable can specify additional
2525+
state to identifies the configuration. ``state_fn`` will be called with
2526+
the same arguments as that of the decorated function. It should return a
2527+
traversable object (e.g., a list or tuple) that is conceptually treated
2528+
as if it was another input of the function.
2529+
2530+
backend (Optional[JitBackend]): If no inputs are given when calling the
2531+
frozen function, the backend used has to be specified using this argument.
2532+
It must match the backend used for computation within the function.
2533+
"""
2534+
2535+
2536+
@overload
2537+
def freeze(
2538+
f: F,
2539+
*,
2540+
state_fn: Optional[Callable] = None,
2541+
limit: Optional[int] = None,
2542+
warn_after: int = 10,
2543+
backend: Optional[JitBackend] = None,
2544+
) -> F: ...
2545+
2546+
2547+
def freeze(
2548+
f: Optional[F] = None,
2549+
*,
2550+
state_fn: Optional[Callable] = None,
2551+
limit: Optional[int] = None,
2552+
warn_after: int = 10,
2553+
backend: Optional[JitBackend] = None,
2554+
) -> Union[F, Callable[[F2], F2]]:
2555+
limit = limit if limit is not None else -1
2556+
backend = backend if backend is not None else JitBackend.Invalid
2557+
2558+
def decorator(f):
2559+
"""
2560+
Internal decorator, returned in ``dr.freeze`` was used with arguments.
2561+
"""
2562+
import functools
2563+
import inspect
2564+
2565+
def inner(closure, *args, **kwargs):
2566+
"""
2567+
This inner function is the one that gets actually frozen. It receives
2568+
any additional state such as closures or state specified with the
2569+
``state`` lambda, and allows for traversal of it.
2570+
"""
2571+
return f(*args, **kwargs)
2572+
2573+
class FrozenFunction:
2574+
def __init__(self, f) -> None:
2575+
closure = inspect.getclosurevars(f)
2576+
self.closure = (closure.nonlocals, closure.globals)
2577+
self.frozen = detail.FrozenFunction(
2578+
inner,
2579+
limit,
2580+
warn_after,
2581+
backend,
2582+
)
2583+
2584+
def __call__(self, *args, **kwargs):
2585+
_state = state_fn(*args, **kwargs) if state_fn is not None else None
2586+
return self.frozen([self.closure, _state], *args, **kwargs)
2587+
2588+
@property
2589+
def n_recordings(self):
2590+
"""
2591+
Represents the number of times the function was recorded. This
2592+
includes occasions where it was recorded due to a dry-run failing.
2593+
It does not necessarily correspond to the number of recordings
2594+
currently cached see ``n_cached_recordings`` for that.
2595+
"""
2596+
return self.frozen.n_recordings
2597+
2598+
@property
2599+
def n_cached_recordings(self):
2600+
"""
2601+
Represents the number of recordings currently cached of the frozen
2602+
function. If a recording fails in dry-run mode, it will not create
2603+
a new recording, but replace the recording that was attemted to be
2604+
replayed. The number of recordings can also be limited with
2605+
the ``max_cache_size`` argument.
2606+
"""
2607+
return self.frozen.n_cached_recordings
2608+
2609+
def clear(self):
2610+
"""
2611+
Clears the recordings of the frozen function, and resets the
2612+
``n_recordings`` counter. The reference to the function is still
2613+
kept, and the frozen function can be called again to re-trace
2614+
new recordings.
2615+
"""
2616+
return self.frozen.clear()
2617+
2618+
def __get__(self, obj, type=None):
2619+
if obj is None:
2620+
return self
2621+
else:
2622+
return FrozenMethod(self.frozen, self.closure, obj)
2623+
2624+
class FrozenMethod(FrozenFunction):
2625+
"""
2626+
A FrozenMethod currying the object into the __call__ method.
2627+
2628+
If the ``freeze`` decorator is applied to a method of some class, it has
2629+
to call the internal frozen function with the ``self`` argument. To this
2630+
end we implement the ``__get__`` method of the frozen function, to
2631+
return a ``FrozenMethod``, which holds a reference to the object.
2632+
The ``__call__`` method of the ``FrozenMethod`` then supplies the object
2633+
in addition to the arguments to the internal function.
2634+
"""
2635+
def __init__(self, frozen, closure, obj) -> None:
2636+
self.obj = obj
2637+
self.frozen = frozen
2638+
self.closure = closure
2639+
2640+
def __call__(self, *args, **kwargs):
2641+
_state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None
2642+
return self.frozen([self.closure, _state], self.obj, *args, **kwargs)
2643+
2644+
return functools.wraps(f)(FrozenFunction(f))
2645+
2646+
if f is not None:
2647+
return decorator(f)
2648+
else:
2649+
return decorator
2650+
2651+
2652+
del F
2653+
del F2
24042654

24052655
def assert_true(
24062656
cond,

include/drjit/array_traverse.h

+57-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#pragma once
1717

18+
#include <type_traits>
19+
1820
#define DRJIT_STRUCT_NODEF(Name, ...) \
1921
Name(const Name &) = default; \
2022
Name(Name &&) = default; \
@@ -140,6 +142,18 @@ namespace detail {
140142
using det_traverse_1_cb_rw =
141143
decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr));
142144

145+
template <typename T>
146+
using det_get = decltype(std::declval<T&>().get());
147+
148+
template <typename T>
149+
using det_const_get = decltype(std::declval<const T &>().get());
150+
151+
template<typename T>
152+
using det_begin = decltype(std::declval<T &>().begin());
153+
154+
template<typename T>
155+
using det_end = decltype(std::declval<T &>().end());
156+
143157
inline drjit::string get_label(const char *s, size_t i) {
144158
auto skip = [](char c) {
145159
return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ',';
@@ -180,10 +194,17 @@ template <typename T> auto labels(const T &v) {
180194
}
181195

182196
template <typename Value>
183-
void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) {
184-
(void) payload; (void) fn;
197+
void traverse_1_fn_ro(const Value &value, void *payload,
198+
void (*fn)(void *, uint64_t, const char *,
199+
const char *)) {
200+
DRJIT_MARK_USED(payload);
201+
DRJIT_MARK_USED(fn);
185202
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
186-
fn(payload, value.index_combined());
203+
if constexpr(Value::IsClass)
204+
fn(payload, value.index_combined(), Value::CallSupport::Variant,
205+
Value::CallSupport::Domain);
206+
else
207+
fn(payload, value.index_combined(), "", "");
187208
} else if constexpr (is_traversable_v<Value>) {
188209
traverse_1(fields(value), [payload, fn](auto &x) {
189210
traverse_1_fn_ro(x, payload, fn);
@@ -198,14 +219,34 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint
198219
is_detected_v<detail::det_traverse_1_cb_ro, Value>) {
199220
if (value)
200221
value->traverse_1_cb_ro(payload, fn);
222+
223+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
224+
is_detected_v<detail::det_end, Value>) {
225+
for (auto elem : value) {
226+
traverse_1_fn_ro(elem, payload, fn);
227+
}
228+
} else if constexpr (is_detected_v<detail::det_const_get, Value>) {
229+
const auto *tmp = value.get();
230+
traverse_1_fn_ro(tmp, payload, fn);
231+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
232+
value.traverse_1_cb_ro(payload, fn);
201233
}
202234
}
203235

204236
template <typename Value>
205-
void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) {
206-
(void) payload; (void) fn;
237+
void traverse_1_fn_rw(Value &value, void *payload,
238+
uint64_t (*fn)(void *, uint64_t, const char *,
239+
const char *)) {
240+
DRJIT_MARK_USED(payload);
241+
DRJIT_MARK_USED(fn);
207242
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
208-
value = Value::borrow((typename Value::Index) fn(payload, value.index_combined()));
243+
if constexpr(Value::IsClass)
244+
value = Value::borrow((typename Value::Index) fn(
245+
payload, value.index_combined(), Value::CallSupport::Variant,
246+
Value::CallSupport::Domain));
247+
else
248+
value = Value::borrow((typename Value::Index) fn(
249+
payload, value.index_combined(), "", ""));
209250
} else if constexpr (is_traversable_v<Value>) {
210251
traverse_1(fields(value), [payload, fn](auto &x) {
211252
traverse_1_fn_rw(x, payload, fn);
@@ -220,6 +261,16 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64
220261
is_detected_v<detail::det_traverse_1_cb_rw, Value>) {
221262
if (value)
222263
value->traverse_1_cb_rw(payload, fn);
264+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
265+
is_detected_v<detail::det_end, Value>) {
266+
for (auto elem : value) {
267+
traverse_1_fn_rw(elem, payload, fn);
268+
}
269+
} else if constexpr (is_detected_v<detail::det_get, Value>) {
270+
auto *tmp = value.get();
271+
traverse_1_fn_rw(tmp, payload, fn);
272+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
273+
value.traverse_1_cb_rw(payload, fn);
223274
}
224275
}
225276

include/drjit/autodiff.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,9 @@ NAMESPACE_BEGIN(detail)
973973
/// Internal operations for traversing nested data structures and fetching or
974974
/// storing indices. Used in ``call.h`` and ``loop.h``.
975975

976-
template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
976+
template <bool IncRef>
977+
void collect_indices_fn(void *p, uint64_t index, const char * /*variant*/,
978+
const char * /*domain*/) {
977979
vector<uint64_t> &indices = *(vector<uint64_t> *) p;
978980
if constexpr (IncRef)
979981
index = ad_var_inc_ref(index);
@@ -985,7 +987,8 @@ struct update_indices_payload {
985987
size_t &pos;
986988
};
987989

988-
inline uint64_t update_indices_fn(void *p, uint64_t) {
990+
inline uint64_t update_indices_fn(void *p, uint64_t, const char * /*variant*/,
991+
const char * /*domain*/) {
989992
update_indices_payload &payload = *(update_indices_payload *) p;
990993
return payload.indices[payload.pos++];
991994
}

0 commit comments

Comments
 (0)