diff --git a/pyproject.toml b/pyproject.toml index 1d61dcc..1ebcd0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.15.0" +version = "1.15.1" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py index 20d5d36..b59b697 100644 --- a/vector_quantize_pytorch/finite_scalar_quantization.py +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -4,7 +4,8 @@ """ from __future__ import annotations -from functools import wraps +from functools import wraps, partial +from contextlib import nullcontext from typing import List, Tuple import torch @@ -61,6 +62,7 @@ def __init__( channel_first: bool = False, projection_has_bias: bool = True, return_indices = True, + force_quantization_f32 = True ): super().__init__() _levels = torch.tensor(levels, dtype=int32) @@ -99,6 +101,7 @@ def __init__( self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) self.allowed_dtypes = allowed_dtypes + self.force_quantization_f32 = force_quantization_f32 def bound(self, z, eps: float = 1e-3): """ Bound `z`, an array of shape (..., d). """ @@ -166,7 +169,6 @@ def forward(self, z): c - number of codebook dim """ - orig_dtype = z.dtype is_img_or_video = z.ndim >= 4 need_move_channel_last = is_img_or_video or self.channel_first @@ -182,25 +184,28 @@ def forward(self, z): z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) - # make sure allowed dtype before quantizing + # whether to force quantization step to be full precision or not - if z.dtype not in self.allowed_dtypes: - z = z.float() + force_f32 = self.force_quantization_f32 + quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext - codes = self.quantize(z) + with quantization_context(): + orig_dtype = z.dtype - # returning indices could be optional + if force_f32 and orig_dtype not in self.allowed_dtypes: + z = z.float() - indices = None + codes = self.quantize(z) - if self.return_indices: - indices = self.codes_to_indices(codes) + # returning indices could be optional - codes = rearrange(codes, 'b n c d -> b n (c d)') + indices = None - # cast codes back to original dtype + if self.return_indices: + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, 'b n c d -> b n (c d)') - if codes.dtype != orig_dtype: codes = codes.type(orig_dtype) # project out