-
Notifications
You must be signed in to change notification settings - Fork 386
Description
Reproduer
When I tried to run the following:
import torch
import torch.nn as nn
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
device = torch.device("cuda")
#device = torch.device("xpu")
# Create a simple linear layer
linear = nn.Linear(32, 48, bias=True).to(device).to(torch.bfloat16).eval()
linear.requires_grad_(False)
print("Quantizing with version=2...")
quantize_(
linear,
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor(), version=2),
device=device,
)
# Create input
input_tensor = torch.randn(16, 32, dtype=torch.bfloat16, device=device)
# Test forward pass inside inference_mode
print("\nTesting forward pass inside inference_mode...")
with torch.inference_mode():
output = linear(input_tensor)I got the following error on both CUDA/XPU:
Testing forward pass inside inference_mode...
Traceback (most recent call last):
File "/root/tongsu/test_full_flow.py", line 26, in <module>
output = linear(input_tensor)
File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
File "/root/tongsu/ao/torchao/utils.py", line 632, in _dispatch__torch_function__
return cls._TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
File "/root/tongsu/ao/torchao/utils.py", line 435, in wrapper
return _func(f, types, args, kwargs)
File "/root/tongsu/ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 266, in _
return _float8_addmm_impl(input_tensor, weight_tensor.t(), bias)
File "/root/tongsu/ao/torchao/utils.py", line 634, in _dispatch__torch_function__
return func(*args, **kwargs)
RuntimeError: Cannot set version_counter for inference tensor
This supposed to be because I should put the tensor creation and usage all inside the with torch.inference_mode():
# This works correctly, which put everything inside the with context
with torch.inference_mode():
linear = nn.Linear(32, 48, bias=True).to(device).to(torch.bfloat16).eval()
...However, maybe we could do something further to make the initialization more user friendly?
Env
torch: Version: 2.10.0.dev20251210+cu128
torchao: Version: 0.16.0+gitff6d9e244
Problem Analysis
As the error message suggests, the failure happened on this:
#float8_tensor.py
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
...
return _float8_addmm_impl(input_tensor, weight_tensor.t(), bias) # <---- This weight_tensor.t()It actually happened on the weight_tensor.t().
As the error message suggests, should be in the dispatching process in:
File "/root/.../ao/torchao/utils.py", line 634, in _dispatch__torch_function__
return func(*args, **kwargs)
So the quick solution is to avoid this dispatching by register the implements_torch_function(torch.Tensor.t) directly:
Possible Fix
In torchao, there is a fix by registering like below. I created a PR at #3488 , but I don't know if this is a proper fix
@implements(aten.t.default)
def _(func, types, args, kwargs):
...
+@implements_torch_function(torch.Tensor.t)
+def _(func, types, args, kwargs):
+ print("DEBUG: implements_torch_function(torch.Tensor.t)")
+ assert len(args) == 1
+ self = args[0]
+ assert len(self.block_size) == 2
+ new_tensor = self.__class__(
+ self.qdata.t(),
+ self.scale.t(),
+ (self.block_size[1], self.block_size[0]),
+ self.mm_config,
+ self.act_quant_kwargs,
+ self.kernel_preference,
+ self.dtype,
+ )
+ return new_tensorThis output is like below:
DEBUG: implements_torch_function(torch.Tensor.t)
So the difference is (by asking GPT):
This is the success path:
Python: weight_tensor.t()
↓
__torch_function__ intercepts
↓
Your handler runs, returns new_tensor (is_inference=True)
↓
Returns DIRECTLY to Python caller ← No C++ ADInplaceOrView involved!
↓
SUCCESS
This is the failure path:
Python: weight_tensor.t()
↓
__torch_function__ - NOT registered, falls through
↓
func(*args, **kwargs) with DisableTorchFunctionSubclass
↓
C++ dispatch starts
↓
ADInplaceOrView::t() runs ← This is the problem!
↓
Calls at::_ops::t::redispatch() → eventually hits __torch_dispatch__
↓
Your @implements(aten.t.default) handler runs, returns new_tensor (is_inference=True)
↓
Returns to ADInplaceOrView::t()
↓
ADInplaceOrView::t() calls as_view(base, output) on YOUR returned tensor
↓
as_view() calls set_version_counter() on output (the inference tensor)
↓
ERROR: Cannot set version_counter for inference tensor