Skip to content

Commit c12c89e

Browse files
Frozen Functions Implementation
1 parent e252532 commit c12c89e

26 files changed

+5837
-61
lines changed

drjit/__init__.py

+178
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Callable, TypeVar
12
from . import detail
23

34
with detail.scoped_rtld_deepbind():
@@ -1878,6 +1879,183 @@ def binary_search(start, end, pred):
18781879

18791880
return start
18801881

1882+
# Represents the frozen function passed to the decorator without arguments
1883+
F = TypeVar("F")
1884+
# Represents the frozen function passed to the decorator with arguments
1885+
F2 = TypeVar("F2")
1886+
1887+
@overload
1888+
def freeze(
1889+
f: None = None,
1890+
*,
1891+
state: Optional[Callable],
1892+
max_cache_size: Optional[int] = None,
1893+
warn_recording_count: int = 10,
1894+
) -> Callable[[F], F]:
1895+
"""
1896+
Decorator to freeze a function for replaying kernels without re-tracing.
1897+
1898+
This decorator wraps a function, and enables replaying it in order remove the
1899+
need for tracing python operations. The first time a frozen function is called,
1900+
it is executed regularly, while all operations performed are recorded. When the
1901+
frozen function is called again, with compatible arguments, the recorded
1902+
operations are replayed instead of launching the Python function. This saves on
1903+
the overhead that is inherent when tracing and compiling python operations into
1904+
kernels. The frozen function keeps a record of all previously recorded calls,
1905+
and associates them with the input layout. Since the kernels might be compiled
1906+
differently when changing certain attributes of variables, the frozen function
1907+
will be re-recorded if the inputs are no longer compatible with previous
1908+
recordings. We also track the layout of any container and the values of Python
1909+
types of the inputs to the frozen function. The following are the input
1910+
attributes which are tracked and might cause the function to be re-recorded if
1911+
changed.
1912+
1913+
- Type of any variable or container
1914+
- Number of members in a container, such as the length of a list
1915+
- Key or field names, such as dictionary keys or dataclass field names
1916+
- The `Dr.Jit` type of any variable
1917+
- Whether a variable has size $1$
1918+
- Whether the memory, referenced by a variable is unaligned (only applies to
1919+
mapped NumPy arrays)
1920+
- Whether the variable has gradients enabled
1921+
- The sets of variables that have the same size
1922+
- The hash of any Python variable, that is not a Dr.Jit variable or a container
1923+
1924+
Generally, the widths of the DrJit variables in the input can be changed in
1925+
each call without triggering a retracing of the function.
1926+
However, variables that had the same width during recording must keep the
1927+
same width in future calls, otherwise the function must be traced again.
1928+
1929+
Please be particularly wary of using `dr.width()` within a recorded function.
1930+
Usage such as `dr.arange(UInt32, dr.width(a))` is allowed, since `dr.arange`
1931+
uses the width only implicitly, via the launch dimension of the kernel.
1932+
But direct usage, or computations on the width, would no longer be correct
1933+
in subsequent launches of the frozen function if the width of the inputs changes,
1934+
for example:
1935+
1936+
```python
1937+
y = dr.gather(type(x), x, dr.width(x)//2)
1938+
```
1939+
1940+
Similarly, calculating the mean of a variable relies on the number of entries,
1941+
which will be baked into the frozen function. To avoid this, we suggest
1942+
supplying the number of entries as a Dr.Jit literal in the arguments to the
1943+
function.
1944+
"""
1945+
1946+
1947+
@overload
1948+
def freeze(
1949+
f: F,
1950+
*,
1951+
state: Optional[Callable] = None,
1952+
max_cache_size: Optional[int] = None,
1953+
warn_recording_count: int = 10,
1954+
) -> F: ...
1955+
1956+
1957+
def freeze(
1958+
f: Optional[F] = None,
1959+
*,
1960+
state: Optional[Callable] = None,
1961+
max_cache_size: Optional[int] = None,
1962+
warn_recording_count: int = 10,
1963+
) -> Union[F, Callable[[F2], F2]]:
1964+
max_cache_size = max_cache_size if max_cache_size is not None else -1
1965+
1966+
def decorator(f):
1967+
"""
1968+
Internal decorator, returned in ``dr.freeze`` was used with arguments.
1969+
"""
1970+
import functools
1971+
import inspect
1972+
1973+
def inner(closure, *args, **kwargs):
1974+
"""
1975+
This inner function is the one that gets actually frozen. It receives
1976+
any additional state such as closures or state specified with the
1977+
``state`` lambda, and allows for traversal of it.
1978+
"""
1979+
return f(*args, **kwargs)
1980+
1981+
class FrozenFunction:
1982+
def __init__(self, f) -> None:
1983+
closure = inspect.getclosurevars(f)
1984+
self.closure = (closure.nonlocals, closure.globals)
1985+
self.frozen = detail.FrozenFunction(
1986+
inner, max_cache_size, warn_recording_count
1987+
)
1988+
1989+
def __call__(self, *args, **kwargs):
1990+
_state = state(*args, **kwargs) if state is not None else None
1991+
return self.frozen([self.closure, _state], *args, **kwargs)
1992+
1993+
@property
1994+
def n_recordings(self):
1995+
"""
1996+
Represents the number of times the function was recorded. This
1997+
includes occasions where it was recorded due to a dry-run failing.
1998+
It does not necessarily correspond to the number of recordings
1999+
currently cached see ``n_cached_recordings`` for that.
2000+
"""
2001+
return self.frozen.n_recordings
2002+
2003+
@property
2004+
def n_cached_recordings(self):
2005+
"""
2006+
Represents the number of recordings currently cached of the frozen
2007+
function. If a recording fails in dry-run mode, it will not create
2008+
a new recording, but replace the recording that was attemted to be
2009+
replayed. The number of recordings can also be limited with
2010+
the ``max_cache_size`` argument.
2011+
"""
2012+
return self.frozen.n_cached_recordings
2013+
2014+
def clear(self):
2015+
"""
2016+
Clears the recordings of the frozen function, and resets the
2017+
``n_recordings`` counter. The reference to the function is still
2018+
kept, and the frozen function can be called again to re-trace
2019+
new recordings.
2020+
"""
2021+
return self.frozen.clear()
2022+
2023+
def __get__(self, obj, type=None):
2024+
if obj is None:
2025+
return self
2026+
else:
2027+
return FrozenMethod(self.frozen, self.closure, obj)
2028+
2029+
class FrozenMethod(FrozenFunction):
2030+
"""
2031+
A FrozenMethod currying the object into the __call__ method.
2032+
2033+
If the ``freeze`` decorator is applied to a method of some class, it has
2034+
to call the internal frozen function with the ``self`` argument. To this
2035+
end we implement the ``__get__`` method of the frozen function, to
2036+
return a ``FrozenMethod``, which holds a reference to the object.
2037+
The ``__call__`` method of the ``FrozenMethod`` then supplies the object
2038+
in addition to the arguments to the internal function.
2039+
"""
2040+
def __init__(self, frozen, closure, obj) -> None:
2041+
self.obj = obj
2042+
self.frozen = frozen
2043+
self.closure = closure
2044+
2045+
def __call__(self, *args, **kwargs):
2046+
_state = state(self.obj, *args, **kwargs) if state is not None else None
2047+
return self.frozen([self.closure, _state], self.obj, *args, **kwargs)
2048+
2049+
return functools.wraps(f)(FrozenFunction(f))
2050+
2051+
if f is not None:
2052+
return decorator(f)
2053+
else:
2054+
return decorator
2055+
2056+
2057+
del F
2058+
del F2
18812059

