File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,6 @@ def __new__(
47
47
cls ,
48
48
tensor : torch .Tensor ,
49
49
):
50
- # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
51
50
return torch .Tensor ._make_wrapper_subclass (
52
51
cls ,
53
52
tensor .size (),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155
154
):
156
155
(data ,) = all_gather_outputs
157
156
157
+ # For epoch 1+, out=unshared param, so we need to copy data to `out``
158
+ # if `self._data`` and `out` do not share the same storage.
159
+ # Otherwise, if they do share the same storage, we can just return directly.
158
160
if out is not None :
161
+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
162
+ if data .dtype == param_dtype :
163
+ assert (
164
+ data .untyped_storage ().data_ptr ()
165
+ == out ._data .untyped_storage ().data_ptr ()
166
+ )
167
+ else :
168
+ assert out ._data .dtype == param_dtype , (
169
+ f"{ out ._data .dtype } { param_dtype } "
170
+ )
171
+ out ._data .copy_ (data )
159
172
return
160
173
174
+ # For epoch 0, out=None, so we need to return a new ScaledGroupedMMTensor.
161
175
output = ScaledGroupedMMTensor (data )
162
176
inner_tensors = (data ,)
163
177
return output , inner_tensors
You can’t perform that action at this time.
0 commit comments