Skip to content

Commit

Permalink
add channel_last, for ability to define whether input has features at…
Browse files Browse the repository at this point in the history
… last dimension, default to true
  • Loading branch information
lucidrains committed Oct 27, 2021
1 parent baf249e commit 567dc9c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.3.7',
version = '0.3.8',
license='MIT',
description = 'Vector Quantization - Pytorch',
author = 'Phil Wang',
Expand Down
13 changes: 12 additions & 1 deletion vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def __init__(
kmeans_init = False,
kmeans_iters = 10,
use_cosine_sim = False,
threshold_ema_dead_code = 0
threshold_ema_dead_code = 0,
channel_last = True
):
super().__init__()
n_embed = default(n_embed, codebook_size)
Expand Down Expand Up @@ -271,12 +272,18 @@ def __init__(
)

self.codebook_size = codebook_size
self.channel_last = channel_last

@property
def codebook(self):
return self._codebook.codebook

def forward(self, x):
need_transpose = not self.channel_last

if need_transpose:
x = rearrange(x, 'b n d -> b d n')

x = self.project_in(x)

quantize, embed_ind = self._codebook(x)
Expand All @@ -288,4 +295,8 @@ def forward(self, x):
quantize = x + (quantize - x).detach()

quantize = self.project_out(quantize)

if need_transpose:
quantize = rearrange(quantize, 'b n d -> b d n')

return quantize, embed_ind, commit_loss

0 comments on commit 567dc9c

Please sign in to comment.