66import math
77import pickle
88import types
9- from collections .abc import Callable , Generator , Iterable
9+ from collections .abc import Callable , Generator , Iterable , Iterator
1010from functools import wraps
1111from types import ModuleType
1212from typing import (
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
512512 convert them to/from PyTrees.
513513 """
514514
515- obj : T
515+ _obj : Any
516+ _is_iter : bool
516517 _registered : ClassVar [bool ] = False
517- __slots__ : tuple [str , ...] = ("obj" , )
518+ __slots__ : tuple [str , ...] = ("_is_iter" , "_obj" )
518519
519520 def __init__ (self , obj : T ) -> None : # numpydoc ignore=GL08
520521 self ._register ()
521- self .obj = obj
522+ if isinstance (obj , Iterator ):
523+ self ._obj = list (obj )
524+ self ._is_iter = True
525+ else :
526+ self ._obj = obj
527+ self ._is_iter = False
528+
529+ @property
530+ def obj (self ) -> T : # numpydoc ignore=RT01
531+ """Return wrapped object."""
532+ return iter (self ._obj ) if self ._is_iter else self ._obj
522533
523534 @classmethod
524535 def _register (cls ) -> None : # numpydoc ignore=SS06
@@ -531,7 +542,7 @@ def _register(cls) -> None: # numpydoc ignore=SS06
531542
532543 jax .tree_util .register_pytree_node (
533544 cls ,
534- lambda obj : pickle_flatten (obj , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
545+ lambda instance : pickle_flatten (instance , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
535546 lambda aux_data , children : pickle_unflatten (children , aux_data ), # pyright: ignore[reportUnknownArgumentType]
536547 )
537548 cls ._registered = True
@@ -556,6 +567,7 @@ def jax_autojit(
556567 - Automatically descend into non-array return values and find ``jax.Array`` objects
557568 inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
558569 tracer objects with concrete arrays.
570+ - Returned iterators are immediately completely consumed.
559571
560572 See Also
561573 --------
0 commit comments