|
19 | 19 | import sys as _sys
|
20 | 20 | if _sys.version_info < (3, 11):
|
21 | 21 | try:
|
22 |
| - from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable |
| 22 | + from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar |
23 | 23 | except ImportError:
|
24 | 24 | raise RuntimeError(
|
25 | 25 | "Dr.Jit requires the 'typing_extensions' package on Python <3.11")
|
26 | 26 | else:
|
27 |
| - from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable |
| 27 | + from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar |
28 | 28 |
|
29 | 29 | from .ast import syntax, hint
|
30 | 30 | from .interop import wrap
|
@@ -2401,6 +2401,256 @@ def binary_search(start, end, pred):
|
2401 | 2401 |
|
2402 | 2402 | return start
|
2403 | 2403 |
|
| 2404 | +# Represents the frozen function passed to the decorator without arguments |
| 2405 | +F = TypeVar("F") |
| 2406 | +# Represents the frozen function passed to the decorator with arguments |
| 2407 | +F2 = TypeVar("F2") |
| 2408 | + |
| 2409 | +@overload |
| 2410 | +def freeze( |
| 2411 | + f: None = None, |
| 2412 | + *, |
| 2413 | + state_fn: Optional[Callable], |
| 2414 | + limit: Optional[int] = None, |
| 2415 | + warn_after: int = 10, |
| 2416 | + backend: Optional[JitBackend] = None, |
| 2417 | +) -> Callable[[F], F]: |
| 2418 | + """ |
| 2419 | + Decorator to "freeze" functions, which improves efficiency by removing |
| 2420 | + repeated JIT tracing overheads. |
| 2421 | +
|
| 2422 | + In general, Dr.Jit traces computation and then compiles and launches kernels |
| 2423 | + containing this trace (see the section on :ref:`evaluation <eval>` for |
| 2424 | + details). While the compilation step can often be skipped via caching, the |
| 2425 | + tracing cost can still be significant especially when repeatedly evaluating |
| 2426 | + complex models, e.g., as part of an optimization loop. |
| 2427 | +
|
| 2428 | + The :py:func:`@dr.freeze <drjit.freeze>` decorator adresses this problem by |
| 2429 | + altogether removing the need to trace repeatedly. For example, consider the |
| 2430 | + following decorated function: |
| 2431 | +
|
| 2432 | + .. code-block:: python |
| 2433 | +
|
| 2434 | + @dr.freeze |
| 2435 | + def f(x, y, z): |
| 2436 | + return ... # Complicated code involving the arguments |
| 2437 | +
|
| 2438 | + Dr.Jit will trace the first call to the decorated function ``f()``, while |
| 2439 | + collecting additional information regarding the nature of the function's inputs |
| 2440 | + and regarding the CPU/GPU kernel launches representing the body of ``f()``. |
| 2441 | +
|
| 2442 | + If the function is subsequently called with *compatible* arguments (more on |
| 2443 | + this below), it will immediately launch the previously made CPU/GPU kernels |
| 2444 | + without re-tracing, which can substantially improve performance. |
| 2445 | +
|
| 2446 | + When :py:func:`@dr.freeze <drjit.freeze>` detects *incompatibilities* (e.g., ``x`` |
| 2447 | + having a different type compared to the previous call), it will conservatively |
| 2448 | + re-trace the body and keep track of another potential input configuration. |
| 2449 | +
|
| 2450 | + Frozen functions support arbitrary :ref:`PyTrees <pytrees>` as function |
| 2451 | + arguments and return values. |
| 2452 | +
|
| 2453 | + The following may trigger re-tracing: |
| 2454 | +
|
| 2455 | + - Changes in the **type** of an argument or :ref:`PyTree <pytrees>` element. |
| 2456 | + - Changes in the **length** of a container (``list``, ``tuple``, ``dict``). |
| 2457 | + - Changes of **dictionary keys** or **field names** of dataclasses. |
| 2458 | + - Changes in the AD status (:py:`dr.grad_enabled() <drjit.grad_enabled>`) of a variable. |
| 2459 | + - Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()``. |
| 2460 | +
|
| 2461 | + The following more technical conditions also trigger re-tracing: |
| 2462 | + - A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``). |
| 2463 | + - The sets of variables of the same size change. In the example above, this |
| 2464 | + would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)`` |
| 2465 | + subsequently. |
| 2466 | + - When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the |
| 2467 | + memory can be aligned or unaligned. A re-tracing step is needed when this |
| 2468 | + status changes. |
| 2469 | +
|
| 2470 | + These all correspond to situations where the generated kernel code may need to |
| 2471 | + change, and the system conservatively re-traces to ensure correctness. |
| 2472 | +
|
| 2473 | + Frozen functions support arguments with a different variable *width* (see |
| 2474 | + :py:func:`dr.with() <drjit.width>`) without re-tracing, as long as the sets of |
| 2475 | + variables of the same width stay consistent. |
| 2476 | +
|
| 2477 | + Some constructions are problematic and should be avoided in frozen functions. |
| 2478 | +
|
| 2479 | + - The function :py:func:`dr.width() <drjit.width>` returns an integer literal |
| 2480 | + that may be merged into the generated code. If the frozen function is later |
| 2481 | + rerun with differently-sized arguments, the executed kernels will still |
| 2482 | + reference the old size. One exception to this rule are constructions like |
| 2483 | + `dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on |
| 2484 | + the width value. |
| 2485 | +
|
| 2486 | + **Advanced features**. The :py:func:`@dr.freeze <drjit.freeze>` decorator takes |
| 2487 | + several optional parameters that are helpful in certain situations. |
| 2488 | +
|
| 2489 | + - **Warning when re-tracing happens too often**: Incompatible arguments trigger |
| 2490 | + re-tracing, which can mask issues where *accidentally* incompatible arguments |
| 2491 | + keep :py:func:`@dr.freeze <drjit.freeze>` from producing the expected |
| 2492 | + performance benefits. |
| 2493 | +
|
| 2494 | + In such situations, it can be helpful to warn and identify changing |
| 2495 | + parameters by name. This feature is enabled and set to ``10`` by default. |
| 2496 | +
|
| 2497 | + .. code-block:: pycon |
| 2498 | +
|
| 2499 | + >>> @dr.freeze(warn_after=1) |
| 2500 | + >>> def f(x): |
| 2501 | + ... return x |
| 2502 | + ... |
| 2503 | + >>> f(Int(1)) |
| 2504 | + >>> f(Float(1)) |
| 2505 | + The frozen function has been recorded 2 times, this indicates a problem |
| 2506 | + with how the frozen function is being called. For example, calling it |
| 2507 | + with changing python values such as an index. For more information about |
| 2508 | + which variables changed set the log level to ``LogLevel::Debug``. |
| 2509 | +
|
| 2510 | + - **Limiting memory usage**. Storing kernels for many possible input |
| 2511 | + configuration requires device memory, which can become problematic. Set the |
| 2512 | + ``limit=`` parameter to enable a LRU cache. This is useful when calls to a |
| 2513 | + function are mostly compatible but require occasional re-tracing. |
| 2514 | +
|
| 2515 | + Args: |
| 2516 | + limit (Optional[int]): An optional integer specifying the maximum number of |
| 2517 | + stored configurations. Once this limit is reached, incompatible calls |
| 2518 | + requiring re-tracing will cause the last used configuration to be dropped. |
| 2519 | +
|
| 2520 | + warn_after (int): When the number of re-tracing steps exceeds this value, |
| 2521 | + Dr.Jit will generate a warning that explains which variables changed |
| 2522 | + between calls to the function. |
| 2523 | +
|
| 2524 | + state_fn (Optional[Callable]): This optional callable can specify additional |
| 2525 | + state to identifies the configuration. ``state_fn`` will be called with |
| 2526 | + the same arguments as that of the decorated function. It should return a |
| 2527 | + traversable object (e.g., a list or tuple) that is conceptually treated |
| 2528 | + as if it was another input of the function. |
| 2529 | +
|
| 2530 | + backend (Optional[JitBackend]): If no inputs are given when calling the |
| 2531 | + frozen function, the backend used has to be specified using this argument. |
| 2532 | + It must match the backend used for computation within the function. |
| 2533 | + """ |
| 2534 | + |
| 2535 | + |
| 2536 | +@overload |
| 2537 | +def freeze( |
| 2538 | + f: F, |
| 2539 | + *, |
| 2540 | + state_fn: Optional[Callable] = None, |
| 2541 | + limit: Optional[int] = None, |
| 2542 | + warn_after: int = 10, |
| 2543 | + backend: Optional[JitBackend] = None, |
| 2544 | +) -> F: ... |
| 2545 | + |
| 2546 | + |
| 2547 | +def freeze( |
| 2548 | + f: Optional[F] = None, |
| 2549 | + *, |
| 2550 | + state_fn: Optional[Callable] = None, |
| 2551 | + limit: Optional[int] = None, |
| 2552 | + warn_after: int = 10, |
| 2553 | + backend: Optional[JitBackend] = None, |
| 2554 | +) -> Union[F, Callable[[F2], F2]]: |
| 2555 | + limit = limit if limit is not None else -1 |
| 2556 | + backend = backend if backend is not None else JitBackend.Invalid |
| 2557 | + |
| 2558 | + def decorator(f): |
| 2559 | + """ |
| 2560 | + Internal decorator, returned in ``dr.freeze`` was used with arguments. |
| 2561 | + """ |
| 2562 | + import functools |
| 2563 | + import inspect |
| 2564 | + |
| 2565 | + def inner(closure, *args, **kwargs): |
| 2566 | + """ |
| 2567 | + This inner function is the one that gets actually frozen. It receives |
| 2568 | + any additional state such as closures or state specified with the |
| 2569 | + ``state`` lambda, and allows for traversal of it. |
| 2570 | + """ |
| 2571 | + return f(*args, **kwargs) |
| 2572 | + |
| 2573 | + class FrozenFunction: |
| 2574 | + def __init__(self, f) -> None: |
| 2575 | + closure = inspect.getclosurevars(f) |
| 2576 | + self.closure = (closure.nonlocals, closure.globals) |
| 2577 | + self.frozen = detail.FrozenFunction( |
| 2578 | + inner, |
| 2579 | + limit, |
| 2580 | + warn_after, |
| 2581 | + backend, |
| 2582 | + ) |
| 2583 | + |
| 2584 | + def __call__(self, *args, **kwargs): |
| 2585 | + _state = state_fn(*args, **kwargs) if state_fn is not None else None |
| 2586 | + return self.frozen([self.closure, _state], *args, **kwargs) |
| 2587 | + |
| 2588 | + @property |
| 2589 | + def n_recordings(self): |
| 2590 | + """ |
| 2591 | + Represents the number of times the function was recorded. This |
| 2592 | + includes occasions where it was recorded due to a dry-run failing. |
| 2593 | + It does not necessarily correspond to the number of recordings |
| 2594 | + currently cached see ``n_cached_recordings`` for that. |
| 2595 | + """ |
| 2596 | + return self.frozen.n_recordings |
| 2597 | + |
| 2598 | + @property |
| 2599 | + def n_cached_recordings(self): |
| 2600 | + """ |
| 2601 | + Represents the number of recordings currently cached of the frozen |
| 2602 | + function. If a recording fails in dry-run mode, it will not create |
| 2603 | + a new recording, but replace the recording that was attemted to be |
| 2604 | + replayed. The number of recordings can also be limited with |
| 2605 | + the ``max_cache_size`` argument. |
| 2606 | + """ |
| 2607 | + return self.frozen.n_cached_recordings |
| 2608 | + |
| 2609 | + def clear(self): |
| 2610 | + """ |
| 2611 | + Clears the recordings of the frozen function, and resets the |
| 2612 | + ``n_recordings`` counter. The reference to the function is still |
| 2613 | + kept, and the frozen function can be called again to re-trace |
| 2614 | + new recordings. |
| 2615 | + """ |
| 2616 | + return self.frozen.clear() |
| 2617 | + |
| 2618 | + def __get__(self, obj, type=None): |
| 2619 | + if obj is None: |
| 2620 | + return self |
| 2621 | + else: |
| 2622 | + return FrozenMethod(self.frozen, self.closure, obj) |
| 2623 | + |
| 2624 | + class FrozenMethod(FrozenFunction): |
| 2625 | + """ |
| 2626 | + A FrozenMethod currying the object into the __call__ method. |
| 2627 | +
|
| 2628 | + If the ``freeze`` decorator is applied to a method of some class, it has |
| 2629 | + to call the internal frozen function with the ``self`` argument. To this |
| 2630 | + end we implement the ``__get__`` method of the frozen function, to |
| 2631 | + return a ``FrozenMethod``, which holds a reference to the object. |
| 2632 | + The ``__call__`` method of the ``FrozenMethod`` then supplies the object |
| 2633 | + in addition to the arguments to the internal function. |
| 2634 | + """ |
| 2635 | + def __init__(self, frozen, closure, obj) -> None: |
| 2636 | + self.obj = obj |
| 2637 | + self.frozen = frozen |
| 2638 | + self.closure = closure |
| 2639 | + |
| 2640 | + def __call__(self, *args, **kwargs): |
| 2641 | + _state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None |
| 2642 | + return self.frozen([self.closure, _state], self.obj, *args, **kwargs) |
| 2643 | + |
| 2644 | + return functools.wraps(f)(FrozenFunction(f)) |
| 2645 | + |
| 2646 | + if f is not None: |
| 2647 | + return decorator(f) |
| 2648 | + else: |
| 2649 | + return decorator |
| 2650 | + |
| 2651 | + |
| 2652 | +del F |
| 2653 | +del F2 |
2404 | 2654 |
|
2405 | 2655 | def assert_true(
|
2406 | 2656 | cond,
|
|
0 commit comments