@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
95
95
96
96
tensor_data_names = ["qdata" , "scale" ]
97
97
tensor_attribute_names = []
98
+ optional_tensor_data_names = ["test_only_data" ]
98
99
optional_tensor_attribute_names = [
99
100
"block_size" ,
100
101
"mm_config" ,
@@ -109,6 +110,7 @@ def __new__(
109
110
cls ,
110
111
qdata : torch .Tensor ,
111
112
scale : torch .Tensor ,
113
+ test_only_data : Optional [torch .Tensor ] = None ,
112
114
block_size : Optional [List [int ]] = None ,
113
115
mm_config : Optional [Float8MMConfig ] = None ,
114
116
hp_value_lb : Optional [float ] = None ,
@@ -128,6 +130,7 @@ def __init__(
128
130
self ,
129
131
qdata : torch .Tensor ,
130
132
scale : torch .Tensor ,
133
+ test_only_data : Optional [torch .Tensor ] = None ,
131
134
block_size : Optional [List [int ]] = None ,
132
135
mm_config : Optional [Float8MMConfig ] = None ,
133
136
hp_value_lb : Optional [float ] = None ,
@@ -138,6 +141,7 @@ def __init__(
138
141
):
139
142
self .qdata = qdata
140
143
self .scale = scale
144
+ self .test_only_data = test_only_data
141
145
self .block_size = block_size
142
146
self .mm_config = mm_config
143
147
self .hp_value_lb = hp_value_lb
0 commit comments