18822060
def assert_true(
18832061
cond,

include/drjit/array_traverse.h

+57-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#pragma once
1717

18+
#include <type_traits>
19+
1820
#define DRJIT_STRUCT_NODEF(Name, ...) \
1921
Name(const Name &) = default; \
2022
Name(Name &&) = default; \
@@ -140,6 +142,18 @@ namespace detail {
140142
using det_traverse_1_cb_rw =
141143
decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr));
142144

145+
template <typename T>
146+
using det_get = decltype(std::declval<T&>().get());
147+
148+
template <typename T>
149+
using det_const_get = decltype(std::declval<const T &>().get());
150+
151+
template<typename T>
152+
using det_begin = decltype(std::declval<T &>().begin());
153+
154+
template<typename T>
155+
using det_end = decltype(std::declval<T &>().end());
156+
143157
inline drjit::string get_label(const char *s, size_t i) {
144158
auto skip = [](char c) {
145159
return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ',';
@@ -180,10 +194,17 @@ template <typename T> auto labels(const T &v) {
180194
}
181195

182196
template <typename Value>
183-
void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) {
184-
(void) payload; (void) fn;
197+
void traverse_1_fn_ro(const Value &value, void *payload,
198+
void (*fn)(void *, uint64_t, const char *,
199+
const char *)) {
200+
DRJIT_MARK_USED(payload);
201+
DRJIT_MARK_USED(fn);
185202
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
186-
fn(payload, value.index_combined());
203+
if constexpr(Value::IsClass)
204+
fn(payload, value.index_combined(), Value::CallSupport::Variant,
205+
Value::CallSupport::Domain);
206+
else
207+
fn(payload, value.index_combined(), "", "");
187208
} else if constexpr (is_traversable_v<Value>) {
188209
traverse_1(fields(value), [payload, fn](auto &x) {
189210
traverse_1_fn_ro(x, payload, fn);
@@ -198,14 +219,34 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint
198219
is_detected_v<detail::det_traverse_1_cb_ro, Value>) {
199220
if (value)
200221
value->traverse_1_cb_ro(payload, fn);
222+
223+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
224+
is_detected_v<detail::det_end, Value>) {
225+
for (auto elem : value) {
226+
traverse_1_fn_ro(elem, payload, fn);
227+
}
228+
} else if constexpr (is_detected_v<detail::det_const_get, Value>) {
229+
const auto *tmp = value.get();
230+
traverse_1_fn_ro(tmp, payload, fn);
231+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
232+
value.traverse_1_cb_ro(payload, fn);
201233
}
202234
}
203235

204236
template <typename Value>
205-
void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) {
206-
(void) payload; (void) fn;
237+
void traverse_1_fn_rw(Value &value, void *payload,
238+
uint64_t (*fn)(void *, uint64_t, const char *,
239+
const char *)) {
240+
DRJIT_MARK_USED(payload);
241+
DRJIT_MARK_USED(fn);
207242
if constexpr (is_jit_v<Value> && depth_v<Value> == 1) {
208-
value = Value::borrow((typename Value::Index) fn(payload, value.index_combined()));
243+
if constexpr(Value::IsClass)
244+
value = Value::borrow((typename Value::Index) fn(
245+
payload, value.index_combined(), Value::CallSupport::Variant,
246+
Value::CallSupport::Domain));
247+
else
248+
value = Value::borrow((typename Value::Index) fn(
249+
payload, value.index_combined(), "", ""));
209250
} else if constexpr (is_traversable_v<Value>) {
210251
traverse_1(fields(value), [payload, fn](auto &x) {
211252
traverse_1_fn_rw(x, payload, fn);
@@ -220,6 +261,16 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64
220261
is_detected_v<detail::det_traverse_1_cb_rw, Value>) {
221262
if (value)
222263
value->traverse_1_cb_rw(payload, fn);
264+
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
265+
is_detected_v<detail::det_end, Value>) {
266+
for (auto elem : value) {
267+
traverse_1_fn_rw(elem, payload, fn);
268+
}
269+
} else if constexpr (is_detected_v<detail::det_get, Value>) {
270+
auto *tmp = value.get();
271+
traverse_1_fn_rw(tmp, payload, fn);
272+
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
273+
value.traverse_1_cb_rw(payload, fn);
223274
}
224275
}
225276

include/drjit/autodiff.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,9 @@ NAMESPACE_BEGIN(detail)
971971
/// Internal operations for traversing nested data structures and fetching or
972972
/// storing indices. Used in ``call.h`` and ``loop.h``.
973973

974-
template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
974+
template <bool IncRef>
975+
void collect_indices_fn(void *p, uint64_t index, const char * /*variant*/,
976+
const char * /*domain*/) {
975977
vector<uint64_t> &indices = *(vector<uint64_t> *) p;
976978
if constexpr (IncRef)
977979
index = ad_var_inc_ref(index);
@@ -983,7 +985,8 @@ struct update_indices_payload {
983985
size_t &pos;
984986
};
985987

986-
inline uint64_t update_indices_fn(void *p, uint64_t) {
988+
inline uint64_t update_indices_fn(void *p, uint64_t, const char * /*variant*/,
989+
const char * /*domain*/) {
987990
update_indices_payload &payload = *(update_indices_payload *) p;
988991
return payload.indices[payload.pos++];
989992
}

include/drjit/extra.h

+5
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ extern DRJIT_EXTRA_EXPORT bool ad_release_one_output(drjit::detail::CustomOpBase
247247
extern DRJIT_EXTRA_EXPORT void ad_copy_implicit_deps(drjit::vector<uint32_t> &,
248248
bool input);
249249

250+
/// Retrieve a list of ad indices, that are the target of edges, that have been
251+
/// postponed by the current scope
252+
extern DRJIT_EXTRA_EXPORT void ad_scope_postponed(drjit::vector<uint32_t> *dst);
253+
254+
250255
/// Kahan-compensated floating point atomic scatter-addition
251256
extern DRJIT_EXTRA_EXPORT void
252257
ad_var_scatter_add_kahan(uint64_t *target_1, uint64_t *target_2, uint64_t value,

0 commit comments

Comments
 (0)