Skip to content

Commit

Permalink
allow for 1-dimensional channel first training for residual sim vq, for
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 3, 2025
1 parent 55fa8f1 commit 5ae9c79
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.21.5"
version = "1.21.7"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/residual_sim_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def forward(
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2])
null_indices_shape = (x.shape[0], *x.shape[2:]) if self.channel_first else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
null_loss = torch.full((), 0., device = device, dtype = x.dtype)

# save all inputs across layers, for use during expiration at end under shared codebook setting

Expand Down

0 comments on commit 5ae9c79

Please sign in to comment.