Skip to content

Commit

Permalink
still need comit loss from quantize to input to optimize linear proje…
Browse files Browse the repository at this point in the history
…ction in SimVQ, best combination is rotation trick without straight though and without commit loss from input to quantize
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent 72ede73 commit cd0fa8e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.0"
version = "1.20.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
19 changes: 9 additions & 10 deletions vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(
codebook_size,
init_fn: Callable = identity,
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
commit_loss_input_to_quantize_weight = 0.25,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
input_to_quantize_commit_loss_weight = 0.25,
):
super().__init__()
self.accept_image_fmap = accept_image_fmap
Expand All @@ -59,11 +59,10 @@ def __init__(
# https://arxiv.org/abs/2410.06424

self.rotation_trick = rotation_trick
self.register_buffer('zero', torch.tensor(0.), persistent = False)

# commit loss weighting - weighing input to quantize a bit less is crucial for it to work

self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight

def forward(
self,
Expand All @@ -83,18 +82,18 @@ def forward(

quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)

# commit loss and straight through, as was done in the paper

commit_loss = F.mse_loss(x.detach(), quantized)

if self.rotation_trick:
# rotation trick from @cfifty

quantized = rotate_from_to(quantized, x)

commit_loss = self.zero
else:
# commit loss and straight through, as was done in the paper

commit_loss = (
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight +
F.mse_loss(x.detach(), quantized)
commit_loss +
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
)

quantized = (quantized - x).detach() + x
Expand Down

0 comments on commit cd0fa8e

Please sign in to comment.