Skip to content

Commit

Permalink
Merge pull request #185 from lucasnewman/fsq_preserve_symmetry
Browse files Browse the repository at this point in the history
Add symmetry-preserving and noise-approximated quantization for FSQ from arxiv:2411.19842
  • Loading branch information
lucidrains authored Jan 3, 2025
2 parents 8c27a0e + 09c33f3 commit e2abbe5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
7 changes: 5 additions & 2 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,14 @@ def test_tiger():

assert torch.allclose(quantized, quantized_out, atol = 1e-5)

def test_fsq():
@pytest.mark.parametrize('preserve_symmetry', (True, False))
def test_fsq(
preserve_symmetry
):
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)
Expand Down
44 changes: 38 additions & 6 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from einops import rearrange, pack, unpack

import random

# helper functions

def exists(v):
Expand Down Expand Up @@ -62,9 +64,12 @@ def __init__(
channel_first: bool = False,
projection_has_bias: bool = True,
return_indices = True,
force_quantization_f32 = True
force_quantization_f32 = True,
preserve_symmetry: bool = False,
noise_approx_prob = 0.0,
):
super().__init__()

_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent = False)

Expand All @@ -73,6 +78,9 @@ def __init__(

self.scale = scale

self.preserve_symmetry = preserve_symmetry
self.noise_approx_prob = noise_approx_prob

codebook_dim = len(levels)
self.codebook_dim = codebook_dim

Expand Down Expand Up @@ -110,12 +118,36 @@ def bound(self, z, eps: float = 1e-3):
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset

def quantize(self, z):
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842

def symmetry_preserving_bound(self, z):
"""
QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
"""
levels_minus_1 = (self._levels - 1)
scale = 2.0 / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
return scale * bracket - 1.0

def noise_approx_bound(self, z):
"""
simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
"""
noise = torch.empty_like(z).uniform_(-1, 1)
return torch.tanh(z) + noise / (self._levels - 1)

def quantize(self, z, preserve_symmetry = False):
""" Quantizes z, returns quantized zhat, same shape as z. """
quantized = round_ste(self.bound(z))
if self.training and random.random() < self.noise_approx_prob:
bounded = self.noise_approx_bound(z)
elif preserve_symmetry:
bounded = self.symmetry_preserving_bound(z)
else:
bounded = self.bound(z)
quantized = round_ste(bounded)
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width

def _scale_and_shift(self, zhat_normalized):
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
Expand Down Expand Up @@ -194,7 +226,7 @@ def forward(self, z):
if force_f32 and orig_dtype not in self.allowed_dtypes:
z = z.float()

codes = self.quantize(z)
codes = self.quantize(z, preserve_symmetry=self.preserve_symmetry)

# returning indices could be optional

Expand All @@ -205,7 +237,7 @@ def forward(self, z):

codes = rearrange(codes, 'b n c d -> b n (c d)')

codes = codes.type(orig_dtype)
codes = codes.to(orig_dtype)

# project out

Expand Down

0 comments on commit e2abbe5

Please sign in to comment.