Skip to content

Commit a81322e

Browse files
authored
Fix QDQLayout (#2037)
* up * up
1 parent f788897 commit a81322e

File tree

1 file changed

+185
-5
lines changed

1 file changed

+185
-5
lines changed

torchao/dtypes/uintx/q_dq_layout.py

+185-5
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,199 @@
2424
logger.addHandler(handler)
2525

2626

27-
from torchao.dtypes.utils import PlainLayout
27+
from dataclasses import dataclass
28+
from typing import Optional, Tuple
2829

30+
from torch.utils._python_dispatch import (
31+
return_and_correct_aliasing,
32+
)
33+
34+
from torchao.dtypes.utils import AQTTensorImpl, Layout
35+
from torchao.utils import fill_defaults
36+
37+
aten = torch.ops.aten
2938

30-
class QDQLayout(PlainLayout):
39+
40+
@dataclass(frozen=True)
41+
class QDQLayout(Layout):
3142
pass
3243

3344

34-
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
45+
def _same_metadata(self: "QDQTensorImpl", src: "QDQTensorImpl") -> bool:
46+
return (
47+
isinstance(self, QDQTensorImpl)
48+
and isinstance(src, QDQTensorImpl)
49+
and self.shape == src.shape
50+
and self.int_data.shape == src.int_data.shape
51+
and self.scale.shape == src.scale.shape
52+
and (self.zero_point is None and src.zero_point is None)
53+
or (
54+
self.zero_point is not None
55+
and src.zero_point is not None
56+
and self.zero_point.shape == src.zero_point.shape
57+
)
58+
and type(self._layout) == type(src._layout)
59+
)
3560

3661

3762
@register_layout(QDQLayout)
38-
class _Impl(PlainAQTTensorImpl):
39-
pass
63+
class QDQTensorImpl(AQTTensorImpl):
64+
"""
65+
TensorImpl for QDQLayout layout for affine quantized tensor, it stores int_data, scale, zero_point
66+
tensors directly as plain tensors.
67+
68+
fields:
69+
int_data (torch.Tensor): the quantized integer data Tensor
70+
scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor
71+
zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor
72+
"""
73+
74+
def __new__(
75+
cls,
76+
int_data: torch.Tensor,
77+
scale: torch.Tensor,
78+
zero_point: Optional[torch.Tensor],
79+
_layout: Layout,
80+
):
81+
kwargs = {}
82+
kwargs["device"] = int_data.device
83+
kwargs["layout"] = (
84+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
85+
)
86+
kwargs["dtype"] = int_data.dtype
87+
kwargs["requires_grad"] = False
88+
shape = int_data.shape
89+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
90+
91+
def __init__(
92+
self,
93+
int_data: torch.Tensor,
94+
scale: torch.Tensor,
95+
zero_point: Optional[torch.Tensor],
96+
_layout: Layout,
97+
):
98+
self.int_data = int_data
99+
self.scale = scale
100+
self.zero_point = zero_point
101+
self._layout = _layout
102+
103+
def __tensor_flatten__(self):
104+
if self.zero_point is None:
105+
return ["int_data", "scale"], [self._layout]
106+
return ["int_data", "scale", "zero_point"], [self._layout]
107+
108+
@classmethod
109+
def __tensor_unflatten__(
110+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
111+
):
112+
int_data, scale, zero_point = (
113+
tensor_data_dict["int_data"],
114+
tensor_data_dict["scale"],
115+
tensor_data_dict.get("zero_point", None),
116+
)
117+
(_layout,) = tensor_attributes
118+
return cls(int_data, scale, zero_point, _layout)
119+
120+
def to(self, *args, **kwargs):
121+
kwargs = self._get_to_kwargs(*args, **kwargs)
122+
return self.__class__(
123+
self.int_data.to(kwargs["device"]),
124+
self.scale.to(kwargs["device"]),
125+
self.zero_point.to(kwargs["device"])
126+
if self.zero_point is not None
127+
else None,
128+
self._layout,
129+
)
130+
131+
def _apply_fn_to_data(self, fn):
132+
return self.__class__(
133+
fn(self.int_data),
134+
fn(self.scale),
135+
fn(self.zero_point) if self.zero_point is not None else None,
136+
self._layout,
137+
)
138+
139+
@classmethod
140+
def __torch_dispatch__(cls, func, types, args, kwargs):
141+
kwargs = {} if kwargs is None else kwargs
142+
143+
if func is aten.detach.default:
144+
return return_and_correct_aliasing(
145+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
146+
)
147+
148+
elif func is aten.clone.default:
149+
return return_and_correct_aliasing(
150+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
151+
)
152+
153+
elif func is aten.copy_.default:
154+
self = args[0]
155+
src = args[1]
156+
if _same_metadata(self, src):
157+
self_tensors = self.__tensor_flatten__()[0]
158+
for tensor_name in self_tensors:
159+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
160+
return
161+
raise ValueError(
162+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
163+
)
164+
165+
elif func is aten.t.default:
166+
tensor = args[0]
167+
new = tensor.__class__(
168+
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout
169+
)
170+
return return_and_correct_aliasing(func, args, kwargs, new)
171+
172+
elif func is aten.slice.Tensor:
173+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
174+
if dim in [0, 1]:
175+
int_data, scale, zero_point = self.get_plain()
176+
data_len = int_data.shape[dim]
177+
scale_len = scale.shape[dim]
178+
ratio = data_len / scale_len
179+
start_scale = int(start / ratio)
180+
end_scale = int(end / ratio)
181+
182+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
183+
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
184+
zero_point = aten.slice.Tensor(
185+
zero_point, dim, start_scale, end_scale, step
186+
)
187+
# this is to handle padding
188+
int_data, scale, zero_point = self._layout.post_process(
189+
int_data, scale, zero_point, self.block_size
190+
)
191+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
192+
return return_and_correct_aliasing(func, args, kwargs, sliced)
193+
else:
194+
raise NotImplementedError(
195+
f"QDQTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
196+
)
197+
198+
raise NotImplementedError(
199+
f"QDQTensorImpl dispatch: attempting to run {func}, this is not supported"
200+
)
201+
202+
__torch_function__ = torch._C._disabled_torch_function_impl
203+
204+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
205+
return self.int_data, self.scale, self.zero_point
206+
207+
def get_layout(self) -> Layout:
208+
return self._layout
209+
210+
@classmethod
211+
def from_plain(
212+
cls,
213+
int_data: torch.Tensor,
214+
scale: torch.Tensor,
215+
zero_point: Optional[torch.Tensor],
216+
_layout: Layout,
217+
):
218+
assert isinstance(_layout, QDQLayout)
219+
return cls(int_data, scale, zero_point, _layout)
40220

41221

42222
def _linear_check(input_tensor, weight_tensor, bias):

0 commit comments

Comments
 (0)