Skip to content

Cannot set version_counter for inference tensor #3487

@Stonepia

Description

@Stonepia

Reproduer

When I tried to run the following:

import torch
import torch.nn as nn
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor

device = torch.device("cuda")
#device = torch.device("xpu")

# Create a simple linear layer
linear = nn.Linear(32, 48, bias=True).to(device).to(torch.bfloat16).eval()
linear.requires_grad_(False)

print("Quantizing with version=2...")
quantize_(
    linear,
    Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor(), version=2),
    device=device,
)

# Create input
input_tensor = torch.randn(16, 32, dtype=torch.bfloat16, device=device)

# Test forward pass inside inference_mode
print("\nTesting forward pass inside inference_mode...")
with torch.inference_mode():
    output = linear(input_tensor)

I got the following error on both CUDA/XPU:

Testing forward pass inside inference_mode...
Traceback (most recent call last):
  File "/root/tongsu/test_full_flow.py", line 26, in <module>
    output = linear(input_tensor)
  File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniforge3/envs/tongsu_1212/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
  File "/root/tongsu/ao/torchao/utils.py", line 632, in _dispatch__torch_function__
    return cls._TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
  File "/root/tongsu/ao/torchao/utils.py", line 435, in wrapper
    return _func(f, types, args, kwargs)
  File "/root/tongsu/ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 266, in _
    return _float8_addmm_impl(input_tensor, weight_tensor.t(), bias)
  File "/root/tongsu/ao/torchao/utils.py", line 634, in _dispatch__torch_function__
    return func(*args, **kwargs)
RuntimeError: Cannot set version_counter for inference tensor

This supposed to be because I should put the tensor creation and usage all inside the with torch.inference_mode():

# This works correctly, which put everything inside the with context

with torch.inference_mode():
   linear = nn.Linear(32, 48, bias=True).to(device).to(torch.bfloat16).eval()
   ...

However, maybe we could do something further to make the initialization more user friendly?

Env

torch: Version: 2.10.0.dev20251210+cu128
torchao: Version: 0.16.0+gitff6d9e244

Problem Analysis

As the error message suggests, the failure happened on this:

#float8_tensor.py
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
...

return _float8_addmm_impl(input_tensor, weight_tensor.t(), bias) # <---- This weight_tensor.t()

It actually happened on the weight_tensor.t().

As the error message suggests, should be in the dispatching process in:

  File "/root/.../ao/torchao/utils.py", line 634, in _dispatch__torch_function__
    return func(*args, **kwargs)

So the quick solution is to avoid this dispatching by register the implements_torch_function(torch.Tensor.t) directly:

Possible Fix

In torchao, there is a fix by registering like below. I created a PR at #3488 , but I don't know if this is a proper fix

@implements(aten.t.default)
def _(func, types, args, kwargs):
...

+@implements_torch_function(torch.Tensor.t)
+def _(func, types, args, kwargs):
+    print("DEBUG: implements_torch_function(torch.Tensor.t)")
+    assert len(args) == 1
+    self = args[0]
+    assert len(self.block_size) == 2
+    new_tensor = self.__class__(
+        self.qdata.t(),
+        self.scale.t(),
+        (self.block_size[1], self.block_size[0]),
+        self.mm_config,
+        self.act_quant_kwargs,
+        self.kernel_preference,
+        self.dtype,
+    )
+    return new_tensor

This output is like below:

DEBUG: implements_torch_function(torch.Tensor.t)

So the difference is (by asking GPT):

This is the success path:

Python: weight_tensor.t()
    ↓
__torch_function__ intercepts
    ↓
Your handler runs, returns new_tensor (is_inference=True)
    ↓
Returns DIRECTLY to Python caller ← No C++ ADInplaceOrView involved!
    ↓
SUCCESS

This is the failure path:

Python: weight_tensor.t()
    ↓
__torch_function__ - NOT registered, falls through
    ↓
func(*args, **kwargs) with DisableTorchFunctionSubclass
    ↓
C++ dispatch starts
    ↓
ADInplaceOrView::t() runs ← This is the problem!
    ↓
Calls at::_ops::t::redispatch() → eventually hits __torch_dispatch__
    ↓
Your @implements(aten.t.default) handler runs, returns new_tensor (is_inference=True)
    ↓
Returns to ADInplaceOrView::t()
    ↓
ADInplaceOrView::t() calls as_view(base, output) on YOUR returned tensor
    ↓
as_view() calls set_version_counter() on output (the inference tensor)
    ↓
ERROR: Cannot set version_counter for inference tensor

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingfloat8quantize_quantize_ API

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions