Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples
"""
raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet")

def shift_rngs(self, value: int) -> None:
def set_seed_recursively(ex_iterable):
if hasattr(ex_iterable, "generator"):
new_seed = ex_iterable.generator.bit_generator.state["state"]["state"] + value
ex_iterable.generator = np.random.default_rng(seed=new_seed)
if hasattr(ex_iterable, "ex_iterable"):
set_seed_recursively(ex_iterable.ex_iterable)

set_seed_recursively(self)

def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet")
Expand Down Expand Up @@ -2372,6 +2382,7 @@ def _iter_pytorch(self):
ex_iterable = ex_iterable.shard_data_sources(
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
)
ex_iterable.shift_rngs(value=worker_info.id)
self._state_dict = {
"examples_iterable": ex_iterable._init_state_dict(),
"epoch": self.epoch,
Expand Down
45 changes: 45 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,51 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa
assert len(result) == 10


@require_torch
def test_iterable_dataset_shuffle_with_multiple_workers_different_rng():
# GH 7567
from torch.utils.data import DataLoader, get_worker_info

def gen(shard):
worker_info = get_worker_info()
for i in range(100):
yield {"value": i, "worker_id": worker_info.id}

num_workers = 20
ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))})
ds = ds.shuffle(buffer_size=100, seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

result = list(dataloader)
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
values = [item["value"] for item in single_chunk]
# This will fail with the chance 1/100 ** 20!
assert len(set(values)) != 1, "Make sure not all values are identical"


@require_torch
def test_iterable_dataset_interleave_dataset_with_multiple_workers():
# GH 7567
from torch.utils.data import DataLoader

def gen(shard, value):
for i in range(100):
yield {"value": value}

num_workers = 20
ds = [
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
for i in range(10)
]
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

result = list(dataloader)
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
values = [item["value"] for item in single_chunk]
assert len(set(values)) != 1, "Make sure not all values are identical"


@pytest.mark.parametrize("batch_size", [4, 5])
@pytest.mark.parametrize("drop_last_batch", [False, True])
def test_iterable_dataset_iter_batch(batch_size, drop_last_batch):
Expand Down
Loading