diff --git a/ai_edge_quantizer/params_generator.py b/ai_edge_quantizer/params_generator.py index 14cc9ea6..8318d311 100644 --- a/ai_edge_quantizer/params_generator.py +++ b/ai_edge_quantizer/params_generator.py @@ -75,12 +75,10 @@ def generate_quantization_parameters( if model_qsvs is None: model_qsvs = {} - skip_subgraphs = set() + float_subgraphs = set() op_codes = self.flatbuffer_model.operatorCodes for sg_ind, subgraph in enumerate(self.flatbuffer_model.subgraphs): - if sg_ind in skip_subgraphs: - continue - + no_quantize_this_subgraph = sg_ind in float_subgraphs graph_info = qtyping.GraphInfo( subgraph.tensors, self.flatbuffer_model.buffers ) @@ -109,14 +107,20 @@ def generate_quantization_parameters( algorithm_name, op_quant_config = ( model_recipe_manager.get_quantization_configs(op_key, op_scope) ) - if policy.is_non_quantizable_composite_op(op): + if ( + policy.is_non_quantizable_composite_op(op) + or no_quantize_this_subgraph + ): algorithm_name = algorithm_manager.AlgorithmName.NO_QUANTIZE if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE: - side_effect_subgraphs = ( - tfl_flatbuffer_utils.get_op_side_effect_subgraphs(op) - ) - skip_subgraphs.update(side_effect_subgraphs) + # Add side effect subgraphs to the float subgraphs + # if the op is not no quantized. + if not no_quantize_this_subgraph: + side_effect_subgraphs = ( + tfl_flatbuffer_utils.get_op_side_effect_subgraphs(op) + ) + float_subgraphs.update(side_effect_subgraphs) op_quant_results = self._get_params_for_no_quant_op( subgraph_op_id, op, subgraph.tensors