13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from collections import defaultdict
17
- from typing import List , Optional , Set , Tuple
16
+ from typing import List , Optional , Set
18
17
19
18
import torch
20
19
import torch .nn .utils .parametrize as P
@@ -101,8 +100,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
101
100
for module , arg in tqdm .tqdm (modules_args , desc = desc , disable = (not use_tqdm )):
102
101
self ._apply_to_module (module , arg )
103
102
104
- self ._update_tied_weights ()
105
-
106
103
def _apply_to_module (self , module : Module , args : TransformArgs ):
107
104
"""
108
105
Create transforms and apply them to the module
@@ -165,31 +162,6 @@ def output_hook(_, _input, output):
165
162
else :
166
163
raise NotImplementedError ()
167
164
168
- def _update_tied_weights (self ):
169
- """
170
- Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171
- which is used by transformers to detect and remove shared pointers
172
- during saving
173
- """
174
- # map from data_ptrs to keys
175
- ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
176
- for transform in self .transforms :
177
- for name , param in transform .named_parameters (recurse = False ):
178
- # NOTE: previously asserted that parent._hf_hook.place_submodules=False
179
- if has_offloaded_params (transform ):
180
- param = transform ._hf_hook .weights_map [name ]
181
- ptr_to_keys [param .data_ptr ()].append ((transform , name ))
182
-
183
- # populate `_dynamic_tied_weights_keys` if there is more than one key
184
- # and ensure that they share tensors
185
- for shared_keys in ptr_to_keys .values ():
186
- if len (shared_keys ) > 1 :
187
- tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
188
-
189
- for transform , name in shared_keys :
190
- transform ._dynamic_tied_weights_keys .add (name )
191
- setattr (transform , name , tensor )
192
-
193
165
194
166
class TransformBase (InternalModule , ABC ):
195
167
"""
@@ -198,11 +170,7 @@ class TransformBase(InternalModule, ABC):
198
170
199
171
args : TransformArgs
200
172
weight : Parameter
201
- _dynamic_tied_weights_keys : Set [str ]
202
-
203
- def __init__ (self ):
204
- super ().__init__ ()
205
- self ._dynamic_tied_weights_keys = set ()
173
+ _dynamic_tied_weights_keys : List [str ] = ["weight" ]
206
174
207
175
@abstractmethod
208
176
def forward (self , value : Tensor ) -> Tensor :
0 commit comments