From ce4a4fcc437f1731e65fd45cd9ebbd6a1b70033a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Nov 2024 09:29:10 -0800 Subject: [PATCH] commit loss weighting for sim vq --- pyproject.toml | 2 +- vector_quantize_pytorch/sim_vq.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 671334e..ad856ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.20.7" +version = "1.20.8" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index f264ffe..a873173 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -44,6 +44,7 @@ def __init__( accept_image_fmap = False, rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize input_to_quantize_commit_loss_weight = 0.25, + commitment_weight = 1., frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection ): super().__init__() @@ -74,6 +75,10 @@ def __init__( self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight + # total commitment loss weight + + self.commitment_weight = commitment_weight + @property def codebook(self): return self.code_transform(self.frozen_codebook) @@ -132,7 +137,7 @@ def forward( indices = inverse_pack(indices, 'b *') - return quantized, indices, commit_loss + return quantized, indices, commit_loss * self.commitment_weight # main