Skip to content

Commit

Permalink
🐛Introduced AsyncIterManager
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 22, 2024
1 parent fdf9afc commit a1e62fb
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,18 @@ def poll(self, cursor: int, index: Any) -> Any:
time.sleep(0.01) # pragma: no cover


class AsyncIterManager:
_cur: Optional["AsyncDataLoaderIter"] = None

@classmethod
def new(cls, fn: Callable[[], "AsyncDataLoaderIter"]) -> "AsyncDataLoaderIter":
if cls._cur is not None:
if not cls._cur._finalized:
cls._cur._cleanup()
cls._cur = fn()
return cls._cur


class AsyncDataLoaderIter(_SingleProcessDataLoaderIter):
_queue: Optional[List[Tuple[int, Any]]]
_drained: bool
Expand All @@ -278,10 +290,6 @@ def __init__(self, loader: "DataLoader"):
self._initialized = False
self._pool = ThreadPoolExecutor(max_workers=self.async_prefetch_factor)

def __del__(self) -> None:
if not self._finalized:
self._cleanup()

def _initialize(self) -> None:
self._queue = None
self._drained = False
Expand Down Expand Up @@ -369,7 +377,7 @@ class DataLoader(TorchDataLoader):

def _get_iterator(self) -> _BaseDataLoaderIter:
if self.num_workers == 0:
return AsyncDataLoaderIter(self)
return AsyncIterManager.new(lambda: AsyncDataLoaderIter(self))
return super()._get_iterator() # pragma: no cover

def __iter__(self) -> Iterator[tensor_dict_type]: # type: ignore
Expand Down

0 comments on commit a1e62fb

Please sign in to comment.