Skip to content

Commit

Permalink
⚡️Utilized ThreadPoolExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 22, 2024
1 parent 2a76276 commit 05a2dab
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass
from accelerate.utils import PrecisionType
from accelerate.utils import extract_model_from_parallel
from concurrent.futures import ThreadPoolExecutor
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler
Expand Down Expand Up @@ -263,6 +264,7 @@ class AsyncDataLoaderIter(_SingleProcessDataLoaderIter):
_queue_cursor: int
_dataset: IAsyncDataset
_finalized: bool
_results: Dict[int, Any]

def __init__(self, loader: "DataLoader"):
super().__init__(loader)
Expand All @@ -274,6 +276,7 @@ def __init__(self, loader: "DataLoader"):
) # pragma: no cover
self._finalized = True
self._initialized = False
self._pool = ThreadPoolExecutor(max_workers=self.async_prefetch_factor)

def __del__(self) -> None:
if not self._finalized:
Expand All @@ -283,22 +286,32 @@ def _initialize(self) -> None:
self._queue = None
self._drained = False
self._queue_cursor = 0
self._results = {}
self._dataset.async_reset()
self._finalized = False
self._initialized = True

def _finalize(self) -> None:
self._pool.shutdown()
self._results.clear()
self._dataset.async_finalize()
self._finalized = True
raise StopIteration

def _submit_next(self) -> None:
cursor = self._queue_cursor
index = self._next_index()
def _async_submit(self, cursor: int, index: Any) -> Any:
if not self._dataset.async_submit(cursor, index): # pragma: no cover
msg = f"failed to submit async task with cursor={cursor} and index={index}"
console.error(msg)
raise RuntimeError("failed to submit async task")
data = self._dataset.poll(cursor, index)
if self._pin_memory: # pragma: no cover
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
self._results[cursor] = data

def _submit_next(self) -> None:
cursor = self._queue_cursor
index = self._next_index()
self._pool.submit(self._async_submit, cursor, index)
self._queue.append((cursor, index)) # type: ignore
self._queue_cursor = cursor + 1

Expand All @@ -322,11 +335,12 @@ def _next_data(self) -> Any:
self._submit_next()
except StopIteration:
self._drained = True
cursor, index = self._queue.pop(0)
data = self._dataset.poll(cursor, index)
if self._pin_memory: # pragma: no cover
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
cursor, _ = self._queue.pop(0)
while True:
data = self._results.pop(cursor, None)
if data is not None:
return data
time.sleep(0.005)


class DataLoader(TorchDataLoader):
Expand Down

0 comments on commit 05a2dab

Please sign in to comment.