Skip to content

Commit fd933ea

Browse files
handle out != None
1 parent 7fdba52 commit fd933ea

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __new__(
4747
cls,
4848
tensor: torch.Tensor,
4949
):
50-
# logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5150
return torch.Tensor._make_wrapper_subclass(
5251
cls,
5352
tensor.size(),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155154
):
156155
(data,) = all_gather_outputs
157156

157+
# For epoch 1+, out=unshared param, so we need to copy data to `out``
158+
# if `self._data`` and `out` do not share the same storage.
159+
# Otherwise, if they do share the same storage, we can just return directly.
158160
if out is not None:
161+
assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}"
162+
if data.dtype == param_dtype:
163+
assert (
164+
data.untyped_storage().data_ptr()
165+
== out._data.untyped_storage().data_ptr()
166+
)
167+
else:
168+
assert out._data.dtype == param_dtype, (
169+
f"{out._data.dtype} {param_dtype}"
170+
)
171+
out._data.copy_(data)
159172
return
160173

174+
# For epoch 0, out=None, so we need to return a new ScaledGroupedMMTensor.
161175
output = ScaledGroupedMMTensor(data)
162176
inner_tensors = (data,)
163177
return output, inner_tensors

0 commit comments

Comments
 (0)