9
9
10
10
import torch
11
11
import torch .utils ._pytree as pytree
12
+ from torch import nn
12
13
from torch ._prims_common import suggest_memory_format
14
+ from torch .distributed .device_mesh import DeviceMesh
15
+ from torch .distributed .fsdp import MixedPrecisionPolicy
13
16
14
17
from torchao .prototype .moe_training import _scaled_grouped_mm
15
18
16
19
logger : logging .Logger = logging .getLogger (__name__ )
17
20
18
-
19
21
_ops_to_preserve_subclass = {
20
22
torch .ops .aten .empty_like .default ,
21
23
torch .ops .aten .new_zeros .default ,
@@ -44,15 +46,14 @@ class ScaledGroupedMMTensor(torch.Tensor):
44
46
def __new__ (
45
47
cls ,
46
48
tensor : torch .Tensor ,
47
- dtype : torch .dtype ,
48
49
):
49
50
return torch .Tensor ._make_wrapper_subclass (
50
51
cls ,
51
52
tensor .size (),
52
53
strides = tensor .stride (),
53
54
storage_offset = tensor .storage_offset (),
54
55
memory_format = suggest_memory_format (tensor ),
55
- dtype = dtype ,
56
+ dtype = tensor . dtype ,
56
57
layout = tensor .layout ,
57
58
device = tensor .device ,
58
59
pin_memory = tensor .is_pinned (),
@@ -62,14 +63,11 @@ def __new__(
62
63
def __init__ (
63
64
self ,
64
65
tensor : torch .Tensor ,
65
- dtype : torch .dtype ,
66
66
):
67
67
self ._data = tensor
68
- self ._dtype = dtype
69
68
70
69
@classmethod
71
70
def __torch_function__ (cls , func , types , args , kwargs = {}):
72
- logger .info (f"{ func .__name__ } , args: { args } , kwargs: { kwargs } " )
73
71
# override the grouped mm op to use the differentiable _scaled_grouped_mm
74
72
if func .__name__ == cls .grouped_mm_func_name :
75
73
# Use torchao scaled grouped mm with dynamic quant for
@@ -98,19 +96,10 @@ def __torch_function__(cls, func, types, args, kwargs={}):
98
96
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
99
97
# detach is special case
100
98
if func == torch .ops .aten .detach .default :
101
- return ScaledGroupedMMTensor (args [0 ]._data , args [0 ]._dtype )
102
-
103
- # unwrap args and kwargs
104
- dtype : Optional [torch .dtype ] = None
105
-
106
- def unwrap (t ):
107
- nonlocal dtype
108
- if dtype is None :
109
- dtype = t ._dtype
110
- else :
111
- assert t ._dtype == dtype
112
- return t ._data
99
+ return ScaledGroupedMMTensor (args [0 ]._data )
113
100
101
+ # unwrap args/kwargs
102
+ unwrap = lambda x : x ._data if isinstance (x , ScaledGroupedMMTensor ) else x
114
103
args , kwargs = pytree .tree_map_only (
115
104
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
116
105
)
@@ -125,25 +114,33 @@ def unwrap(t):
125
114
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
126
115
return pytree .tree_map_only (
127
116
torch .Tensor ,
128
- lambda x : ScaledGroupedMMTensor (x , dtype ),
117
+ lambda x : ScaledGroupedMMTensor (x ),
129
118
out ,
130
119
)
131
120
132
121
def __repr__ (self ):
133
- return f"ScaledGroupedMMTensor(data={ self ._data } , dtype= { self . _dtype } )"
122
+ return f"ScaledGroupedMMTensor(data={ self ._data } )"
134
123
135
124
def __tensor_flatten__ (self ):
136
- return ["_data" ], { "_dtype" : self . _dtype }
125
+ return ["_data" ]
137
126
138
127
@staticmethod
139
128
def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
140
129
return ScaledGroupedMMTensor (
141
130
inner_tensors ["_data" ],
142
- flatten_spec ["_dtype" ],
143
131
)
144
132
145
- def fsdp_pre_all_gather (self , mesh ):
146
- all_gather_inputs = (self ._data ,)
133
+ # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
134
+ def fsdp_pre_all_gather (
135
+ self ,
136
+ mesh : DeviceMesh ,
137
+ outer_size : torch .Size ,
138
+ outer_stride : tuple [int , ...],
139
+ module : nn .Module ,
140
+ mp_policy : MixedPrecisionPolicy ,
141
+ ):
142
+ # cast to mixed precision dtype prior to all-gather
143
+ all_gather_inputs = (self ._data .to (mp_policy .param_dtype ),)
147
144
all_gather_metadata = ()
148
145
return all_gather_inputs , all_gather_metadata
149
146
@@ -156,6 +153,25 @@ def fsdp_post_all_gather(
156
153
out : Optional [torch .Tensor ] = None ,
157
154
):
158
155
(data ,) = all_gather_outputs
159
- output = ScaledGroupedMMTensor (data , param_dtype )
156
+
157
+ # For training step 1+, out=unsharded param, so we need to copy data to `out`
158
+ # if `self._data`` and `out` do not share the same storage.
159
+ # Otherwise, if they do share the same storage, we can just return directly.
160
+ if out is not None :
161
+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
162
+ if data .dtype == param_dtype :
163
+ assert (
164
+ data .untyped_storage ().data_ptr ()
165
+ == out ._data .untyped_storage ().data_ptr ()
166
+ )
167
+ else :
168
+ assert out ._data .dtype == param_dtype , (
169
+ f"{ out ._data .dtype } { param_dtype } "
170
+ )
171
+ out ._data .copy_ (data )
172
+ return
173
+
174
+ # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
175
+ output = ScaledGroupedMMTensor (data )
160
176
inner_tensors = (data ,)
161
177
return output , inner_tensors
0 commit comments