From 05a2dab9e424a6a7ba91692e908a4fd004c4778b Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Tue, 22 Oct 2024 15:55:03 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8FUtilized=20`ThreadPoolExecuto?= =?UTF-8?q?r`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/learn/schema.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/core/learn/schema.py b/core/learn/schema.py index 9c452b9..e7adf9c 100644 --- a/core/learn/schema.py +++ b/core/learn/schema.py @@ -38,6 +38,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from accelerate.utils import PrecisionType from accelerate.utils import extract_model_from_parallel +from concurrent.futures import ThreadPoolExecutor from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Dataset from torch.utils.data import WeightedRandomSampler @@ -263,6 +264,7 @@ class AsyncDataLoaderIter(_SingleProcessDataLoaderIter): _queue_cursor: int _dataset: IAsyncDataset _finalized: bool + _results: Dict[int, Any] def __init__(self, loader: "DataLoader"): super().__init__(loader) @@ -274,6 +276,7 @@ def __init__(self, loader: "DataLoader"): ) # pragma: no cover self._finalized = True self._initialized = False + self._pool = ThreadPoolExecutor(max_workers=self.async_prefetch_factor) def __del__(self) -> None: if not self._finalized: @@ -283,22 +286,32 @@ def _initialize(self) -> None: self._queue = None self._drained = False self._queue_cursor = 0 + self._results = {} self._dataset.async_reset() self._finalized = False self._initialized = True def _finalize(self) -> None: + self._pool.shutdown() + self._results.clear() self._dataset.async_finalize() self._finalized = True raise StopIteration - def _submit_next(self) -> None: - cursor = self._queue_cursor - index = self._next_index() + def _async_submit(self, cursor: int, index: Any) -> Any: if not self._dataset.async_submit(cursor, index): # pragma: no cover msg = f"failed to submit async task with cursor={cursor} and index={index}" console.error(msg) raise RuntimeError("failed to submit async task") + data = self._dataset.poll(cursor, index) + if self._pin_memory: # pragma: no cover + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + self._results[cursor] = data + + def _submit_next(self) -> None: + cursor = self._queue_cursor + index = self._next_index() + self._pool.submit(self._async_submit, cursor, index) self._queue.append((cursor, index)) # type: ignore self._queue_cursor = cursor + 1 @@ -322,11 +335,12 @@ def _next_data(self) -> Any: self._submit_next() except StopIteration: self._drained = True - cursor, index = self._queue.pop(0) - data = self._dataset.poll(cursor, index) - if self._pin_memory: # pragma: no cover - data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) - return data + cursor, _ = self._queue.pop(0) + while True: + data = self._results.pop(cursor, None) + if data is not None: + return data + time.sleep(0.005) class DataLoader(TorchDataLoader):