diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index 3681ba9..67e7f0b 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -15,7 +15,7 @@ num_codes = 256 seed = 1234 -rotation_trick = True # rotation trick instead ot straight-through +rotation_trick = True # rotation trick instead ot straight-through use_mlp = True # use a one layer mlp with relu instead of linear device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/pyproject.toml b/pyproject.toml index ad856ee..a84b852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.20.8" +version = "1.20.9" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_readme.py b/tests/test_readme.py index 49d541f..0f4dfa5 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -385,7 +385,7 @@ def test_residual_sim_vq(): dim = 512, num_quantizers = 4, codebook_size = 1024, - accept_image_fmap = True + channel_first = True ) x = torch.randn(1, 512, 32, 32) diff --git a/vector_quantize_pytorch/residual_sim_vq.py b/vector_quantize_pytorch/residual_sim_vq.py index 9d767a0..2aea9ca 100644 --- a/vector_quantize_pytorch/residual_sim_vq.py +++ b/vector_quantize_pytorch/residual_sim_vq.py @@ -58,20 +58,20 @@ def __init__( quantize_dropout = False, quantize_dropout_cutoff_index = 0, quantize_dropout_multiple_of = 1, - accept_image_fmap = False, + channel_first = False, rotation_trick = True, # rotation trick from @cfifty, on top of sim vq **sim_vq_kwargs ): super().__init__() assert heads == 1, 'residual vq is not compatible with multi-headed codes' - self.accept_image_fmap = accept_image_fmap + self.channel_first = channel_first self.num_quantizers = num_quantizers # define sim vq across layers - self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, accept_image_fmap = accept_image_fmap, **sim_vq_kwargs) for _ in range(num_quantizers)]) + self.layers = ModuleList([SimVQ(dim = dim, codebook_size = codebook_size, rotation_trick = rotation_trick, channel_first = channel_first, **sim_vq_kwargs) for _ in range(num_quantizers)]) # quantize dropout @@ -100,7 +100,7 @@ def get_codes_from_indices(self, indices): batch, quantize_dim = indices.shape[0], indices.shape[-1] - # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + # may also receive indices in the shape of 'b h w q' (images) indices, inverse = pack_one(indices, 'b * q') @@ -122,11 +122,11 @@ def get_codes_from_indices(self, indices): all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.) - # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + # if (channel_first = True) then return shape (quantize, batch, height, width, dimension) all_codes = inverse(all_codes, 'q b * d') - if self.accept_image_fmap: + if self.channel_first: all_codes = rearrange(all_codes, 'q b ... d -> q b d ...') return all_codes @@ -139,13 +139,10 @@ def get_output_from_indices(self, indices): def forward( self, x, - indices: Tensor | list[Tensor] | None = None, return_all_codes = False, rand_quantize_dropout_fixed_seed = None ): - num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device - - assert not (self.accept_image_fmap and exists(indices)) + num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device quantized_out = 0. residual = x @@ -153,9 +150,6 @@ def forward( all_losses = [] all_indices = [] - if isinstance(indices, list): - indices = torch.stack(indices) - should_quantize_dropout = self.training and self.quantize_dropout and not return_loss # sample a layer index at which to dropout further residual quantization @@ -175,7 +169,7 @@ def forward( if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 - null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.channel_first else tuple(x.shape[:2]) null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) null_loss = torch.full((1,), 0., device = device, dtype = x.dtype) diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index a873173..427760b 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -41,7 +41,7 @@ def __init__( codebook_size, codebook_transform: Module | None = None, init_fn: Callable = identity, - accept_image_fmap = False, + channel_first = False, rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize input_to_quantize_commit_loss_weight = 0.25, commitment_weight = 1., @@ -49,7 +49,7 @@ def __init__( ): super().__init__() self.codebook_size = codebook_size - self.accept_image_fmap = accept_image_fmap + self.channel_first = channel_first frozen_codebook_dim = default(frozen_codebook_dim, dim) codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5) @@ -92,7 +92,7 @@ def indices_to_codes( frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices) quantized = self.code_transform(frozen_codes) - if self.accept_image_fmap: + if self.channel_first: quantized = rearrange(quantized, 'b ... d -> b d ...') return quantized @@ -101,9 +101,10 @@ def forward( self, x ): - if self.accept_image_fmap: - x = rearrange(x, 'b d h w -> b h w d') - x, inverse_pack = pack_one(x, 'b * d') + if self.channel_first: + x = rearrange(x, 'b d ... -> b ... d') + + x, inverse_pack = pack_one(x, 'b * d') implicit_codebook = self.codebook @@ -131,11 +132,11 @@ def forward( quantized = (quantized - x).detach() + x - if self.accept_image_fmap: - quantized = inverse_pack(quantized) - quantized = rearrange(quantized, 'b h w d-> b d h w') + quantized = inverse_pack(quantized) + indices = inverse_pack(indices, 'b *') - indices = inverse_pack(indices, 'b *') + if self.channel_first: + quantized = rearrange(quantized, 'b ... d-> b d ...') return quantized, indices, commit_loss * self.commitment_weight @@ -153,7 +154,7 @@ def forward( nn.Linear(1024, 512) ), codebook_size = 1024, - accept_image_fmap = True + channel_first = True ) quantized, indices, commit_loss = sim_vq(x)