Skip to content

Commit

Permalink
add code expiry / replacement strategy from soundstream paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 19, 2021
1 parent 9ad29ef commit b1f5d8e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
```

### Cosine Similarity
### Cosine similarity

The <a href="https://openreview.net/forum?id=pfNyExj7z2">Improved VQGAN paper</a> also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting `use_cosine_sim = True`

Expand All @@ -108,6 +108,24 @@ x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
```

### Expiring stale codes

Finally, the SoundStream paper has a scheme where they replace codes that have not been used in a certain number of consecutive batches with a randomly selected vector from the current batch. You can set this threshold for consecutive misses before replacement with `max_codebook_misses_before_expiry` keyword. (I know it is a bit long, but I couldn't think of a better name)

```python
import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
dim = 256,
codebook_size = 512,
max_codebook_misses_before_expiry = 5 # should actively replace any codes that were missed 5 times in a row during training
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
```

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.3.2',
version = '0.3.3',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
68 changes: 53 additions & 15 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@ def ema_inplace(moving_avg, new, decay):
def laplace_smoothing(x, n_categories, eps = 1e-5):
return (x + eps) / (x.sum() + n_categories * eps)

def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
samples = rearrange(x, '... d -> (...) d')
num_samples, dim, dtype, device = *samples.shape, x.dtype, x.device
def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device

if num_samples >= num_clusters:
indices = torch.randperm(num_samples, device=device)[:num_clusters]
if num_samples >= num:
indices = torch.randperm(num_samples, device = device)[:num]
else:
indices = torch.randint(0, num_samples, (num_clusters,), device=device)
indices = torch.randint(0, num_samples, (num,), device = device)

means = samples[indices]
return samples[indices]

def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device

means = sample_vectors(samples, num_clusters)

for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
buckets = dists.max(dim = -1).indices
else:
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
dists = (diffs ** 2).sum(dim = -1)
buckets = dists.argmin(dim = -1)
dists = -(diffs ** 2).sum(dim = -1)

buckets = dists.max(dim = -1).indices
bins = torch.bincount(buckets, minlength = num_clusters)
zero_mask = bins == 0
bins = bins.masked_fill(zero_mask, 1)
Expand Down Expand Up @@ -85,14 +88,18 @@ def init_embed_(self, data):
self.embed_avg.data.copy_(embed.clone())
self.initted.data.copy_(torch.Tensor([True]))

def forward(self, x):
if not self.initted:
self.init_embed_(x)
def replace(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)

def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
embed = self.embed.t()

if not self.initted:
self.init_embed_(flatten)

dist = -(
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ embed
Expand Down Expand Up @@ -144,15 +151,20 @@ def init_embed_(self, data):
self.embed.data.copy_(embed)
self.initted.data.copy_(torch.Tensor([True]))

def replace(self, samples, mask):
samples = l2norm(samples)
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)

def forward(self, x):
shape, dtype = x.shape, x.dtype
flatten = rearrange(x, '... d -> (...) d')
flatten = l2norm(flatten)
embed = l2norm(self.embed)

if not self.initted:
self.init_embed_(flatten)

embed = l2norm(self.embed)
dist = flatten @ embed.t()
embed_ind = dist.max(dim = -1).indices
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
Expand Down Expand Up @@ -187,7 +199,8 @@ def __init__(
eps = 1e-5,
kmeans_init = False,
kmeans_iters = 10,
use_cosine_sim = False
use_cosine_sim = False,
max_codebook_misses_before_expiry = 0
):
super().__init__()
n_embed = default(n_embed, codebook_size)
Expand All @@ -211,20 +224,45 @@ def __init__(
eps = eps
)

self.codebook_size = codebook_size
self.max_codebook_misses_before_expiry = max_codebook_misses_before_expiry

if max_codebook_misses_before_expiry > 0:
codebook_misses = torch.zeros(codebook_size)
self.register_buffer('codebook_misses', codebook_misses)

@property
def codebook(self):
return self._codebook.codebook

def expire_codes_(self, embed_ind, batch_samples):
if self.max_codebook_misses_before_expiry == 0:
return

embed_ind = rearrange(embed_ind, '... -> (...)')
misses = torch.bincount(embed_ind, minlength = self.codebook_size) == 0
self.codebook_misses += misses

expired_codes = self.codebook_misses >= self.max_codebook_misses_before_expiry
if not torch.any(expired_codes):
return

self.codebook_misses.masked_fill_(expired_codes, 0)
batch_samples = rearrange(batch_samples, '... d -> (...) d')
self._codebook.replace(batch_samples, mask = expired_codes)

def forward(self, x):
dtype = x.dtype
x = self.project_in(x)

quantize, embed_ind = self._codebook(x)

commit_loss = 0.

if self.training:
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
quantize = x + (quantize - x).detach()
self.expire_codes_(embed_ind, x)

quantize = self.project_out(quantize)
return quantize, embed_ind, commit_loss

0 comments on commit b1f5d8e

Please sign in to comment.