Skip to content

Commit

Permalink
✨Enhanced async_reset
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 23, 2024
1 parent 32aa09d commit 81989bb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion core/learn/data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..schema import IDataset
from ..schema import DataBundle
from ..schema import IAsyncDataset
from ..schema import AsyncDataLoaderIter
from ..schema import TDs
from ..toolkit import np_batch_to_tensor
from ..constants import INPUT_KEY
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, x: arr_type, y: Optional[arr_type] = None):
def __len__(self) -> int:
return len(self.x)

def async_reset(self) -> None:
def async_reset(self, _: AsyncDataLoaderIter) -> None:
self._map: Dict[int, Tuple[arr_type, Optional[arr_type]]] = {}

def async_submit(self, cursor: int, index: Any) -> bool:
Expand Down
12 changes: 9 additions & 3 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,14 @@ def __getitems__(self, indices: List[int]) -> Any: # pragma: no cover
raise NotImplementedError("should not call `__getitems__` of an async dataset")

@abstractmethod
def async_reset(self) -> None:
"""reset the dataset at the beginning of each epoch"""
def async_reset(self, async_iter: "AsyncDataLoaderIter") -> None:
"""
reset the dataset at the beginning of each epoch
> we provide the `async_iter` object to allow the dataset to interact with the
> `AsyncDataLoaderIter` object. this might be very helpful if exceptions are raised
> and the dataset needs to do some cleanup works.
"""

@abstractmethod
def async_submit(self, cursor: int, index: Any) -> bool:
Expand Down Expand Up @@ -309,7 +315,7 @@ def _initialize(self) -> None:
self._drained = False
self._queue_cursor = 0
self._results = {}
self._dataset.async_reset()
self._dataset.async_reset(self)
self._finalized = False
self._initialized = True

Expand Down

0 comments on commit 81989bb

Please sign in to comment.