Skip to content

Commit

Permalink
move the rotation operation from rotation trick to a function for pot…
Browse files Browse the repository at this point in the history
…ential reuse for other research
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent e7ff7d7 commit 723ea9f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.19.4"
version = "1.19.5"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
63 changes: 34 additions & 29 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,6 @@ def lens_to_mask(lens, max_length):
seq = torch.arange(max_length, device = lens.device)
return seq < lens[:, None]

def efficient_rotation_trick_transform(u, q, e):
"""
4.2 in https://arxiv.org/abs/2410.06424
"""
e = rearrange(e, 'b d -> b 1 d')
w = l2norm(u + q, dim = 1).detach()

return (
e -
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
)

def uniform_init(*shape):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
Expand Down Expand Up @@ -248,6 +235,39 @@ def kmeans(

return means, bins

# rotation trick related

def efficient_rotation_trick_transform(u, q, e):
"""
4.2 in https://arxiv.org/abs/2410.06424
"""
e = rearrange(e, 'b d -> b 1 d')
w = l2norm(u + q, dim = 1).detach()

return (
e -
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
)

def rotate_from_to(src, tgt):
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
tgt, inverse = pack_one(tgt, '* d')
src, _ = pack_one(src, '* d')

norm_tgt = tgt.norm(dim = -1, keepdim = True)
norm_src = src.norm(dim = -1, keepdim = True)

rotated_src = efficient_rotation_trick_transform(
safe_div(tgt, norm_tgt),
safe_div(src, norm_src),
tgt
).squeeze()

rotated = rotated_src * safe_div(norm_src, norm_tgt).detach()

return inverse(rotated)

# distributed helpers

@cache
Expand Down Expand Up @@ -1098,22 +1118,7 @@ def forward(
commit_quantize = maybe_detach(quantize)

if self.rotation_trick:
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
x, inverse = pack_one(x, '* d')
quantize, _ = pack_one(quantize, '* d')

norm_x = x.norm(dim = -1, keepdim = True)
norm_quantize = quantize.norm(dim = -1, keepdim = True)

rot_quantize = efficient_rotation_trick_transform(
safe_div(x, norm_x),
safe_div(quantize, norm_quantize),
x
).squeeze()

quantize = rot_quantize * safe_div(norm_quantize, norm_x).detach()

x, quantize = inverse(x), inverse(quantize)
quantize = rotate_from_to(quantize, x)
else:
# standard STE to get gradients through VQ layer.
quantize = x + (quantize - x).detach()
Expand Down

0 comments on commit 723ea9f

Please sign in to comment.