Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,19 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:


def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
"""Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
algorithm."""
"""Generates a sequence of seeds from a base seed, worker id and rank using hash-based mixing followed by the
linear congruential generator (LCG) algorithm."""
# Combine base seed, worker id and rank into a unique 64-bit number
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank

# Apply hash-based mixing (MurmurHash3 finalizer) to distribute bits uniformly
# This ensures that small base seeds don't result in zeros in lower bits
combined_seed ^= combined_seed >> 33
combined_seed = (combined_seed * 0xFF51AFD7ED558CCD) & ((1 << 64) - 1)
combined_seed ^= combined_seed >> 33
combined_seed = (combined_seed * 0xC4CEB9FE1A85EC53) & ((1 << 64) - 1)
combined_seed ^= combined_seed >> 33

seeds = []
for _ in range(count):
# x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Fixed

- Fix `_generate_seed_sequence_sampling` function not producing unique seeds ([#21399](https://github.com/Lightning-AI/pytorch-lightning/pull/21399))


## [2.6.0] - 2025-11-28

Expand Down
21 changes: 21 additions & 0 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from lightning.fabric.utilities.seed import (
_collect_rng_states,
_generate_seed_sequence,
_set_rng_states,
pl_worker_init_function,
reset_seed,
Expand Down Expand Up @@ -153,3 +154,23 @@ def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
assert len(stdlib_rands) == num_ranks * num_workers
assert len(numpy_rands) == num_ranks * num_workers
assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks


def test_generate_seed_sequence_no_collision():
"""Test that _generate_seed_sequence produces unique seeds for different base seeds."""
base_seeds = [0, 1, 42, 123, 999, 12345]
generated_seeds = []
random_outputs = []

for base_seed in base_seeds:
seed_everything(base_seed)
process_seed = torch.initial_seed()
generated_seed = _generate_seed_sequence(process_seed, worker_id=0, global_rank=0, count=1)[0]
generated_seeds.append(generated_seed)
torch.manual_seed(generated_seed)
random_outputs.append(tuple(torch.randn(10).tolist()))

assert len(set(generated_seeds)) == len(generated_seeds), (
"Generated seeds should be unique for different base seeds"
)
assert len(set(random_outputs)) == len(random_outputs), "Random outputs should be unique for different base seeds"
Loading