Skip to content

Commit

Permalink
Merge pull request #176 from EmmettBicker/master
Browse files Browse the repository at this point in the history
Changed frac_per_sample_entropy to take up less memory
  • Loading branch information
lucidrains authored Nov 28, 2024
2 parents 76ec1de + cdd0141 commit b48946f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 14 deletions.
77 changes: 77 additions & 0 deletions tests/test_lfq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import pytest
from vector_quantize_pytorch import LFQ
import math
"""
testing_strategy:
subdivisions: using masks, using frac_per_sample_entropy < 1
"""

torch.manual_seed(0)

@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5))
@pytest.mark.parametrize('mask', (torch.tensor([False, False]),
torch.tensor([True, False]),
torch.tensor([True, True])))
def test_masked_lfq(
frac_per_sample_entropy,
mask
):
# you can specify either dim or codebook_size
# if both specified, will be validated against each other

quantizer = LFQ(
codebook_size = 65536, # codebook size, must be a power of 2
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
frac_per_sample_entropy = frac_per_sample_entropy
)

image_feats = torch.randn(2, 16, 32, 32)

ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature

quantized, indices, _ = ret
assert (quantized == quantizer.indices_to_codes(indices)).all()

@pytest.mark.parametrize('frac_per_sample_entropy', (0.1,))
@pytest.mark.parametrize('iters', (10,))
@pytest.mark.parametrize('mask', (None, torch.tensor([True, False])))
def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask):
image_feats = torch.randn(2, 16, 32, 32)

full_per_sample_entropy_quantizer = LFQ(
codebook_size = 65536, # codebook size, must be a power of 2
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
frac_per_sample_entropy = 1
)

partial_per_sample_entropy_quantizer = LFQ(
codebook_size = 65536, # codebook size, must be a power of 2
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
frac_per_sample_entropy = frac_per_sample_entropy
)

ret, loss_breakdown = full_per_sample_entropy_quantizer(
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask)
true_per_sample_entropy = loss_breakdown.per_sample_entropy

per_sample_losses = torch.zeros(iters)
for iter in range(iters):
ret, loss_breakdown = partial_per_sample_entropy_quantizer(
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature

quantized, indices, _ = ret
assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all()
per_sample_losses[iter] = loss_breakdown.per_sample_entropy
# 95% confidence interval
assert abs(per_sample_losses.mean() - true_per_sample_entropy) \
< (1.96*(per_sample_losses.std() / math.sqrt(iters)))

print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy))
print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters))))
36 changes: 22 additions & 14 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,34 @@ def forward(

codebook = self.maybe_l2norm(codebook)

# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)

prob = (-distance * inv_temperature).softmax(dim = -1)

# account for mask

if exists(mask):
prob = prob[mask]
else:
prob = rearrange(prob, 'b n ... -> (b n) ...')

# whether to only use a fraction of probs, for reducing memory

if self.frac_per_sample_entropy < 1.:
num_tokens = prob.shape[0]
# account for mask
if exists(mask):
original_input = original_input[mask]
original_input = rearrange(original_input, 'b n ... -> (b n) ...')

num_tokens = original_input.size(0)
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
per_sample_probs = prob[rand_mask]

sampled_input = original_input[rand_mask]

sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)

sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)

per_sample_probs = sampled_prob
else:
if exists(mask):
original_input = original_input[mask]
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)

prob = (-distance * inv_temperature).softmax(dim = -1)

per_sample_probs = prob

# calculate per sample entropy
Expand Down

0 comments on commit b48946f

Please sign in to comment.