Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 152 additions & 65 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,25 @@ def get_symmetric_quantization_config(
act_qmax: int = 127,
weight_qmin: int = -127,
weight_qmax: int = 127,
):
) -> QuantizationConfig:
"""Create symmetric quantization config for activations and weights.

Args:
is_per_channel (bool): Whether to use per-channel quantization for
weights.
is_qat (bool): Whether the configuration targets quantization aware
training.
is_dynamic (bool): Whether to generate dynamic activation observers.
act_qmin (int): Minimum activation quantization value.
act_qmax (int): Maximum activation quantization value.
weight_qmin (int): Minimum weight quantization value.
weight_qmax (int): Maximum weight quantization value.

Returns:
QuantizationConfig: Quantization settings for activations, weights, and
bias.

"""
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if is_dynamic:
Expand Down Expand Up @@ -169,23 +187,26 @@ def get_symmetric_a16w8_quantization_config(
weight_qmin: int = -127,
weight_qmax: int = 127,
epsilon: float = 2**-12,
):
"""
16A8W quantization config: 16-bit activations, 8-bit weights.
) -> QuantizationConfig:
"""16A8W quantization config: 16-bit activations, 8-bit weights.

This configuration provides better accuracy than 8A8W while maintaining
reasonable memory usage through 8-bit weights.

Args:
is_per_channel: Whether to use per-channel quantization for weights
is_qat: Whether this is for Quantization Aware Training
is_dynamic: Whether to use dynamic quantization
weight_qmin: Minimum quantization value for weights
weight_qmax: Maximum quantization value for weights
epsilon: Value used to pad observed [qmin, qmax] before initial zero point and scale calculation
is_per_channel (bool): Whether to use per-channel quantization for
weights.
is_qat (bool): Whether this is for quantization aware training.
is_dynamic (bool): Whether to use dynamic quantization.
weight_qmin (int): Minimum quantization value for weights.
weight_qmax (int): Maximum quantization value for weights.
epsilon (float): Value used to pad observed [qmin, qmax] before initial
zero-point and scale calculation.

Returns:
QuantizationConfig with 16-bit activations and 8-bit weights
QuantizationConfig: Configuration with 16-bit activations and 8-bit
weights.

"""
extra_args: Dict[str, Any] = {"eps": epsilon}

Expand Down Expand Up @@ -244,27 +265,39 @@ def get_symmetric_a16w8_quantization_config(


NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
a Node and returns whether the node should be annotated or not.
"""Type for a Node Filter used by annotators.

A Node filter is a function that takes a Node and returns whether the node
should be annotated or not.

"""


def _get_module_type_filter(tp: Callable) -> NodeFilterType:
"""Get the module_type_filter function for a given module type, the filter accepts
a node and checks if the node comes from a module that has certain module type
"""Get the module_type_filter function for a given module type.

For example:
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
The filter accepts a node and checks if the node comes from a module that
has a certain module type.

Args:
tp (Callable): Module class to match against the graph node metadata.

Returns:
NodeFilterType: Predicate that returns True for nodes from the module
type.

For example:
node: linear_op = call_function[...](...) # type Block -> Sub -> Linear

>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
>> module_type_filter = _get_module_type_filter(Sub)
>> print(module_type_filter(node))
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
"""
True # the node is from the submodule `Sub` (same for `Block` and `Linear`)

"""
tp_str = tp.__module__ + "." + tp.__qualname__

def module_type_filter(n: Node) -> bool:
"""Return True if the node originates from the target module type."""
# node_stack example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
Expand All @@ -279,16 +312,29 @@ def module_type_filter(n: Node) -> bool:
def _get_not_module_type_or_name_filter(
tp_list: List[Callable], module_name_list: List[str]
) -> NodeFilterType:
"""Create a filter that excludes provided module types and names.

Args:
tp_list (List[Callable]): Module types to exclude from annotation.
module_name_list (List[str]): Module names to exclude from annotation.

Returns:
NodeFilterType: Filter that returns True when the node does not match
any provided module type or name.

"""
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]

def not_module_type_or_name_filter(n: Node) -> bool:
"""Return True when the node matches none of the blocked filters."""
return not any(f(n) for f in module_type_filters + module_name_list_filters)

return not_module_type_or_name_filter


class TOSAQuantizer(Quantizer):
"""Manage quantization annotations for TOSA-compatible backends."""

def __init__(
self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec
Expand All @@ -314,41 +360,47 @@ def __init__(
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}

def set_global(self, quantization_config: QuantizationConfig) -> TOSAQuantizer:
"""
Set quantization_config for submodules that are not already annotated by name or type filters.
"""Set quantization_config for submodules not matched by other filters.

Args:
quantization_config: The QuantizationConfig to set as global configuration.
quantization_config (QuantizationConfig): Configuration to apply to
modules that are not captured by name or type filters.

"""
self.global_config = quantization_config
return self

def set_module_type(
self, module_type: Callable, quantization_config: QuantizationConfig
) -> TOSAQuantizer:
"""
Set quantization_config for a submodule with type: `module_type`, for example:
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
patterns in the submodule with this module type with the given `quantization_config`.
"""Set quantization_config for submodules with a given module type.

For example, calling set_module_type(Sub) quantizes supported patterns
in each Sub instance with the provided quantization_config.

Args:
module_type: The type of the submodule to set the quantization config for.
quantization_config: The QuantizationConfig to set for the submodule.
module_type (Callable): Type whose submodules should use the
provided quantization configuration.
quantization_config (QuantizationConfig): Configuration to apply to
submodules of the given type.

