From b5cb143d1bdd8069f444e2c5a8aa11667562186b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 25 Sep 2024 14:28:30 -0700 Subject: [PATCH] address https://github.com/lucidrains/vector-quantize-pytorch/issues/162 --- pyproject.toml | 2 +- tests/test_readme.py | 5 ++- vector_quantize_pytorch/residual_vq.py | 12 ++++++ .../vector_quantize_pytorch.py | 41 +++++++++++++------ 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2160d25..9a69b77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.17.4" +version = "1.17.5" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_readme.py b/tests/test_readme.py index ed5a874..259210b 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -65,10 +65,12 @@ def test_vq_mask(): @pytest.mark.parametrize('implicit_neural_codebook', (True, False)) @pytest.mark.parametrize('use_cosine_sim', (True, False)) @pytest.mark.parametrize('train', (True, False)) +@pytest.mark.parametrize('shared_codebook', (True, False)) def test_residual_vq( implicit_neural_codebook, use_cosine_sim, - train + train, + shared_codebook ): from vector_quantize_pytorch import ResidualVQ @@ -78,6 +80,7 @@ def test_residual_vq( codebook_size = 128, implicit_neural_codebook = implicit_neural_codebook, use_cosine_sim = use_cosine_sim, + shared_codebook = shared_codebook ) x = torch.randn(1, 256, 32) diff --git a/vector_quantize_pytorch/residual_vq.py b/vector_quantize_pytorch/residual_vq.py index 83e262b..75de5ee 100644 --- a/vector_quantize_pytorch/residual_vq.py +++ b/vector_quantize_pytorch/residual_vq.py @@ -137,6 +137,11 @@ def __init__( ema_update = False ) + if shared_codebook: + vq_kwargs.update( + manual_ema_update = True + ) + self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)]) assert all([not vq.has_projections for vq in self.layers]) @@ -157,6 +162,8 @@ def __init__( # sharing codebook logic + self.shared_codebook = shared_codebook + if not shared_codebook: return @@ -349,6 +356,11 @@ def forward( all_indices.append(embed_indices) all_losses.append(loss) + # if shared codebook, update ema only at end + + if self.shared_codebook: + first(self.layers)._codebook.update_ema() + # project out, if needed quantized_out = self.project_out(quantized_out) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 9f14e58..13ede02 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -280,6 +280,7 @@ def __init__( gumbel_sample = gumbel_sample, sample_codebook_temp = 1., ema_update = True, + manual_ema_update = False, affine_param = False, sync_affine_param = False, affine_param_batch_decay = 0.99, @@ -290,6 +291,7 @@ def __init__( self.decay = decay self.ema_update = ema_update + self.manual_ema_update = manual_ema_update init_fn = uniform_init if not kmeans_init else torch.zeros embed = init_fn(num_codebooks, codebook_size, dim) @@ -458,6 +460,12 @@ def expire_codes_(self, batch_samples): batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') self.replace(batch_samples, batch_mask = expired_codes) + def update_ema(self): + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') + self.embed.data.copy_(embed_normalized) + @autocast('cuda', enabled = False) def forward( self, @@ -551,11 +559,9 @@ def forward( ema_inplace(self.embed_avg.data, embed_sum, self.decay) - cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) - - embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') - self.embed.data.copy_(embed_normalized) - self.expire_codes_(x) + if not self.manual_ema_update: + self.update_ema() + self.expire_codes_(x) if needs_codebook_dim: quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) @@ -582,11 +588,14 @@ def __init__( gumbel_sample = gumbel_sample, sample_codebook_temp = 1., ema_update = True, + manual_ema_update = False ): super().__init__() self.transform_input = l2norm self.ema_update = ema_update + self.manual_ema_update = manual_ema_update + self.decay = decay if not kmeans_init: @@ -671,6 +680,14 @@ def expire_codes_(self, batch_samples): batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') self.replace(batch_samples, batch_mask = expired_codes) + def update_ema(self): + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') + embed_normalized = l2norm(embed_normalized) + + self.embed.data.copy_(embed_normalized) + @autocast('cuda', enabled = False) def forward( self, @@ -746,13 +763,9 @@ def forward( ema_inplace(self.embed_avg.data, embed_sum, self.decay) - cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True) - - embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') - embed_normalized = l2norm(embed_normalized) - - self.embed.data.copy_(embed_normalized) - self.expire_codes_(x) + if not self.manual_ema_update: + self.update_ema() + self.expire_codes_(x) if needs_codebook_dim: quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) @@ -802,6 +815,7 @@ def __init__( sync_codebook = None, sync_affine_param = False, ema_update = True, + manual_ema_update = False, learnable_codebook = False, in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook affine_param = False, @@ -881,7 +895,8 @@ def __init__( learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook, sample_codebook_temp = sample_codebook_temp, gumbel_sample = gumbel_sample_fn, - ema_update = ema_update + ema_update = ema_update, + manual_ema_update = manual_ema_update ) if affine_param: