Skip to content

Commit

Permalink
fix eval mode for implicit neural codebooks and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 15, 2024
1 parent c3f33b0 commit 3ad24a3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.16.0"
version = "1.16.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
6 changes: 3 additions & 3 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def test_residual_vq(
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
dim = 256,
dim = 32,
num_quantizers = 8,
codebook_size = 1024,
codebook_size = 128,
implicit_neural_codebook = implicit_neural_codebook,
use_cosine_sim = use_cosine_sim,
)

x = torch.randn(1, 1024, 256)
x = torch.randn(1, 256, 32)

quantized, indices, commit_loss = residual_vq(x)
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
Expand Down
24 changes: 14 additions & 10 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,6 @@ def kmeans(

return means, bins

def batched_embedding(indices, embeds):
batch, dim = indices.shape[1], embeds.shape[-1]
indices = repeat(indices, 'h b n -> h b n d', d = dim)
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
return embeds.gather(2, indices)

# distributed helpers

@cache
Expand Down Expand Up @@ -521,17 +515,22 @@ def forward(

embed_ind = unpack_one(embed_ind, 'h *')

if exists(codebook_transform_fn):
transformed_embed = unpack_one(transformed_embed, 'h * c d')

if self.training:
unpacked_onehot = unpack_one(embed_onehot, 'h * c')

if exists(codebook_transform_fn):
transformed_embed = unpack_one(transformed_embed, 'h * c d')
quantize = einsum('h b n c, h b n c d -> h b n d', unpacked_onehot, transformed_embed)
else:
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)

else:
quantize = batched_embedding(embed_ind, embed)
if exists(codebook_transform_fn):
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
else:
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)

if self.training and self.ema_update and not freeze_codebook:

Expand Down Expand Up @@ -715,17 +714,22 @@ def forward(
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
embed_ind = unpack_one(embed_ind, 'h *')

if exists(codebook_transform_fn):
transformed_embed = unpack_one(transformed_embed, 'h * c d')

if self.training:
unpacked_onehot = unpack_one(embed_onehot, 'h * c')

if exists(codebook_transform_fn):
transformed_embed = unpack_one(transformed_embed, 'h * c d')
quantize = einsum('h b n c, h b n c d -> h b n d', unpacked_onehot, transformed_embed)
else:
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)

else:
quantize = batched_embedding(embed_ind, embed)
if exists(codebook_transform_fn):
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
else:
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)

if self.training and self.ema_update and not freeze_codebook:
if exists(mask):
Expand Down

0 comments on commit 3ad24a3

Please sign in to comment.