@@ -115,7 +115,7 @@ def load_pretrained_quantization_parameters(
115115
116116def apply_quantization_config (
117117 model : Module , config : Union [QuantizationConfig , None ], run_compressed : bool = False
118- ) -> Dict [ str , QuantizationScheme ] :
118+ ):
119119 """
120120 Initializes the model for quantization in-place based on the given config.
121121 Optionally coverts quantizable modules to compressed_linear modules
@@ -125,26 +125,22 @@ def apply_quantization_config(
125125 :param run_compressed: Whether the model will be run in compressed mode or
126126 decompressed fully on load
127127 """
128- # Workaround for when HF Quantizer passes None, see PR #180
129- if config is None :
130- return dict ()
128+ from compressed_tensors .linear .compressed_linear import CompressedLinear
131129
132- # remove reference to the original `config`
133- # argument. This function can mutate it, and we'd
134- # like to keep the original `config` as it is.
135130 config = deepcopy (config )
131+ if config is None : # see PR #180
132+ return dict ()
133+
134+ # preprocess to support kv cache scheme
135+ config = process_quantization_config (config )
136+
136137 # build mapping of targets to schemes for easier matching
137138 # use ordered dict to preserve target ordering in config
138139 target_to_scheme = OrderedDict ()
139- config = process_quantization_config (config )
140- names_to_scheme = dict ()
141140 for scheme in config .config_groups .values ():
142141 for target in scheme .targets :
143142 target_to_scheme [target ] = scheme
144143
145- if run_compressed :
146- from compressed_tensors .linear .compressed_linear import CompressedLinear
147-
148144 # mark appropriate layers for quantization by setting their quantization schemes
149145 for name , submodule in match_named_modules (
150146 model , target_to_scheme , config .ignore , warn_on_fail = True
@@ -153,7 +149,12 @@ def apply_quantization_config(
153149 # quant scheme to the matching layers
154150 matched_targets = match_targets (name , submodule , target_to_scheme )
155151 scheme = _scheme_from_targets (target_to_scheme , matched_targets , name )
156- if run_compressed :
152+ # target matched - add layer and scheme to target list
153+ submodule .quantization_scheme = scheme
154+
155+ # replace with run compressed if applicable
156+ # FUTURE: move this to model compressor
157+ if isinstance (submodule , torch .nn .Linear ) and run_compressed :
157158 format = config .format
158159 if format != CompressionFormat .dense .value :
159160 if isinstance (submodule , torch .nn .Linear ):
@@ -165,14 +166,8 @@ def apply_quantization_config(
165166 )
166167 replace_module (model , name , compressed_linear )
167168
168- # target matched - add layer and scheme to target list
169- submodule .quantization_scheme = scheme
170-
171- names_to_scheme [name ] = submodule .quantization_scheme
172-
173169 # apply current quantization status across all targeted layers
174170 apply_quantization_status (model , config .quantization_status )
175- return names_to_scheme
176171
177172
178173def process_quantization_config (config : QuantizationConfig ) -> QuantizationConfig :
0 commit comments