Skip to content

Commit 01f7352

Browse files
[moe training] Cast to mixed precision policy param dtype in fsdp_pre_all_gather hook (#2455)
1 parent 19237e3 commit 01f7352

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _swap_params(
8484
f"Does not support a root nn.Parameter with children: {module}"
8585
)
8686
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
87+
new_data = ScaledGroupedMMTensor(module.data)
8888
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8989
return module
9090

@@ -110,7 +110,7 @@ def post_order_traversal(
110110
for param_name, param in module.named_parameters(recurse=False):
111111
if not isinstance(param.data, ScaledGroupedMMTensor):
112112
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param.data, param.data.dtype),
113+
ScaledGroupedMMTensor(param.data),
114114
requires_grad=param.requires_grad,
115115
)
116116
setattr(module, param_name, new_param)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Optional
89

910
import torch
@@ -18,6 +19,8 @@
1819
_is_column_major,
1920
)
2021

22+
logger: logging.Logger = logging.getLogger(__name__)
23+
2124

2225
def _scaled_grouped_mm(
2326
A: torch.Tensor,
@@ -36,8 +39,8 @@ def _scaled_grouped_mm(
3639
and in column-major memory layout.
3740
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
3841
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
39-
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
4042
"""
43+
# logger.info("Using scaled_grouped_mm")
4144
return _Float8GroupedMM.apply(
4245
A,
4346
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
import torch
1111
import torch.utils._pytree as pytree
12+
from torch import nn
1213
from torch._prims_common import suggest_memory_format
14+
from torch.distributed.device_mesh import DeviceMesh
15+
from torch.distributed.fsdp import MixedPrecisionPolicy
1316

1417
from torchao.prototype.moe_training import _scaled_grouped_mm
1518

1619
logger: logging.Logger = logging.getLogger(__name__)
1720

18-
1921
_ops_to_preserve_subclass = {
2022
torch.ops.aten.empty_like.default,
2123
torch.ops.aten.new_zeros.default,
@@ -44,15 +46,14 @@ class ScaledGroupedMMTensor(torch.Tensor):
4446
def __new__(
4547
cls,
4648
tensor: torch.Tensor,
47-
dtype: torch.dtype,
4849
):
4950
return torch.Tensor._make_wrapper_subclass(
5051
cls,
5152
tensor.size(),
5253
strides=tensor.stride(),
5354
storage_offset=tensor.storage_offset(),
5455
memory_format=suggest_memory_format(tensor),
55-
dtype=dtype,
56+
dtype=tensor.dtype,
5657
layout=tensor.layout,
5758
device=tensor.device,
5859
pin_memory=tensor.is_pinned(),
@@ -62,14 +63,11 @@ def __new__(
6263
def __init__(
6364
self,
6465
tensor: torch.Tensor,
65-
dtype: torch.dtype,
6666
):
6767
self._data = tensor
68-
self._dtype = dtype
6968

7069
@classmethod
7170
def __torch_function__(cls, func, types, args, kwargs={}):
72-
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
7371
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7472
if func.__name__ == cls.grouped_mm_func_name:
7573
# Use torchao scaled grouped mm with dynamic quant for
@@ -98,19 +96,10 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9896
def __torch_dispatch__(cls, func, types, args, kwargs={}):
9997
# detach is special case
10098
if func == torch.ops.aten.detach.default:
101-
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
102-
103-
# unwrap args and kwargs
104-
dtype: Optional[torch.dtype] = None
105-
106-
def unwrap(t):
107-
nonlocal dtype
108-
if dtype is None:
109-
dtype = t._dtype
110-
else:
111-
assert t._dtype == dtype
112-
return t._data
99+
return ScaledGroupedMMTensor(args[0]._data)
113100

101+
# unwrap args/kwargs
102+
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
114103
args, kwargs = pytree.tree_map_only(
115104
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
116105
)
@@ -125,25 +114,33 @@ def unwrap(t):
125114
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
126115
return pytree.tree_map_only(
127116
torch.Tensor,
128-
lambda x: ScaledGroupedMMTensor(x, dtype),
117+
lambda x: ScaledGroupedMMTensor(x),
129118
out,
130119
)
131120

132121
def __repr__(self):
133-
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
122+
return f"ScaledGroupedMMTensor(data={self._data})"
134123

135124
def __tensor_flatten__(self):
136-
return ["_data"], {"_dtype": self._dtype}
125+
return ["_data"]
137126

138127
@staticmethod
139128
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
140129
return ScaledGroupedMMTensor(
141130
inner_tensors["_data"],
142-
flatten_spec["_dtype"],
143131
)
144132

145-
def fsdp_pre_all_gather(self, mesh):
146-
all_gather_inputs = (self._data,)
133+
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
134+
def fsdp_pre_all_gather(
135+
self,
136+
mesh: DeviceMesh,
137+
outer_size: torch.Size,
138+
outer_stride: tuple[int, ...],
139+
module: nn.Module,
140+
mp_policy: MixedPrecisionPolicy,
141+
):
142+
# cast to mixed precision dtype prior to all-gather
143+
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
147144
all_gather_metadata = ()
148145
return all_gather_inputs, all_gather_metadata
149146

@@ -156,6 +153,25 @@ def fsdp_post_all_gather(
156153
out: Optional[torch.Tensor] = None,
157154
):
158155
(data,) = all_gather_outputs
159-
output = ScaledGroupedMMTensor(data, param_dtype)
156+
157+
# For training step 1+, out=unsharded param, so we need to copy data to `out`
158+
# if `self._data`` and `out` do not share the same storage.
159+
# Otherwise, if they do share the same storage, we can just return directly.
160+
if out is not None:
161+
assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}"
162+
if data.dtype == param_dtype:
163+
assert (
164+
data.untyped_storage().data_ptr()
165+
== out._data.untyped_storage().data_ptr()
166+
)
167+
else:
168+
assert out._data.dtype == param_dtype, (
169+
f"{out._data.dtype} {param_dtype}"
170+
)
171+
out._data.copy_(data)
172+
return
173+
174+
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
175+
output = ScaledGroupedMMTensor(data)
160176
inner_tensors = (data,)
161177
return output, inner_tensors

0 commit comments

Comments
 (0)