|
24 | 24 | logger.addHandler(handler)
|
25 | 25 |
|
26 | 26 |
|
27 |
| -from torchao.dtypes.utils import PlainLayout |
| 27 | +from dataclasses import dataclass |
| 28 | +from typing import Optional, Tuple |
28 | 29 |
|
| 30 | +from torch.utils._python_dispatch import ( |
| 31 | + return_and_correct_aliasing, |
| 32 | +) |
| 33 | + |
| 34 | +from torchao.dtypes.utils import AQTTensorImpl, Layout |
| 35 | +from torchao.utils import fill_defaults |
| 36 | + |
| 37 | +aten = torch.ops.aten |
29 | 38 |
|
30 |
| -class QDQLayout(PlainLayout): |
| 39 | + |
| 40 | +@dataclass(frozen=True) |
| 41 | +class QDQLayout(Layout): |
31 | 42 | pass
|
32 | 43 |
|
33 | 44 |
|
34 |
| -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl |
| 45 | +def _same_metadata(self: "QDQTensorImpl", src: "QDQTensorImpl") -> bool: |
| 46 | + return ( |
| 47 | + isinstance(self, QDQTensorImpl) |
| 48 | + and isinstance(src, QDQTensorImpl) |
| 49 | + and self.shape == src.shape |
| 50 | + and self.int_data.shape == src.int_data.shape |
| 51 | + and self.scale.shape == src.scale.shape |
| 52 | + and (self.zero_point is None and src.zero_point is None) |
| 53 | + or ( |
| 54 | + self.zero_point is not None |
| 55 | + and src.zero_point is not None |
| 56 | + and self.zero_point.shape == src.zero_point.shape |
| 57 | + ) |
| 58 | + and type(self._layout) == type(src._layout) |
| 59 | + ) |
35 | 60 |
|
36 | 61 |
|
37 | 62 | @register_layout(QDQLayout)
|
38 |
| -class _Impl(PlainAQTTensorImpl): |
39 |
| - pass |
| 63 | +class QDQTensorImpl(AQTTensorImpl): |
| 64 | + """ |
| 65 | + TensorImpl for QDQLayout layout for affine quantized tensor, it stores int_data, scale, zero_point |
| 66 | + tensors directly as plain tensors. |
| 67 | +
|
| 68 | + fields: |
| 69 | + int_data (torch.Tensor): the quantized integer data Tensor |
| 70 | + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor |
| 71 | + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor |
| 72 | + """ |
| 73 | + |
| 74 | + def __new__( |
| 75 | + cls, |
| 76 | + int_data: torch.Tensor, |
| 77 | + scale: torch.Tensor, |
| 78 | + zero_point: Optional[torch.Tensor], |
| 79 | + _layout: Layout, |
| 80 | + ): |
| 81 | + kwargs = {} |
| 82 | + kwargs["device"] = int_data.device |
| 83 | + kwargs["layout"] = ( |
| 84 | + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout |
| 85 | + ) |
| 86 | + kwargs["dtype"] = int_data.dtype |
| 87 | + kwargs["requires_grad"] = False |
| 88 | + shape = int_data.shape |
| 89 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 90 | + |
| 91 | + def __init__( |
| 92 | + self, |
| 93 | + int_data: torch.Tensor, |
| 94 | + scale: torch.Tensor, |
| 95 | + zero_point: Optional[torch.Tensor], |
| 96 | + _layout: Layout, |
| 97 | + ): |
| 98 | + self.int_data = int_data |
| 99 | + self.scale = scale |
| 100 | + self.zero_point = zero_point |
| 101 | + self._layout = _layout |
| 102 | + |
| 103 | + def __tensor_flatten__(self): |
| 104 | + if self.zero_point is None: |
| 105 | + return ["int_data", "scale"], [self._layout] |
| 106 | + return ["int_data", "scale", "zero_point"], [self._layout] |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def __tensor_unflatten__( |
| 110 | + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 111 | + ): |
| 112 | + int_data, scale, zero_point = ( |
| 113 | + tensor_data_dict["int_data"], |
| 114 | + tensor_data_dict["scale"], |
| 115 | + tensor_data_dict.get("zero_point", None), |
| 116 | + ) |
| 117 | + (_layout,) = tensor_attributes |
| 118 | + return cls(int_data, scale, zero_point, _layout) |
| 119 | + |
| 120 | + def to(self, *args, **kwargs): |
| 121 | + kwargs = self._get_to_kwargs(*args, **kwargs) |
| 122 | + return self.__class__( |
| 123 | + self.int_data.to(kwargs["device"]), |
| 124 | + self.scale.to(kwargs["device"]), |
| 125 | + self.zero_point.to(kwargs["device"]) |
| 126 | + if self.zero_point is not None |
| 127 | + else None, |
| 128 | + self._layout, |
| 129 | + ) |
| 130 | + |
| 131 | + def _apply_fn_to_data(self, fn): |
| 132 | + return self.__class__( |
| 133 | + fn(self.int_data), |
| 134 | + fn(self.scale), |
| 135 | + fn(self.zero_point) if self.zero_point is not None else None, |
| 136 | + self._layout, |
| 137 | + ) |
| 138 | + |
| 139 | + @classmethod |
| 140 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 141 | + kwargs = {} if kwargs is None else kwargs |
| 142 | + |
| 143 | + if func is aten.detach.default: |
| 144 | + return return_and_correct_aliasing( |
| 145 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 146 | + ) |
| 147 | + |
| 148 | + elif func is aten.clone.default: |
| 149 | + return return_and_correct_aliasing( |
| 150 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 151 | + ) |
| 152 | + |
| 153 | + elif func is aten.copy_.default: |
| 154 | + self = args[0] |
| 155 | + src = args[1] |
| 156 | + if _same_metadata(self, src): |
| 157 | + self_tensors = self.__tensor_flatten__()[0] |
| 158 | + for tensor_name in self_tensors: |
| 159 | + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) |
| 160 | + return |
| 161 | + raise ValueError( |
| 162 | + f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" |
| 163 | + ) |
| 164 | + |
| 165 | + elif func is aten.t.default: |
| 166 | + tensor = args[0] |
| 167 | + new = tensor.__class__( |
| 168 | + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout |
| 169 | + ) |
| 170 | + return return_and_correct_aliasing(func, args, kwargs, new) |
| 171 | + |
| 172 | + elif func is aten.slice.Tensor: |
| 173 | + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) |
| 174 | + if dim in [0, 1]: |
| 175 | + int_data, scale, zero_point = self.get_plain() |
| 176 | + data_len = int_data.shape[dim] |
| 177 | + scale_len = scale.shape[dim] |
| 178 | + ratio = data_len / scale_len |
| 179 | + start_scale = int(start / ratio) |
| 180 | + end_scale = int(end / ratio) |
| 181 | + |
| 182 | + int_data = aten.slice.Tensor(int_data, dim, start, end, step) |
| 183 | + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) |
| 184 | + zero_point = aten.slice.Tensor( |
| 185 | + zero_point, dim, start_scale, end_scale, step |
| 186 | + ) |
| 187 | + # this is to handle padding |
| 188 | + int_data, scale, zero_point = self._layout.post_process( |
| 189 | + int_data, scale, zero_point, self.block_size |
| 190 | + ) |
| 191 | + sliced = self.from_plain(int_data, scale, zero_point, self._layout) |
| 192 | + return return_and_correct_aliasing(func, args, kwargs, sliced) |
| 193 | + else: |
| 194 | + raise NotImplementedError( |
| 195 | + f"QDQTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" |
| 196 | + ) |
| 197 | + |
| 198 | + raise NotImplementedError( |
| 199 | + f"QDQTensorImpl dispatch: attempting to run {func}, this is not supported" |
| 200 | + ) |
| 201 | + |
| 202 | + __torch_function__ = torch._C._disabled_torch_function_impl |
| 203 | + |
| 204 | + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| 205 | + return self.int_data, self.scale, self.zero_point |
| 206 | + |
| 207 | + def get_layout(self) -> Layout: |
| 208 | + return self._layout |
| 209 | + |
| 210 | + @classmethod |
| 211 | + def from_plain( |
| 212 | + cls, |
| 213 | + int_data: torch.Tensor, |
| 214 | + scale: torch.Tensor, |
| 215 | + zero_point: Optional[torch.Tensor], |
| 216 | + _layout: Layout, |
| 217 | + ): |
| 218 | + assert isinstance(_layout, QDQLayout) |
| 219 | + return cls(int_data, scale, zero_point, _layout) |
40 | 220 |
|
41 | 221 |
|
42 | 222 | def _linear_check(input_tensor, weight_tensor, bias):
|
|
0 commit comments