Skip to content

Commit 6ca070d

Browse files
handle out != None
1 parent 7fdba52 commit 6ca070d

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
logger.info("Using scaled_grouped_mm")
43+
#logger.info("Using scaled_grouped_mm")
4444
return _Float8GroupedMM.apply(
4545
A,
4646
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __new__(
4747
cls,
4848
tensor: torch.Tensor,
4949
):
50-
# logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5150
return torch.Tensor._make_wrapper_subclass(
5251
cls,
5352
tensor.size(),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155154
):
156155
(data,) = all_gather_outputs
157156

157+
# For training step 1+, out=unshared param, so we need to copy data to `out``
158+
# if `self._data`` and `out` do not share the same storage.
159+
# Otherwise, if they do share the same storage, we can just return directly.
158160
if out is not None:
161+
assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}"
162+
if data.dtype == param_dtype:
163+
assert (
164+
data.untyped_storage().data_ptr()
165+
== out._data.untyped_storage().data_ptr()
166+
)
167+
else:
168+
assert out._data.dtype == param_dtype, (
169+
f"{out._data.dtype} {param_dtype}"
170+
)
171+
out._data.copy_(data)
159172
return
160173

174+
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
161175
output = ScaledGroupedMMTensor(data)
162176
inner_tensors = (data,)
163177
return output, inner_tensors

0 commit comments

Comments
 (0)