Skip to content

Commit c5db9f0

Browse files
committed
[test only] testing adding optioanl tensor arg to float8 tensor
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2840, branch: jerryzh168/stack/33
1 parent 1dd6d07 commit c5db9f0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
9595

9696
tensor_data_names = ["qdata", "scale"]
9797
tensor_attribute_names = []
98+
optional_tensor_data_names = ["test_only_data"]
9899
optional_tensor_attribute_names = [
99100
"block_size",
100101
"mm_config",
@@ -109,6 +110,7 @@ def __new__(
109110
cls,
110111
qdata: torch.Tensor,
111112
scale: torch.Tensor,
113+
test_only_data: Optional[torch.Tensor] = None,
112114
block_size: Optional[List[int]] = None,
113115
mm_config: Optional[Float8MMConfig] = None,
114116
hp_value_lb: Optional[float] = None,
@@ -128,6 +130,7 @@ def __init__(
128130
self,
129131
qdata: torch.Tensor,
130132
scale: torch.Tensor,
133+
test_only_data: Optional[torch.Tensor] = None,
131134
block_size: Optional[List[int]] = None,
132135
mm_config: Optional[Float8MMConfig] = None,
133136
hp_value_lb: Optional[float] = None,
@@ -138,6 +141,7 @@ def __init__(
138141
):
139142
self.qdata = qdata
140143
self.scale = scale
144+
self.test_only_data = test_only_data
141145
self.block_size = block_size
142146
self.mm_config = mm_config
143147
self.hp_value_lb = hp_value_lb

0 commit comments

Comments
 (0)