@@ -60,7 +60,7 @@ def forward(
60
60
offs : Optional [torch .Tensor ] = None ,
61
61
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
62
62
) -> torch .Tensor :
63
- # torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
63
+ # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D.
64
64
assert A .ndim == 2 or A .ndim == 3 , "A must be 2D or 3D"
65
65
assert B_t .ndim == 3 , "B must be 3D"
66
66
@@ -150,17 +150,6 @@ def forward(
150
150
assert _is_column_major (B_t_fp8_col_major ), (
151
151
"B must be column-major for output = A @ B"
152
152
)
153
-
154
- # TODO: remove excessive logging once prototype is more mature.
155
- logger .debug (
156
- (
157
- f"forward scaled_grouped_mm: A_fp8_row_major.shape={ A_fp8_row_major .shape } , "
158
- f"A_scale.shape={ A_scales .squeeze (- 1 ).shape } , "
159
- f"B_t_fp8_col_major.shape={ B_t_fp8_col_major .shape } , "
160
- f"B_t_scale.shape={ B_t_scales .squeeze (1 ).shape } , "
161
- f"offs={ offs if offs is not None else None } "
162
- )
163
- )
164
153
return torch ._scaled_grouped_mm (
165
154
A_fp8_row_major ,
166
155
B_t_fp8_col_major ,
@@ -205,14 +194,6 @@ def backward(ctx, grad_output: torch.Tensor):
205
194
assert _is_column_major (B_fp8_col_major ), (
206
195
"B must be column-major for grad_A = grad_output @ B"
207
196
)
208
- logger .debug (
209
- (
210
- f"backward grad_A: grad_output_fp8_row_major.shape={ grad_output_fp8_row_major .shape } , "
211
- f"grad_output_scale.shape={ grad_output_scales .shape } , "
212
- f"B_fp8_col_major.shape={ B_fp8_col_major .shape } , "
213
- f"B_scale.shape={ B_scales .shape } , "
214
- )
215
- )
216
197
grad_A = torch ._scaled_grouped_mm (
217
198
grad_output_fp8_row_major ,
218
199
B_fp8_col_major ,
@@ -258,15 +239,6 @@ def backward(ctx, grad_output: torch.Tensor):
258
239
assert _is_column_major (A_fp8_col_major ), (
259
240
"A must be column-major for grad_B = grad_output_t @ A"
260
241
)
261
-
262
- logger .debug (
263
- (
264
- f"backward grad_B: grad_output_t_fp8_row_major.shape={ grad_output_t_fp8_row_major .shape } , "
265
- f"grad_output_t_scale.shape={ grad_output_t_scales .shape } , "
266
- f"A_fp8_col_major.shape={ A_fp8_col_major .shape } , "
267
- f"A_scale.shape={ A_scales .shape } , "
268
- )
269
- )
270
242
grad_B = torch ._scaled_grouped_mm (
271
243
grad_output_t_fp8_row_major ,
272
244
A_fp8_col_major ,
0 commit comments