diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 5c60567cc17..acbe50e3882 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -169,6 +169,19 @@ def _convert_to_arrow( yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True)) +def shift_ex_examples_rngs(ex_iterable: "_BaseExamplesIterable", value: int) -> "_BaseExamplesIterable": + """We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable.""" + + def set_seed_recursively(ex_iterable): + if hasattr(ex_iterable, "shift_rngs"): + ex_iterable = ex_iterable.shift_rngs(value) + if hasattr(ex_iterable, "ex_iterable"): + ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable) + return ex_iterable + + return set_seed_recursively(ex_iterable) + + class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" @@ -283,6 +296,14 @@ def __init__( super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -390,6 +411,14 @@ def __init__( super().__init__(generate_tables_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesArrowExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -1031,6 +1060,15 @@ def __init__( self.generator = deepcopy(generator) self.probabilities = probabilities + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return RandomlyCyclingMultiSourcesExamplesIterable( + ex_iterables=self.ex_iterables, + generator=np.random.default_rng(seed=new_seed), + probabilities=self.probabilities, + stopping_strategy=self.stopping_strategy, + ) + @property def is_typed(self): return self.ex_iterables[0].is_typed @@ -1628,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat self.buffer_size = buffer_size self.generator = generator + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return BufferShuffledExamplesIterable( + ex_iterable=self.ex_iterable, + buffer_size=self.buffer_size, + generator=np.random.default_rng(seed=new_seed), + ) + @property def is_typed(self): return self.ex_iterable.is_typed @@ -2372,6 +2418,7 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) + ex_iterable = shift_ex_examples_rngs(ex_iterable=ex_iterable, value=worker_info.id) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), "epoch": self.epoch, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 1bca866bdf8..583f5dab51a 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1553,6 +1553,77 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa assert len(result) == 10 +@require_torch +def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(): + # GH 7567 + from torch.utils.data import DataLoader, get_worker_info + + def gen(shard): + worker_info = get_worker_info() + for i in range(100): + yield {"value": i, "worker_id": worker_info.id} + + num_workers = 20 + ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))}) + ds = ds.shuffle(buffer_size=100, seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + # This will fail with the chance 1/100 ** 20! + assert len(set(values)) != 1, "Make sure not all values are identical" + + +@require_torch +def test_iterable_dataset_interleave_dataset_with_multiple_workers(): + # GH 7567 + from torch.utils.data import DataLoader + + def gen(shard, value): + for i in range(100): + yield {"value": value} + + num_workers = 20 + ds = [ + IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(10) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + assert len(set(values)) != 1, "Make sure not all values are identical" + + +@require_torch +def test_iterable_dataset_interleave_dataset_deterministic_across_iterations(): + # GH 7567 + from torch.utils.data import DataLoader + + def gen(shard, value): + for i in range(50): + yield {"value": value, "id": i} + + num_workers = 10 + ds = [ + IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(5) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + # First iteration + first_result = list(dataloader) + + # Second iteration + second_result = list(dataloader) + + assert first_result == second_result, "Results should be identical across iterations when using same seed" + + @pytest.mark.parametrize("batch_size", [4, 5]) @pytest.mark.parametrize("drop_last_batch", [False, True]) def test_iterable_dataset_iter_batch(batch_size, drop_last_batch):