Skip to content

Commit

Permalink
allow for an MLP implicitly parameterized codebook instead of just a …
Browse files Browse the repository at this point in the history
…single Linear, for SimVQ. seems to converge just fine with rotation trick
  • Loading branch information
lucidrains committed Nov 11, 2024
1 parent cd0fa8e commit 80f4e84
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.20.1"
version = "1.20.2"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
13 changes: 12 additions & 1 deletion vector_quantize_pytorch/sim_vq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Callable

import torch
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
self,
dim,
codebook_size,
codebook_transform: Module | None = None,
init_fn: Callable = identity,
accept_image_fmap = False,
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
Expand All @@ -51,7 +53,11 @@ def __init__(

# the codebook is actually implicit from a linear layer from frozen gaussian or uniform

self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
if not exists(codebook_transform):
codebook_transform = nn.Linear(dim, dim, bias = False)

self.codebook_to_codes = codebook_transform

self.register_buffer('codebook', codebook)


Expand Down Expand Up @@ -114,6 +120,11 @@ def forward(

sim_vq = SimVQ(
dim = 512,
codebook_transform = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 512)
),
codebook_size = 1024,
accept_image_fmap = True
)
Expand Down

0 comments on commit 80f4e84

Please sign in to comment.