-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathdistributed_iterable_dataset.py
More file actions
61 lines (48 loc) · 2.16 KB
/
distributed_iterable_dataset.py
File metadata and controls
61 lines (48 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import torch
class DistributedIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
self.dataset_name = dataset_name
self.local_rank = local_rank
self.world_size = world_size
self.num_workers = num_workers
self.rng = random.Random()
self.data_paths = None
self.epoch_seed = None
def get_data_paths(self, *args, **kwargs):
raise NotImplementedError
def set_epoch(self, seed=42):
self.epoch_seed = seed
if self.data_paths is None:
return
if isinstance(self.data_paths[0], tuple):
data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
elif isinstance(self.data_paths[0], str):
data_paths = sorted(self.data_paths)
else:
raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
self.rng.seed(seed)
self.rng.shuffle(data_paths)
num_files_per_rank = len(data_paths) // self.world_size
local_start = self.local_rank * num_files_per_rank
local_end = (self.local_rank + 1) * num_files_per_rank
self.num_files_per_rank = num_files_per_rank
self.data_paths_per_rank = data_paths[local_start:local_end]
def get_data_paths_per_worker(self):
if self.data_paths is None:
return None
# print('self.data_paths', self.data_paths)
info = torch.utils.data.get_worker_info()
if info is None:
# Single worker: Use all files assigned to the rank
return self.data_paths_per_rank, 0
worker_id = info.id
num_files_per_worker = self.num_files_per_rank // info.num_workers
start = num_files_per_worker * worker_id
end = num_files_per_worker * (worker_id + 1)
data_paths_per_worker = self.data_paths_per_rank[start:end] #### (chunk0.parquet, 0)0是num_row_groups
return data_paths_per_worker[::-1], worker_id
def __iter__(self):
raise NotImplementedError