From 5ae9c792e77b45751ac636eb72b310f09648aac3 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 3 Feb 2025 11:13:37 -0800 Subject: [PATCH] allow for 1-dimensional channel first training for residual sim vq, for @zaptrem --- pyproject.toml | 2 +- vector_quantize_pytorch/residual_sim_vq.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 285596e..a4d11b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.21.5" +version = "1.21.7" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/residual_sim_vq.py b/vector_quantize_pytorch/residual_sim_vq.py index c899c2b..ba96908 100644 --- a/vector_quantize_pytorch/residual_sim_vq.py +++ b/vector_quantize_pytorch/residual_sim_vq.py @@ -169,9 +169,9 @@ def forward( if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 - null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2]) + null_indices_shape = (x.shape[0], *x.shape[2:]) if self.channel_first else tuple(x.shape[:2]) null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) - null_loss = torch.full((1,), 0., device = device, dtype = x.dtype) + null_loss = torch.full((), 0., device = device, dtype = x.dtype) # save all inputs across layers, for use during expiration at end under shared codebook setting