Skip to content

Commit

Permalink
fix a warning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 2, 2024
1 parent 7c7c69f commit 5930411
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.16.3"
version = "1.17.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
6 changes: 3 additions & 3 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast
from torch.amp import autocast

from einops import rearrange, pack, unpack

Expand Down Expand Up @@ -159,7 +159,7 @@ def indices_to_codes(self, indices):

return codes

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(self, z):
"""
einstein notation
Expand Down Expand Up @@ -187,7 +187,7 @@ def forward(self, z):
# whether to force quantization step to be full precision or not

force_f32 = self.force_quantization_f32
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext

with quantization_context():
orig_dtype = z.dtype
Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.cuda.amp import autocast
from torch.amp import autocast

from einops import rearrange, reduce, pack, unpack

Expand Down Expand Up @@ -293,7 +293,7 @@ def forward(

force_f32 = self.force_quantization_f32

quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext

with quantization_context():

Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/residual_fsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp import autocast

from vector_quantize_pytorch.finite_scalar_quantization import FSQ

Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(

# go through the layers

with autocast(enabled = False):
with autocast('cuda', enabled = False):
for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):

if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
Expand Down
4 changes: 2 additions & 2 deletions vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp import autocast

from vector_quantize_pytorch.lookup_free_quantization import LFQ

Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(

# go through the layers

with autocast(enabled = False):
with autocast('cuda', enabled = False):
for quantizer_index, layer in enumerate(self.layers):

if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
Expand Down
6 changes: 3 additions & 3 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.amp import autocast

import einx
from einops import rearrange, repeat, reduce, pack, unpack
Expand Down Expand Up @@ -458,7 +458,7 @@ def expire_codes_(self, batch_samples):
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(
self,
x,
Expand Down Expand Up @@ -671,7 +671,7 @@ def expire_codes_(self, batch_samples):
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(
self,
x,
Expand Down

0 comments on commit 5930411

Please sign in to comment.