Skip to content

Commit 7fdba52

Browse files
don't have dtype param
1 parent 5360aad commit 7fdba52

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _swap_params(
8484
f"Does not support a root nn.Parameter with children: {module}"
8585
)
8686
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
87+
new_data = ScaledGroupedMMTensor(module.data)
8888
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8989
return module
9090

@@ -110,7 +110,7 @@ def post_order_traversal(
110110
for param_name, param in module.named_parameters(recurse=False):
111111
if not isinstance(param.data, ScaledGroupedMMTensor):
112112
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param.data, param.data.dtype),
113+
ScaledGroupedMMTensor(param.data),
114114
requires_grad=param.requires_grad,
115115
)
116116
setattr(module, param_name, new_param)

torchao/prototype/moe_training/tensor.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
4646
def __new__(
4747
cls,
4848
tensor: torch.Tensor,
49-
dtype: torch.dtype,
5049
):
51-
logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
50+
# logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5251
return torch.Tensor._make_wrapper_subclass(
5352
cls,
5453
tensor.size(),
5554
strides=tensor.stride(),
5655
storage_offset=tensor.storage_offset(),
5756
memory_format=suggest_memory_format(tensor),
58-
dtype=dtype,
57+
dtype=tensor.dtype,
5958
layout=tensor.layout,
6059
device=tensor.device,
6160
pin_memory=tensor.is_pinned(),
@@ -65,15 +64,11 @@ def __new__(
6564
def __init__(
6665
self,
6766
tensor: torch.Tensor,
68-
dtype: torch.dtype,
6967
):
70-
logger.info(f"ScaledGroupedMMTensor __init__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
71-
self._data = tensor.to(dtype)
72-
self._dtype = dtype
68+
self._data = tensor
7369

7470
@classmethod
7571
def __torch_function__(cls, func, types, args, kwargs={}):
76-
logger.info(f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}")
7772
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7873
if func.__name__ == cls.grouped_mm_func_name:
7974
# Use torchao scaled grouped mm with dynamic quant for
@@ -102,7 +97,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
10297
def __torch_dispatch__(cls, func, types, args, kwargs={}):
10398
# detach is special case
10499
if func == torch.ops.aten.detach.default:
105-
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
100+
return ScaledGroupedMMTensor(args[0]._data)
106101

107102
# unwrap args/kwargs
108103
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
@@ -120,21 +115,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
120115
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
121116
return pytree.tree_map_only(
122117
torch.Tensor,
123-
lambda x: ScaledGroupedMMTensor(x, x.dtype),
118+
lambda x: ScaledGroupedMMTensor(x),
124119
out,
125120
)
126121

127122
def __repr__(self):
128-
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
123+
return f"ScaledGroupedMMTensor(data={self._data})"
129124

130125
def __tensor_flatten__(self):
131-
return ["_data"], {"_dtype": self._dtype}
126+
return ["_data"]
132127

133128
@staticmethod
134129
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
135130
return ScaledGroupedMMTensor(
136131
inner_tensors["_data"],
137-
flatten_spec["_dtype"],
138132
)
139133

140134
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -146,9 +140,9 @@ def fsdp_pre_all_gather(
146140
module: nn.Module,
147141
mp_policy: MixedPrecisionPolicy,
148142
):
149-
all_gather_inputs = (self._data,)
143+
# cast to mixed precision dtype prior to all-gather
144+
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
150145
all_gather_metadata = ()
151-
#logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
152146
return all_gather_inputs, all_gather_metadata
153147

154148
def fsdp_post_all_gather(
@@ -160,11 +154,10 @@ def fsdp_post_all_gather(
160154
out: Optional[torch.Tensor] = None,
161155
):
162156
(data,) = all_gather_outputs
163-
#logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
164157

165158
if out is not None:
166159
return
167160

168-
output = ScaledGroupedMMTensor(data, param_dtype)
161+
output = ScaledGroupedMMTensor(data)
169162
inner_tensors = (data,)
170163
return output, inner_tensors

0 commit comments

Comments
 (0)