"""
self.module_type_config[module_type] = quantization_config
return self

def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
) -> TOSAQuantizer:
"""
Set quantization_config for a submodule with name: `module_name`, for example:
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""Set quantization_config for submodules with a given module name.

For example, calling set_module_name("blocks.sub") quantizes supported
patterns for that submodule with the provided quantization_config.

Args:
module_name: The name of the submodule to set the quantization config for.
quantization_config: The QuantizationConfig to set for the submodule.
module_name (str): Fully qualified module name to configure.
quantization_config (QuantizationConfig): Configuration applied to
the named submodule.

"""
# Validate that quantization_config is provided
if quantization_config is None:
Expand All @@ -357,26 +409,28 @@ def set_module_name(
return self

def set_io(self, quantization_config: QuantizationConfig) -> TOSAQuantizer:
"""
Set quantization_config for input and output nodes.
"""Set quantization_config for input and output nodes.

Args:
quantization_config: The QuantizationConfig to set for input and output nodes.
quantization_config (QuantizationConfig): Configuration describing
activation quantization for model inputs and outputs.

"""
self.io_config = quantization_config
return self

def transform_for_annotation(self, model: GraphModule) -> GraphModule:
"""
An initial pass for transforming the graph to prepare it for annotation.
"""Transform the graph to prepare it for quantization annotation.

Currently transforms scalar values to tensor attributes.

Args:
model: The model to transform.
model (GraphModule): Model whose graph will be transformed.

Returns:
The transformed model.
"""
GraphModule: Transformed model prepared for annotation.

"""
# TODO: Fix the need to lazily import this.
from executorch.backends.arm._passes import ArmPassManager

Expand All @@ -385,12 +439,16 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
)

def annotate(self, model: GraphModule) -> GraphModule:
"""Performs the quantization annotation on the graph.
Currently only does static quantization annotation.
"""Annotate the graph with the configured quantization settings.

Currently only does static quantization annotation.

Args:
model: The model to annotate statically.
model (GraphModule): Model to annotate statically.

Returns:
The annotated model.
GraphModule: Annotated model ready for export.

"""
model = self._annotate_for_static_quantization_config(model)
return model
Expand All @@ -401,14 +459,19 @@ def _annotate_all_static_patterns(
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> GraphModule:
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
"""Annotate all static patterns registered for the backend.

Args:
model: The model to annotate statically.
quantization_config: Specifies the QuantizationSpecs for the model's
input activations, output activations, weights and biases.
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
model (GraphModule): Model to annotate statically.
quantization_config (Optional[QuantizationConfig]): Quantization
specs for input activations, output activations, weights, and
biases.
filter_fn (Optional[Callable[[Node], bool]]): Optional node filter
specifying which nodes to annotate.

Returns:
The annotated model.
GraphModule: Model populated with quantization annotations.

"""
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
Expand All @@ -420,8 +483,15 @@ def _annotate_all_static_patterns(
def _annotate_for_static_quantization_config(
self, model: GraphModule
) -> GraphModule:
"""Matches the correct QuantizationConfig with the correct module using a filter
when running _annotate_all_static_patterns.
"""Match QuantizationConfigs to modules before annotating patterns.

Args:
model (GraphModule): Model whose modules are being matched to
quantization configs.

Returns:
GraphModule: Annotated model after applying configured filters.

"""
if self.io_config:
self._annotate_io(model, self.io_config)
Expand Down Expand Up @@ -451,6 +521,14 @@ def _annotate_io(
model: GraphModule,
quantization_config: QuantizationConfig,
):
"""Annotate graph inputs and outputs with the provided configuration.

Args:
model (GraphModule): GraphModule being annotated.
quantization_config (QuantizationConfig): Activation qspecs to apply
to IO nodes.

"""
for node in model.graph.nodes:
if is_annotated(node):
continue
Expand All @@ -468,6 +546,7 @@ def _annotate_io(
mark_node_as_annotated(node)

def validate(self, model: GraphModule) -> None:
"""TODO: Implement validation of annotated graph for TOSA backend."""
pass

def quantize_with_submodules(
Expand All @@ -479,10 +558,16 @@ def quantize_with_submodules(
"""Quantizes a GraphModule in a way such that conditional submodules are handled properly.

Args:
model: GraphModule, the model to quantize.
calibration_samples: list[tuple], a list of inputs to used to calibrate the model during quantization.
To properly calibrate a model with submodules, at least one sample per code path is needed.
is_qat: bool, whether to do quantization aware training or not.
model (GraphModule): The model to quantize.
calibration_samples (list[tuple]): A list of inputs to used to
calibrate the model during quantization. To properly calibrate a
model with submodules, at least one sample per code path is
needed.
is_qat (bool): Whether to do quantization aware training or not.

Returns:
GraphModule: The quantized model.

"""
prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e

Expand All @@ -499,23 +584,25 @@ def quantize_with_submodules(


class EthosUQuantizer(TOSAQuantizer):
"""
Quantizer supported by the Arm Ethos-U backend.
"""Quantizer supported by the Arm Ethos-U backend.

Args:
compile_spec: A EthosUCompileSpec instance.
compile_spec (EthosUCompileSpec): Backend compile specification for
Ethos-U targets.

"""

def __init__(self, compile_spec: EthosUCompileSpec) -> None:
super().__init__(compile_spec)


class VgfQuantizer(TOSAQuantizer):
"""
Quantizer supported by the Arm Vgf backend.
"""Quantizer supported by the Arm Vgf backend.

Args:
compile_spec: A VgfCompileSpec instance.
compile_spec (VgfCompileSpec): Backend compile specification for Vgf
targets.

"""

def __init__(self, compile_spec: VgfCompileSpec) -> None:
Expand Down
Loading