Skip to content

Commit

Permalink
Irie et al. notices that the original Oord implementation of VQ sets …
Browse files Browse the repository at this point in the history
…cluster sizes of 0 initially, leading to worse convergence. not an issue if kmeans init is turned on
  • Loading branch information
lucidrains committed Jul 10, 2024
1 parent cb3dd32 commit 4c514db
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,12 @@ assert loss.item() >= 0
primaryClass = {cs.LG}
}
```

```bibtex
@inproceedings{Irie2023SelfOrganisingND,
title = {Self-Organising Neural Discrete Representation Learning \`a la Kohonen},
author = {Kazuki Irie and R'obert Csord'as and J{\"u}rgen Schmidhuber},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:256901024}
}
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.15.2"
version = "1.15.3"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __init__(
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
self.register_buffer('cluster_size', torch.ones(num_codebooks, codebook_size))
self.register_buffer('embed_avg', embed.clone())

self.learnable_codebook = learnable_codebook
Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

self.register_buffer('initted', torch.Tensor([not kmeans_init]))
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
self.register_buffer('cluster_size', torch.ones(num_codebooks, codebook_size))
self.register_buffer('embed_avg', embed.clone())

self.learnable_codebook = learnable_codebook
Expand Down

0 comments on commit 4c514db

Please sign in to comment.