From cd0fa8ec48810d98944de7f9e06d67ff1ec93dc5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 09:41:14 -0800 Subject: [PATCH] still need comit loss from quantize to input to optimize linear projection in SimVQ, best combination is rotation trick without straight though and without commit loss from input to quantize --- pyproject.toml | 2 +- vector_quantize_pytorch/sim_vq.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d65aa3c..3a4ae17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.20.0" +version = "1.20.1" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index 7e16da1..29ad136 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -40,8 +40,8 @@ def __init__( codebook_size, init_fn: Callable = identity, accept_image_fmap = False, - rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through - commit_loss_input_to_quantize_weight = 0.25, + rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize + input_to_quantize_commit_loss_weight = 0.25, ): super().__init__() self.accept_image_fmap = accept_image_fmap @@ -59,11 +59,10 @@ def __init__( # https://arxiv.org/abs/2410.06424 self.rotation_trick = rotation_trick - self.register_buffer('zero', torch.tensor(0.), persistent = False) # commit loss weighting - weighing input to quantize a bit less is crucial for it to work - self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight + self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight def forward( self, @@ -83,18 +82,18 @@ def forward( quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) + # commit loss and straight through, as was done in the paper + + commit_loss = F.mse_loss(x.detach(), quantized) + if self.rotation_trick: # rotation trick from @cfifty - quantized = rotate_from_to(quantized, x) - - commit_loss = self.zero else: - # commit loss and straight through, as was done in the paper commit_loss = ( - F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight + - F.mse_loss(x.detach(), quantized) + commit_loss + + F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight ) quantized = (quantized - x).detach() + x