Skip to content

Commit

Permalink
address #188
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 28, 2025
1 parent 59a30b6 commit ec2f4f6
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,29 +337,29 @@ def forward(

# whether to only use a fraction of probs, for reducing memory

if exists(mask):
input_for_entropy = original_input[mask]

input_for_entropy = rearrange(input_for_entropy, 'b n ... -> (b n) ...')

if self.frac_per_sample_entropy < 1.:
# account for mask
if exists(mask):
original_input = original_input[mask]
original_input = rearrange(original_input, 'b n ... -> (b n) ...')

num_tokens = original_input.size(0)
num_tokens = input_for_entropy.size(0)
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens

sampled_input = original_input[rand_mask]
sampled_input = input_for_entropy[rand_mask]

sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)

sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)

per_sample_probs = sampled_prob
else:
if exists(mask):
original_input = original_input[mask]
original_input = rearrange(original_input, 'b n ... -> (b n) ...')

# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
distance = -2 * einsum('... i d, j d -> ... i j', input_for_entropy, codebook)

prob = (-distance * inv_temperature).softmax(dim = -1)

Expand Down

0 comments on commit ec2f4f6

Please sign in to comment.