1414 is_preset_scheme ,
1515 preset_name_to_scheme ,
1616)
17+ from compressed_tensors .utils import match_named_modules
1718from pydantic import Field , PrivateAttr , field_validator
1819from torch .utils .hooks import RemovableHandle
1920
@@ -121,12 +122,15 @@ def initialize_quantization(self, model: torch.nn.Module):
121122
122123 :param model: model to attach schemes and observers to
123124 """
124- reset_quantization_status (model ) # reset any previously applied qconfigs
125-
126125 # apply scheme and status to model
127126 config = self .resolve_quantization_config ()
127+
128+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
129+ reset_quantization_status (module ) # reset any previously applied qconfigs
130+
128131 apply_quantization_config (model , config )
129132
133+ # TODO should we disable for entire model or just matching modules?
130134 # disable quantization until calibration
131135 model .apply (disable_quantization )
132136
@@ -138,8 +142,11 @@ def start_calibration(self, model: torch.nn.Module):
138142 :param model: model to prepare for calibration
139143 """
140144 self ._calibration_hooks = self ._initialize_hooks (model )
141- model .apply (self ._initialize_observers )
142- model .apply (apply_calibration_status )
145+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
146+ self ._initialize_observers (module )
147+ apply_calibration_status (module )
148+
149+ # TODO should we disable for entire model or just matching modules?
143150 model .apply (enable_quantization ) # quantize at the same time as calibrate
144151
145152 def end_calibration (self , model : torch .nn .Module ):
@@ -150,7 +157,9 @@ def end_calibration(self, model: torch.nn.Module):
150157 :param model: model to end calibration for
151158 """
152159 self .remove_hooks (self ._calibration_hooks )
153- model .apply (freeze_module_quantization ) # remove observers
160+ for _ , module in match_named_modules (model , self .targets , self .ignore ):
161+ freeze_module_quantization (module ) # remove observers
162+
154163 model .apply (enable_quantization ) # keep quantization enabled
155164
156165 def has_config (self ) -> bool :
@@ -240,7 +249,7 @@ def _initialize_observers(self, module: torch.nn.Module):
240249
241250 def _initialize_hooks (self , model : torch .nn .Module ) -> Set [RemovableHandle ]:
242251 hooks = set ()
243- for module in model . modules ( ):
252+ for _ , module in match_named_modules ( model , self . targets , self . ignore ):
244253 if not hasattr (module , "quantization_scheme" ):
245254 continue
246255
0 commit comments