-
Notifications
You must be signed in to change notification settings - Fork 386
enable smoothquant for int8 static tensor #3468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
48cdb61
0b73aed
1e49945
669b6ee
9071526
1539e0f
d9a2b1b
673f228
739fd64
750db1a
9410488
45a3a76
4e2f09c
dd80cca
7f73062
ac6a2b6
f28df4a
328585e
ce4d568
a665d45
0338016
ee39691
9eb0aa9
d4a1514
3cdea56
fa9022d
8ce5cde
6f64121
8ae921d
5b9e243
7a0e38f
3d18edf
9e07f8b
4274e02
b5309eb
a732fee
0c23589
0872986
049830f
2586ab6
8cd656e
ea9b8e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,6 @@ | |
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Optional | ||
|
|
||
|
|
@@ -60,7 +59,7 @@ class Int8Tensor(TorchAOBaseTensor): | |
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "scale"] | ||
| optional_tensor_data_names = ["act_scale"] | ||
| optional_tensor_data_names = ["act_scale", "act_pre_scale"] | ||
| tensor_attribute_names = ["block_size", "dtype"] | ||
| optional_tensor_attribute_names = [ | ||
| "act_quant_kwargs", | ||
|
|
@@ -73,6 +72,7 @@ def __new__( | |
| block_size: List[int], | ||
| dtype: torch.dtype, | ||
| act_scale=None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||
| ): | ||
| kwargs = { | ||
|
|
@@ -89,6 +89,7 @@ def __init__( | |
| block_size: List[int], | ||
| dtype: torch.dtype, | ||
| act_scale=None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -98,6 +99,7 @@ def __init__( | |
| # don't set dtype because this gets done in __new__ | ||
| self.act_quant_kwargs = act_quant_kwargs | ||
| self.act_scale = act_scale | ||
| self.act_pre_scale = act_pre_scale | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
|
|
@@ -106,6 +108,7 @@ def __repr__(self): | |
| f"qdata={self.qdata}, " | ||
| f"scale={self.scale}, " | ||
| f"act_scale={self.act_scale}, " | ||
| f"act_pre_scale={self.act_scale}, " | ||
| f"block_size={self.block_size}, " | ||
| f"shape={self.shape}, " | ||
| f"device={self.device}, " | ||
|
|
@@ -121,6 +124,7 @@ def from_hp( | |
| scale: Optional[torch.Tensor] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||
| act_scale: Optional[torch.Tensor] = None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| ): | ||
| """Create Int8Tensor from high-precision tensor""" | ||
| block_size = get_block_size(hp_tensor.shape, granularity) | ||
|
|
@@ -161,6 +165,7 @@ def from_hp( | |
| block_size, | ||
| hp_tensor.dtype, | ||
| act_scale=act_scale, | ||
| act_pre_scale=act_pre_scale, | ||
| act_quant_kwargs=act_quant_kwargs, | ||
| ) | ||
|
|
||
|
|
@@ -198,13 +203,18 @@ def _(func, types, args, kwargs): | |
|
|
||
| output_dtype = activation_tensor.dtype | ||
|
|
||
| # Apply activation pre-scaling if present (for AWQ, SmoothQuant, etc.) | ||
| if weight_tensor.act_pre_scale is not None: | ||
| activation_tensor = activation_tensor * weight_tensor.act_pre_scale | ||
|
|
||
| if weight_tensor.act_quant_kwargs is not None: | ||
| # for int8 dynamic + static quantization path | ||
|
|
||
| activation_tensor = _choose_quant_func_and_quantize_tensor( | ||
| activation_tensor, | ||
| weight_tensor.act_quant_kwargs, | ||
| scale=weight_tensor.act_scale, | ||
| ) | ||
| # Dynamic activation quantization path | ||
|
|
||
| # 1. do the matrix form of dot(X_i, W_j) | ||
| # | ||
|
|
@@ -292,6 +302,8 @@ def _(func, types, args, kwargs): | |
| block_size, | ||
| self.dtype, | ||
| act_quant_kwargs=self.act_quant_kwargs, | ||
| act_scale=self.act_scale, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess slice doesn't work for static quant int8 before, can you add a test for that? |
||
| act_pre_scale=self.act_pre_scale, | ||
| ), | ||
| ) | ||
|
|
||
|
|
@@ -322,6 +334,8 @@ def _(func, types, args, kwargs): | |
| old_int8_tensor.scale[index], | ||
| old_int8_tensor.block_size[1:], | ||
| old_int8_tensor.dtype, | ||
| old_int8_tensor.act_scale, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this one, seems like select op breaks before with static quant |
||
| old_int8_tensor.act_pre_scale, | ||
| old_int8_tensor.act_quant_kwargs, | ||
| ) | ||
| return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't have specific config here, maybe change this to a similar protocol like
SupportsActivationPreScalingfor config?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think figuring out how to do this generally will need a bit more design, we'd need to figure out how to map to the appropriate
QuantizeTensorToInt/FloatXKwargsobject. Agree we should be able to do this though, but can I address in a later PR?