Skip to content

Commit

Permalink
fix residual VQ in eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 5, 2021
1 parent 2352c87 commit 826c620
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.3.8',
version = '0.3.9',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ def forward(self, x):

quantize, embed_ind = self._codebook(x)

commit_loss = 0.

if self.training:
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
quantize = x + (quantize - x).detach()
else:
commit_loss = torch.tensor([0.], device = x.device)

quantize = self.project_out(quantize)

Expand Down

0 comments on commit 826c620

Please sign in to comment.