|
| 1 | +from typing import Callable, TypeVar |
1 | 2 | from . import detail
|
2 | 3 |
|
3 | 4 | with detail.scoped_rtld_deepbind():
|
@@ -1878,6 +1879,183 @@ def binary_search(start, end, pred):
|
1878 | 1879 |
|
1879 | 1880 | return start
|
1880 | 1881 |
|
| 1882 | +# Represents the frozen function passed to the decorator without arguments |
| 1883 | +F = TypeVar("F") |
| 1884 | +# Represents the frozen function passed to the decorator with arguments |
| 1885 | +F2 = TypeVar("F2") |
| 1886 | + |
| 1887 | +@overload |
| 1888 | +def freeze( |
| 1889 | + f: None = None, |
| 1890 | + *, |
| 1891 | + state: Optional[Callable], |
| 1892 | + max_cache_size: Optional[int] = None, |
| 1893 | + warn_recording_count: int = 10, |
| 1894 | +) -> Callable[[F], F]: |
| 1895 | + """ |
| 1896 | + Decorator to freeze a function for replaying kernels without re-tracing. |
| 1897 | +
|
| 1898 | + This decorator wraps a function, and enables replaying it in order remove the |
| 1899 | + need for tracing python operations. The first time a frozen function is called, |
| 1900 | + it is executed regularly, while all operations performed are recorded. When the |
| 1901 | + frozen function is called again, with compatible arguments, the recorded |
| 1902 | + operations are replayed instead of launching the Python function. This saves on |
| 1903 | + the overhead that is inherent when tracing and compiling python operations into |
| 1904 | + kernels. The frozen function keeps a record of all previously recorded calls, |
| 1905 | + and associates them with the input layout. Since the kernels might be compiled |
| 1906 | + differently when changing certain attributes of variables, the frozen function |
| 1907 | + will be re-recorded if the inputs are no longer compatible with previous |
| 1908 | + recordings. We also track the layout of any container and the values of Python |
| 1909 | + types of the inputs to the frozen function. The following are the input |
| 1910 | + attributes which are tracked and might cause the function to be re-recorded if |
| 1911 | + changed. |
| 1912 | +
|
| 1913 | + - Type of any variable or container |
| 1914 | + - Number of members in a container, such as the length of a list |
| 1915 | + - Key or field names, such as dictionary keys or dataclass field names |
| 1916 | + - The `Dr.Jit` type of any variable |
| 1917 | + - Whether a variable has size $1$ |
| 1918 | + - Whether the memory, referenced by a variable is unaligned (only applies to |
| 1919 | + mapped NumPy arrays) |
| 1920 | + - Whether the variable has gradients enabled |
| 1921 | + - The sets of variables that have the same size |
| 1922 | + - The hash of any Python variable, that is not a Dr.Jit variable or a container |
| 1923 | +
|
| 1924 | + Generally, the widths of the DrJit variables in the input can be changed in |
| 1925 | + each call without triggering a retracing of the function. |
| 1926 | + However, variables that had the same width during recording must keep the |
| 1927 | + same width in future calls, otherwise the function must be traced again. |
| 1928 | +
|
| 1929 | + Please be particularly wary of using `dr.width()` within a recorded function. |
| 1930 | + Usage such as `dr.arange(UInt32, dr.width(a))` is allowed, since `dr.arange` |
| 1931 | + uses the width only implicitly, via the launch dimension of the kernel. |
| 1932 | + But direct usage, or computations on the width, would no longer be correct |
| 1933 | + in subsequent launches of the frozen function if the width of the inputs changes, |
| 1934 | + for example: |
| 1935 | +
|
| 1936 | + ```python |
| 1937 | + y = dr.gather(type(x), x, dr.width(x)//2) |
| 1938 | + ``` |
| 1939 | +
|
| 1940 | + Similarly, calculating the mean of a variable relies on the number of entries, |
| 1941 | + which will be baked into the frozen function. To avoid this, we suggest |
| 1942 | + supplying the number of entries as a Dr.Jit literal in the arguments to the |
| 1943 | + function. |
| 1944 | + """ |
| 1945 | + |
| 1946 | + |
| 1947 | +@overload |
| 1948 | +def freeze( |
| 1949 | + f: F, |
| 1950 | + *, |
| 1951 | + state: Optional[Callable] = None, |
| 1952 | + max_cache_size: Optional[int] = None, |
| 1953 | + warn_recording_count: int = 10, |
| 1954 | +) -> F: ... |
| 1955 | + |
| 1956 | + |
| 1957 | +def freeze( |
| 1958 | + f: Optional[F] = None, |
| 1959 | + *, |
| 1960 | + state: Optional[Callable] = None, |
| 1961 | + max_cache_size: Optional[int] = None, |
| 1962 | + warn_recording_count: int = 10, |
| 1963 | +) -> Union[F, Callable[[F2], F2]]: |
| 1964 | + max_cache_size = max_cache_size if max_cache_size is not None else -1 |
| 1965 | + |
| 1966 | + def decorator(f): |
| 1967 | + """ |
| 1968 | + Internal decorator, returned in ``dr.freeze`` was used with arguments. |
| 1969 | + """ |
| 1970 | + import functools |
| 1971 | + import inspect |
| 1972 | + |
| 1973 | + def inner(closure, *args, **kwargs): |
| 1974 | + """ |
| 1975 | + This inner function is the one that gets actually frozen. It receives |
| 1976 | + any additional state such as closures or state specified with the |
| 1977 | + ``state`` lambda, and allows for traversal of it. |
| 1978 | + """ |
| 1979 | + return f(*args, **kwargs) |
| 1980 | + |
| 1981 | + class FrozenFunction: |
| 1982 | + def __init__(self, f) -> None: |
| 1983 | + closure = inspect.getclosurevars(f) |
| 1984 | + self.closure = (closure.nonlocals, closure.globals) |
| 1985 | + self.frozen = detail.FrozenFunction( |
| 1986 | + inner, max_cache_size, warn_recording_count |
| 1987 | + ) |
| 1988 | + |
| 1989 | + def __call__(self, *args, **kwargs): |
| 1990 | + _state = state(*args, **kwargs) if state is not None else None |
| 1991 | + return self.frozen([self.closure, _state], *args, **kwargs) |
| 1992 | + |
| 1993 | + @property |
| 1994 | + def n_recordings(self): |
| 1995 | + """ |
| 1996 | + Represents the number of times the function was recorded. This |
| 1997 | + includes occasions where it was recorded due to a dry-run failing. |
| 1998 | + It does not necessarily correspond to the number of recordings |
| 1999 | + currently cached see ``n_cached_recordings`` for that. |
| 2000 | + """ |
| 2001 | + return self.frozen.n_recordings |
| 2002 | + |
| 2003 | + @property |
| 2004 | + def n_cached_recordings(self): |
| 2005 | + """ |
| 2006 | + Represents the number of recordings currently cached of the frozen |
| 2007 | + function. If a recording fails in dry-run mode, it will not create |
| 2008 | + a new recording, but replace the recording that was attemted to be |
| 2009 | + replayed. The number of recordings can also be limited with |
| 2010 | + the ``max_cache_size`` argument. |
| 2011 | + """ |
| 2012 | + return self.frozen.n_cached_recordings |
| 2013 | + |
| 2014 | + def clear(self): |
| 2015 | + """ |
| 2016 | + Clears the recordings of the frozen function, and resets the |
| 2017 | + ``n_recordings`` counter. The reference to the function is still |
| 2018 | + kept, and the frozen function can be called again to re-trace |
| 2019 | + new recordings. |
| 2020 | + """ |
| 2021 | + return self.frozen.clear() |
| 2022 | + |
| 2023 | + def __get__(self, obj, type=None): |
| 2024 | + if obj is None: |
| 2025 | + return self |
| 2026 | + else: |
| 2027 | + return FrozenMethod(self.frozen, self.closure, obj) |
| 2028 | + |
| 2029 | + class FrozenMethod(FrozenFunction): |
| 2030 | + """ |
| 2031 | + A FrozenMethod currying the object into the __call__ method. |
| 2032 | +
|
| 2033 | + If the ``freeze`` decorator is applied to a method of some class, it has |
| 2034 | + to call the internal frozen function with the ``self`` argument. To this |
| 2035 | + end we implement the ``__get__`` method of the frozen function, to |
| 2036 | + return a ``FrozenMethod``, which holds a reference to the object. |
| 2037 | + The ``__call__`` method of the ``FrozenMethod`` then supplies the object |
| 2038 | + in addition to the arguments to the internal function. |
| 2039 | + """ |
| 2040 | + def __init__(self, frozen, closure, obj) -> None: |
| 2041 | + self.obj = obj |
| 2042 | + self.frozen = frozen |
| 2043 | + self.closure = closure |
| 2044 | + |
| 2045 | + def __call__(self, *args, **kwargs): |
| 2046 | + _state = state(self.obj, *args, **kwargs) if state is not None else None |
| 2047 | + return self.frozen([self.closure, _state], self.obj, *args, **kwargs) |
| 2048 | + |
| 2049 | + return functools.wraps(f)(FrozenFunction(f)) |
| 2050 | + |
| 2051 | + if f is not None: |
| 2052 | + return decorator(f) |
| 2053 | + else: |
| 2054 | + return decorator |
| 2055 | + |
| 2056 | + |
| 2057 | +del F |
| 2058 | +del F2 |
1881 | 2059 |
|
1882 | 2060 | def assert_true(
|
1883 | 2061 | cond,
|
|
0 commit comments