Skip to content

Commit

Permalink
only sync seed during training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent cd35a68 commit 35c4c84
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.19.2"
version = "1.19.3"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion vector_quantize_pytorch/residual_fsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def forward(

forward_kwargs = dict(
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
)

# invoke residual vq on each group
Expand Down
2 changes: 1 addition & 1 deletion vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(
forward_kwargs = dict(
mask = mask,
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
)

# invoke residual vq on each group
Expand Down
3 changes: 1 addition & 2 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def forward(

rand = random.Random(rand_quantize_dropout_fixed_seed)


rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

if quant_dropout_multiple_of != 1:
Expand Down Expand Up @@ -496,7 +495,7 @@ def forward(
sample_codebook_temp = sample_codebook_temp,
mask = mask,
freeze_codebook = freeze_codebook,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
)

# invoke residual vq on each group
Expand Down

0 comments on commit 35c4c84

Please sign in to comment.