File tree Expand file tree Collapse file tree 2 files changed +16
-2
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -40,7 +40,7 @@ def _scaled_grouped_mm(
40
40
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
41
41
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42
42
"""
43
- logger .info ("Using scaled_grouped_mm" )
43
+ # logger.info("Using scaled_grouped_mm")
44
44
return _Float8GroupedMM .apply (
45
45
A ,
46
46
B_t ,
Original file line number Diff line number Diff line change @@ -47,7 +47,6 @@ def __new__(
47
47
cls ,
48
48
tensor : torch .Tensor ,
49
49
):
50
- # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
51
50
return torch .Tensor ._make_wrapper_subclass (
52
51
cls ,
53
52
tensor .size (),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155
154
):
156
155
(data ,) = all_gather_outputs
157
156
157
+ # For training step 1+, out=unsharded param, so we need to copy data to `out`
158
+ # if `self._data`` and `out` do not share the same storage.
159
+ # Otherwise, if they do share the same storage, we can just return directly.
158
160
if out is not None :
161
+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
162
+ if data .dtype == param_dtype :
163
+ assert (
164
+ data .untyped_storage ().data_ptr ()
165
+ == out ._data .untyped_storage ().data_ptr ()
166
+ )
167
+ else :
168
+ assert out ._data .dtype == param_dtype , (
169
+ f"{ out ._data .dtype } { param_dtype } "
170
+ )
171
+ out ._data .copy_ (data )
159
172
return
160
173
174
+ # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
161
175
output = ScaledGroupedMMTensor (data )
162
176
inner_tensors = (data ,)
163
177
return output , inner_tensors
You can’t perform that action at this time.
0 commit comments