Skip to content

Commit

Permalink
if mask is passed into VQ, output 0 for padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 26, 2024
1 parent 74a27c8 commit 50ec361
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.15.4"
version = "1.15.5"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_vq_mask():
assert torch.allclose(quantized, mask_quantized[:, :512])
assert torch.allclose(indices, mask_indices[:, :512])

assert torch.allclose(mask_quantized[:, 512:], x[:, 512:])
assert (mask_quantized[:, 512:] == 0.).all()
assert (mask_indices[:, 512:] == -1).all()

def test_residual_vq():
Expand Down
11 changes: 10 additions & 1 deletion vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def __init__(
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9,
sync_update_v = 0., # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
return_zeros_for_masked_padding = True
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -855,6 +856,9 @@ def __init__(

self.register_buffer('zero', torch.tensor(0.), persistent = False)

# for variable lengthed sequences, whether to take care of masking out the padding to 0 (or return the original input)
self.return_zeros_for_masked_padding = return_zeros_for_masked_padding

@property
def codebook(self):
codebook = self._codebook.embed
Expand Down Expand Up @@ -1133,11 +1137,16 @@ def calculate_ce_loss(codes):
# if masking, only return quantized for where mask has True

if exists(mask):
masked_out_value = orig_input

if self.return_zeros_for_masked_padding:
masked_out_value = torch.zeros_like(orig_input)

quantize = einx.where(
'b n, b n d, b n d -> b n d',
mask,
quantize,
orig_input
masked_out_value
)

embed_ind = einx.where(
Expand Down

0 comments on commit 50ec361

Please sign in to comment.