@@ -46,16 +46,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
46
46
def __new__ (
47
47
cls ,
48
48
tensor : torch .Tensor ,
49
- dtype : torch .dtype ,
50
49
):
51
- logger .info (f"ScaledGroupedMMTensor __new__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
50
+ # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
52
51
return torch .Tensor ._make_wrapper_subclass (
53
52
cls ,
54
53
tensor .size (),
55
54
strides = tensor .stride (),
56
55
storage_offset = tensor .storage_offset (),
57
56
memory_format = suggest_memory_format (tensor ),
58
- dtype = dtype ,
57
+ dtype = tensor . dtype ,
59
58
layout = tensor .layout ,
60
59
device = tensor .device ,
61
60
pin_memory = tensor .is_pinned (),
@@ -65,15 +64,11 @@ def __new__(
65
64
def __init__ (
66
65
self ,
67
66
tensor : torch .Tensor ,
68
- dtype : torch .dtype ,
69
67
):
70
- logger .info (f"ScaledGroupedMMTensor __init__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
71
- self ._data = tensor .to (dtype )
72
- self ._dtype = dtype
68
+ self ._data = tensor
73
69
74
70
@classmethod
75
71
def __torch_function__ (cls , func , types , args , kwargs = {}):
76
- logger .info (f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } " )
77
72
# override the grouped mm op to use the differentiable _scaled_grouped_mm
78
73
if func .__name__ == cls .grouped_mm_func_name :
79
74
# Use torchao scaled grouped mm with dynamic quant for
@@ -102,7 +97,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
102
97
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
103
98
# detach is special case
104
99
if func == torch .ops .aten .detach .default :
105
- return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
100
+ return ScaledGroupedMMTensor (args [0 ]._data )
106
101
107
102
# unwrap args/kwargs
108
103
unwrap = lambda x : x ._data if isinstance (x , ScaledGroupedMMTensor ) else x
@@ -120,21 +115,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
120
115
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
121
116
return pytree .tree_map_only (
122
117
torch .Tensor ,
123
- lambda x : ScaledGroupedMMTensor (x , x . dtype ),
118
+ lambda x : ScaledGroupedMMTensor (x ),
124
119
out ,
125
120
)
126
121
127
122
def __repr__ (self ):
128
- return f"ScaledGroupedMMTensor(data={ self ._data } , dtype= { self . _dtype } )"
123
+ return f"ScaledGroupedMMTensor(data={ self ._data } )"
129
124
130
125
def __tensor_flatten__ (self ):
131
- return ["_data" ], { "_dtype" : self . _dtype }
126
+ return ["_data" ]
132
127
133
128
@staticmethod
134
129
def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
135
130
return ScaledGroupedMMTensor (
136
131
inner_tensors ["_data" ],
137
- flatten_spec ["_dtype" ],
138
132
)
139
133
140
134
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -146,9 +140,9 @@ def fsdp_pre_all_gather(
146
140
module : nn .Module ,
147
141
mp_policy : MixedPrecisionPolicy ,
148
142
):
149
- all_gather_inputs = (self ._data ,)
143
+ # cast to mixed precision dtype prior to all-gather
144
+ all_gather_inputs = (self ._data .to (mp_policy .param_dtype ),)
150
145
all_gather_metadata = ()
151
- #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
152
146
return all_gather_inputs , all_gather_metadata
153
147
154
148
def fsdp_post_all_gather (
@@ -160,11 +154,10 @@ def fsdp_post_all_gather(
160
154
out : Optional [torch .Tensor ] = None ,
161
155
):
162
156
(data ,) = all_gather_outputs
163
- #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
164
157
165
158
if out is not None :
166
159
return
167
160
168
- output = ScaledGroupedMMTensor (data , param_dtype )
161
+ output = ScaledGroupedMMTensor (data )
169
162
inner_tensors = (data ,)
170
163
return output , inner_tensors
0 commit comments