Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PT: add randomization to bucket batching #1697

Merged
merged 5 commits into from
Feb 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions returnn/torch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,31 @@ class BucketOrderingIterDataPipe(torch.utils.data.IterDataPipe):
"""

def __init__(
self, dataset: torch.utils.data.IterableDataset, *, buckets: Sequence[Tuple[int, int]], length_key: str
self,
dataset: torch.utils.data.IterableDataset,
*,
buckets: Sequence[Tuple[int, int]],
length_key: str,
random_bucket_prob: float = 0.0,
seed: Optional[int] = None,
):
"""
:param dataset: dataset to apply bucket batching to
:param buckets: Bucket configuration as tuples of seq length and max number of seqs in that bucket.
Segments longer than the largest size limit configured in the buckets are dropped. To avoid dropping
any segments make sure your largest bucket allows segments larger than your longest training segment.
:param length_key: data key to take as length measure
:param random_bucket_prob: Probability of putting a segment not into the best-fitting bucket, but into
a randomly chosen still-fitting bucket.
This increases seq length variation within the buckets at the cost of slighly more padding.
:param seed: random seed
"""
self._dataset = dataset
self._length_key = length_key
assert random_bucket_prob >= 0.0
self._random_bucket_prob = random_bucket_prob
self._rng = numpy.random.RandomState()
self._seed = seed % (2**32) if seed is not None else None

assert buckets, "empty bucket batching configuration"
if not all(size > 0 and max_seqs > 0 for size, max_seqs in buckets):
Expand All @@ -367,6 +381,12 @@ def __iter__(self):
if bucket_idx >= len(self._max_seq_lens):
# seg is too long, drop it
continue
if (
self._random_bucket_prob > 0.0
and bucket_idx < len(self._max_seq_lens) - 1
and self._rng.rand() < self._random_bucket_prob
):
bucket_idx = self._rng.randint(bucket_idx, len(self._max_bucket_sizes))
buckets[bucket_idx].append(data_dict)
if len(buckets[bucket_idx]) >= self._max_bucket_sizes[bucket_idx]:
yield buckets[bucket_idx]
Expand All @@ -383,6 +403,21 @@ def __iter__(self):
def __getitem__(self, index):
raise Exception(f"{self.__class__.__name__}.__getitem__ is not supported")

def set_seed(self, seed: int) -> BucketOrderingIterDataPipe:
"""
Sets the seed for the next invocation of ``__iter__``, for compatibility with
``torch.utils.data.graph_settings.apply_random_seed``.
"""
self._seed = seed % (2**32) # seed must be within [0, 2**32) for seeding RandomState
return self

def reset(self):
"""resets the internal state of the data pipe"""
if self._seed is None:
self._seed = int(2**31 + torch.empty((), dtype=torch.int32).random_().item())
self._rng.seed(self._seed)
self._seed = None


def get_batching_iterable_dataset_from_config(
*, dataset: torch.utils.data.IterableDataset, config: Config, train: bool
Expand Down Expand Up @@ -497,7 +532,7 @@ def __init__(
self._buffer_size = buffer_size
self._monotonic_data_keys = monotonic_data_keys
self._rng = numpy.random.RandomState()
self._seed = seed
self._seed = seed % (2**32) if seed is not None else None

def __iter__(self):
# The implementation is very similar to the PostprocessingDataset's combinator LaplaceOrdering.
Expand Down Expand Up @@ -548,7 +583,7 @@ def set_seed(self, seed: int) -> ShufflingDataPipe:
def reset(self):
"""resets the internal state of the data pipe"""
if self._seed is None:
self._seed = int(torch.empty((), dtype=torch.int32).random_().item())
self._seed = int(2**31 + torch.empty((), dtype=torch.int32).random_().item())
self._rng.seed(self._seed)
self._seed = None

Expand Down