Skip to content

Commit

Permalink
make library jittable
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 2, 2021
1 parent e562e0f commit 7f26a57
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.3.10',
version = '0.3.11',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 10 additions & 4 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def __init__(
self.register_buffer('embed', embed)
self.register_buffer('embed_avg', embed.clone())

@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return

embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
Expand Down Expand Up @@ -115,8 +119,7 @@ def forward(self, x):
flatten = rearrange(x, '... d -> (...) d')
embed = self.embed.t()

if not self.initted:
self.init_embed_(flatten)
self.init_embed_(flatten)

dist = -(
flatten.pow(2).sum(1, keepdim=True)
Expand Down Expand Up @@ -168,7 +171,11 @@ def __init__(
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed', embed)

@torch.jit.ignore
def init_embed_(self, data):
if self.initted:
return

embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters,
use_cosine_sim = True)
self.embed.data.copy_(embed)
Expand Down Expand Up @@ -199,8 +206,7 @@ def forward(self, x):
flatten = rearrange(x, '... d -> (...) d')
flatten = l2norm(flatten)

if not self.initted:
self.init_embed_(flatten)
self.init_embed_(flatten)

embed = l2norm(self.embed)
dist = flatten @ embed.t()
Expand Down

0 comments on commit 7f26a57

Please sign in to comment.