Skip to content

Commit 074b423

Browse files
remove excessive logging
1 parent a80b9a0 commit 074b423

File tree

2 files changed

+1
-34
lines changed

2 files changed

+1
-34
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(
6060
offs: Optional[torch.Tensor] = None,
6161
out_dtype: Optional[torch.dtype] = torch.bfloat16,
6262
) -> torch.Tensor:
63-
# torchao _scaled_grouped_mm only supports A=2D|3D + B=3D.
63+
# torchao _scaled_grouped_mm only supports A=2D|3D and B=3D.
6464
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
6565
assert B_t.ndim == 3, "B must be 3D"
6666

@@ -150,17 +150,6 @@ def forward(
150150
assert _is_column_major(B_t_fp8_col_major), (
151151
"B must be column-major for output = A @ B"
152152
)
153-
154-
# TODO: remove excessive logging once prototype is more mature.
155-
logger.debug(
156-
(
157-
f"forward scaled_grouped_mm: A_fp8_row_major.shape={A_fp8_row_major.shape}, "
158-
f"A_scale.shape={A_scales.squeeze(-1).shape}, "
159-
f"B_t_fp8_col_major.shape={B_t_fp8_col_major.shape}, "
160-
f"B_t_scale.shape={B_t_scales.squeeze(1).shape}, "
161-
f"offs={offs if offs is not None else None}"
162-
)
163-
)
164153
return torch._scaled_grouped_mm(
165154
A_fp8_row_major,
166155
B_t_fp8_col_major,
@@ -205,14 +194,6 @@ def backward(ctx, grad_output: torch.Tensor):
205194
assert _is_column_major(B_fp8_col_major), (
206195
"B must be column-major for grad_A = grad_output @ B"
207196
)
208-
logger.debug(
209-
(
210-
f"backward grad_A: grad_output_fp8_row_major.shape={grad_output_fp8_row_major.shape}, "
211-
f"grad_output_scale.shape={grad_output_scales.shape}, "
212-
f"B_fp8_col_major.shape={B_fp8_col_major.shape}, "
213-
f"B_scale.shape={B_scales.shape}, "
214-
)
215-
)
216197
grad_A = torch._scaled_grouped_mm(
217198
grad_output_fp8_row_major,
218199
B_fp8_col_major,
@@ -258,15 +239,6 @@ def backward(ctx, grad_output: torch.Tensor):
258239
assert _is_column_major(A_fp8_col_major), (
259240
"A must be column-major for grad_B = grad_output_t @ A"
260241
)
261-
262-
logger.debug(
263-
(
264-
f"backward grad_B: grad_output_t_fp8_row_major.shape={grad_output_t_fp8_row_major.shape}, "
265-
f"grad_output_t_scale.shape={grad_output_t_scales.shape}, "
266-
f"A_fp8_col_major.shape={A_fp8_col_major.shape}, "
267-
f"A_scale.shape={A_scales.shape}, "
268-
)
269-
)
270242
grad_B = torch._scaled_grouped_mm(
271243
grad_output_t_fp8_row_major,
272244
A_fp8_col_major,

torchao/prototype/moe_training/tensor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
from torchao.prototype.moe_training import _scaled_grouped_mm
1515

16-
logger: logging.Logger = logging.getLogger(__name__)
17-
1816
_ops_to_preserve_subclass = {
1917
torch.ops.aten.empty_like.default,
2018
torch.ops.aten.new_zeros.default,
@@ -77,9 +75,6 @@ def __torch_function__(cls, func, types, args, kwargs={}):
7775
A, B = args[0], args[1]
7876
A_is_2d_or_3d = A.dim() in (2, 3)
7977
B_is_3d = B.dim() == 3
80-
has_offs = kwargs.get(cls.offs_arg_name) is not None
81-
logger.debug(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}")
82-
8378
if A_is_2d_or_3d and B_is_3d:
8479
return _scaled_grouped_mm(
8580
*args,

0 commit comments

Comments
 (0)