diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 2e43b924cc8..10c555ae929 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -39,11 +39,7 @@ ) from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher -from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( - QuantizerReporter, - SUPPORTED_QCONFIGS, - SUPPORTED_QSPECS, -) +from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter from torch._ops import OpOverload @@ -219,20 +215,28 @@ def get_symmetric_quantization_config( bias_quantization_spec = _get_int32_bias_qspec if is_dynamic: - quantization_config = TOSAQuantizationConfig( - act_quantization_spec, - None, - weight_quantization_spec, - bias_quantization_spec, - ) + output_activation = None else: - quantization_config = TOSAQuantizationConfig( - act_quantization_spec, - act_quantization_spec, - weight_quantization_spec, - bias_quantization_spec, - ) - return quantization_config + output_activation = act_quantization_spec + + module_name = __name__.rsplit(".", maxsplit=1)[-1] + label = ( + f"{module_name}.get_symmetric_quantization_config(" + f"per_channel={int(is_per_channel)}, " + f"qat={int(is_qat)}, " + f"dynamic={int(is_dynamic)}, " + f"act_range=[{act_qmin}, {act_qmax}], " + f"weight_range=[{weight_qmin}, {weight_qmax}]" + ")" + ) + + return TOSAQuantizationConfig( + act_quantization_spec, + output_activation, + weight_quantization_spec, + bias_quantization_spec, + label, + ) @functools.lru_cache @@ -357,59 +361,32 @@ def get_symmetric_a16w8_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic, ) - # Replace activation quantization spec with 16-bit version + if is_dynamic: - quantization_config = TOSAQuantizationConfig( - act_quantization_spec, # 16-bit input activations - None, - base_config.weight, # 8-bit weights from base config - base_config.bias, # bias from base config - ) + output_activation = None else: - quantization_config = TOSAQuantizationConfig( - act_quantization_spec, # 16-bit input activations - act_quantization_spec, # 16-bit output activations - base_config.weight, # 8-bit weights from base config - base_config.bias, # bias from base config - ) - return quantization_config - + output_activation = act_quantization_spec + + module_name = __name__.rsplit(".", maxsplit=1)[-1] + label = ( + f"{module_name}.get_symmetric_a16w8_quantization_config(" + f"per_channel={int(is_per_channel)}, " + f"qat={int(is_qat)}, " + f"dynamic={int(is_dynamic)}, " + f"act_range=[{act_quantization_spec.quant_min}, {act_quantization_spec.quant_max}], " + f"weight_range=[{weight_qmin}, {weight_qmax}]" + ")" + ) -# Register supported quantization configs and qspecs in the reporter for human-readable reporting -# MLETORCH-1854: Temporary solution, refactor to automatically register these instead -_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config() -_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config() -_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config() -_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config( - is_per_channel=False -) -_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config( - is_per_channel=False -) -_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config( - is_per_channel=False -) -SUPPORTED_QCONFIGS.update( - { - _symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)", - _symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)", - _symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)", - _symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)", - _symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)", - _symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)", - } -) + # Replace activation quantization spec with 16-bit version + return TOSAQuantizationConfig( + act_quantization_spec, # 16-bit input activations + output_activation, + base_config.weight, # 8-bit weights from base config + base_config.bias, # bias from base config + label, + ) -SUPPORTED_QSPECS.update( - { - _symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC", - _symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC", - _symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC", - _symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC", - _symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC", - _symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC", - } -) NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index fb4f363d6b0..453c0d3f4cc 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -21,6 +21,10 @@ from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.cortex_m.quantizer_reporter import ( + QuantizerInfo, + QuantizerReporterUser, +) from torch.fx import Node from torchao.quantization.pt2e.quantizer import ( @@ -160,25 +164,6 @@ def _get_int32_per_channel_bias_qspec(node): ) -class _QuantizerReporterUserMixin: - def __init__(self): - self.reporter = None - - def register_reporter(self, reporter) -> None: - self.reporter = reporter - - def report_reject(self, pattern: list[Node], reason: str) -> None: - if self.reporter is not None: - self.reporter.report_reject(self, pattern, reason) - - def report_accept(self, pattern: list[Node]) -> None: - if self.reporter is not None: - self.reporter.report_accept(self, pattern) - - def get_quantizer_info(self): - raise NotImplementedError("Quantizer must implement get_quantizer_info method.") - - class PatternCheck: """Base class for pattern checks. @@ -248,7 +233,7 @@ def find_nodes(self, model: torch.fx.GraphModule) -> Iterator[Node]: pass -class PatternQuantizer(Quantizer, _QuantizerReporterUserMixin): +class PatternQuantizer(Quantizer, QuantizerReporterUser): """Quantizes a graph according to an OperatorConfig. Args: @@ -265,28 +250,28 @@ def __init__( pattern_matcher: "PatternMatcher", ) -> None: super().__init__() - _QuantizerReporterUserMixin.__init__(self) + QuantizerReporterUser.__init__(self) self.quantization_config: QuantizationConfig | None = quantization_config self.node_finder: "NodeFinder" = node_finder self.pattern_matcher: "PatternMatcher" = pattern_matcher def get_quantizer_info(self): - from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( - QuantizerInfo, - SUPPORTED_QCONFIGS, - ) - name = self.__class__.__name__ targeted_nodes_description = str(self.node_finder) - quantization_config_path = SUPPORTED_QCONFIGS.get( - self.quantization_config, "UNREGISTERED_QCONFIG" - ) + if self.quantization_config is None: + qconfig_label = "NO_QCONFIG" + else: + qconfig_label = ( + self.quantization_config.label + if self.quantization_config.label is not None + else self.quantization_config.__class__.__name__ # no label, fallback to class name + ) support_config_path = self.pattern_matcher.support_dict_name return QuantizerInfo( name, targeted_nodes_description, - quantization_config_path, + qconfig_label, support_config_path, ) @@ -397,7 +382,7 @@ def validate(self, model: torch.fx.GraphModule) -> bool: # type: ignore[overrid return True -class SharedQspecQuantizer(Quantizer, _QuantizerReporterUserMixin): +class SharedQspecQuantizer(Quantizer, QuantizerReporterUser): """Assures that specific ops share quantization parameters on all inputs/outputs. """ @@ -495,7 +480,7 @@ class SharedQspecQuantizer(Quantizer, _QuantizerReporterUserMixin): def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> None: super().__init__() - _QuantizerReporterUserMixin.__init__(self) + QuantizerReporterUser.__init__(self) if targets is None: self.targets = self.SHARED_QSPEC_OPS_DEFAULT self.support_config_path = ( @@ -508,18 +493,14 @@ def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> Non ) def get_quantizer_info(self): - from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( - QuantizerInfo, - ) - name = self.__class__.__name__ targeted_nodes_description = "" - quantization_config_path = "SHARED_QCONFIG" + qconfig_label = "shared qparams for connected targeted nodes" support_config_path = self.support_config_path return QuantizerInfo( name, targeted_nodes_description, - quantization_config_path, + qconfig_label, support_config_path, ) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index e6c53ebf966..d06203cede3 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -46,6 +46,7 @@ class QuantizationConfig: output_activation: Optional[QuantizationSpecBase] weight: Optional[QuantizationSpecBase] bias: Optional[QuantizationSpecBase] | Callable[[Any], Any] + label: Optional[str] = None # Optional label for debugging/visualization purposes def get_input_act_qspec( self, node: Optional[Node] = None, input_node: Optional[Node] = None diff --git a/backends/cortex_m/TARGETS b/backends/cortex_m/TARGETS new file mode 100644 index 00000000000..98d006cb48a --- /dev/null +++ b/backends/cortex_m/TARGETS @@ -0,0 +1,21 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "quantizer_reporter", + srcs = [ + "quantizer_reporter.py", + ], + deps = [ + "//caffe2:torch", + "//pytorch/ao:torchao", + "fbsource//third-party/pypi/tabulate:tabulate", + ], +) diff --git a/backends/cortex_m/quantizer/TARGETS b/backends/cortex_m/quantizer/TARGETS index 7a4ef5cd78a..d765e59cf8e 100644 --- a/backends/cortex_m/quantizer/TARGETS +++ b/backends/cortex_m/quantizer/TARGETS @@ -17,7 +17,6 @@ python_library( "pattern_matcher.py", "quantization_configs.py", "quantizer.py", - "quantizer_reporter.py", "quantizer_support.py", ], deps = [ @@ -27,6 +26,7 @@ python_library( "//executorch/backends/arm/quantizer:arm_quantizer_utils", "//executorch/backends/arm/quantizer:quantization_annotator", "//executorch/backends/arm/quantizer:quantization_config", + "//executorch/backends/cortex_m:quantizer_reporter", "//pytorch/ao:torchao", "fbsource//third-party/pypi/tabulate:tabulate", ], @@ -42,19 +42,7 @@ python_library( "//caffe2:torch", "//executorch/backends/arm/quantizer:arm_quantizer_utils", "//executorch/backends/arm/quantizer:quantization_config", + "//executorch/backends/cortex_m:quantizer_reporter", "//pytorch/ao:torchao", - ":quantizer_reporter", - ], -) - -python_library( - name = "quantizer_reporter", - srcs = [ - "quantizer_reporter.py", - ], - deps = [ - "//caffe2:torch", - "//pytorch/ao:torchao", - "fbsource//third-party/pypi/tabulate:tabulate", ], ) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index a2fc7d19b21..9bc13c05e9d 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -10,10 +10,6 @@ _get_int32_per_channel_bias_qspec, ) from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( - SUPPORTED_QCONFIGS, - SUPPORTED_QSPECS, -) from torch.fx import Node from torchao.quantization.pt2e import ( HistogramObserver, @@ -156,6 +152,7 @@ def get_bias_qspec( INT8_ACTIVATION_PER_TENSOR_QSPEC, INT8_WEIGHT_PER_TENSOR_QSPEC, _get_int32_bias_qspec, + f"{__name__}.INT8_PER_TENSOR_CONFIG", ) @@ -164,25 +161,5 @@ def get_bias_qspec( INT8_ACTIVATION_PER_TENSOR_QSPEC, INT8_WEIGHT_PER_CHANNEL_QSPEC, _get_int32_per_channel_bias_qspec, -) - - -# Register supported quantization configs and qspecs in the reporter for human-readable reporting -# MLETORCH-1854: Temporary solution, refactor to automatically register these instead -SUPPORTED_QCONFIGS.update( - { - INT8_PER_CHANNEL_CONFIG: f"{__name__}.INT8_PER_CHANNEL_QCONFIG", - INT8_PER_TENSOR_CONFIG: f"{__name__}.INT8_PER_TENSOR_QCONFIG", - } -) - -SUPPORTED_QSPECS.update( - { - INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC", - INT8_ACTIVATION_PER_CHANNEL_QSPEC: "INT8_ACTIVATION_PER_CHANNEL_QSPEC", - INT8_WEIGHT_PER_TENSOR_QSPEC: "INT8_WEIGHT_PER_TENSOR_QSPEC", - INT8_WEIGHT_PER_CHANNEL_QSPEC: "INT8_WEIGHT_PER_CHANNEL_QSPEC", - INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC: "INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC", - SOFTMAX_OUTPUT_FIXED_QSPEC: "SOFTMAX_OUTPUT_FIXED_QSPEC", - } + f"{__name__}.INT8_PER_CHANNEL_CONFIG", ) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 0654862bc1b..e331c80ed1c 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -22,13 +22,13 @@ INT8_PER_CHANNEL_CONFIG, INT8_PER_TENSOR_CONFIG, ) -from executorch.backends.cortex_m.quantizer.quantizer_reporter import QuantizerReporter from executorch.backends.cortex_m.quantizer.quantizer_support import ( __name__ as cortex_m_quantizer_support_module, CONV_OP_PATTERNS, CONV_TRANSPOSE_OP_PATTERNS, CORTEX_M_QUANTIZER_SUPPORT_DICT, ) +from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter from torch._ops import OpOverload from torch.fx import GraphModule from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer diff --git a/backends/cortex_m/quantizer/quantizer_reporter.py b/backends/cortex_m/quantizer_reporter.py similarity index 89% rename from backends/cortex_m/quantizer/quantizer_reporter.py rename to backends/cortex_m/quantizer_reporter.py index 84416d97cb8..5e423672cd1 100644 --- a/backends/cortex_m/quantizer/quantizer_reporter.py +++ b/backends/cortex_m/quantizer_reporter.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. """Contains classes for reporting quantization decisions made by Quantizers. -Basic useage: +Basic usage: 1. Implement the QuantizerReporterUser API for all quantizers intending to use the reporter. 2. Instantiate the QuantizerReporter with a list of quantizers to be reported. 3. After annotation, log the report using QuantizerReporter.log_quantizer_report(model). @@ -32,44 +32,29 @@ logger = logging.getLogger(__name__) tabulate = cast(Callable[..., str], import_module("tabulate").tabulate) -# Look-up dicts used to get human readable names for supported quantization configs and specs -SUPPORTED_QCONFIGS: dict[Any, str] = {} -SUPPORTED_QSPECS: dict[QuantizationSpecBase | None, str] = {} +def qspec_repr(qspec: Optional[QuantizationSpecBase]) -> str: + """Get a human-readable representation of a QuantizationSpec.""" -def _qspec_repr(qspec): - """Get a human readable representation of QuantizationSpecs. - - Note that the observer_or_fake_quant_ctr field is created dynamically with - the qspec so two qspecs created at different times will not evaluate as - equal. Therefore a custom comparison is required. - - #TODO: Clean up qconfig/ qspec string representation logic in cortex_m/arm - backend. - - """ if isinstance(qspec, SharedQuantizationSpec): - return "SHARED_QSPEC" + return f"SharedQuantizationSpec(edge_or_node={qspec.edge_or_node})" elif isinstance(qspec, DerivedQuantizationSpec): - return "DERIVED_QSPEC" - elif qspec is None: - return "NO_QSPEC" + return f"DerivedQuantizationSpec(derived_from={qspec.derived_from}, dtype={qspec.dtype})" elif isinstance(qspec, QuantizationSpec): - for key, val in SUPPORTED_QSPECS.items(): - if type(qspec) is not type(key): - continue - if qspec.dtype != key.dtype: - continue - if qspec.quant_min != key.quant_min: - continue - if qspec.quant_max != key.quant_max: - continue - if qspec.qscheme != key.qscheme: - continue - if qspec.is_dynamic != key.is_dynamic: - continue - return val - return "UNREGISTERED_QSPEC" + + def _fmt(obj: Any) -> str: + return str(obj).removeprefix("torch.").upper() + + q_range_fmt = ( + f", range=({qspec.quant_min},{qspec.quant_max})" + if (qspec.quant_min is not None or qspec.quant_max is not None) + else "" + ) + return f"QuantizationSpec(dtype={_fmt(qspec.dtype)}{q_range_fmt})" + elif qspec is None: + return "None" + else: + return qspec.__class__.__name__ class QuantizerInfo(NamedTuple): @@ -77,7 +62,7 @@ class QuantizerInfo(NamedTuple): name: str targeted_nodes_description: str - quantization_config_path: str + qconfig_label: str support_config_path: str @@ -112,8 +97,8 @@ class QuantizerReport: _PREVIOUS_ANNOTATION_REJECT_REASON = "Tried annotating already quantized node." - def __init__(self, quantizer): - self.quantizer = quantizer.get_quantizer_info() + def __init__(self, quantizer_info: QuantizerInfo): + self.quantizer_info = quantizer_info self.accepted_patterns: List[AnnotatedPatternReport] = [] self.rejected_patterns: List[RejectedPatternReport] = [] @@ -155,7 +140,7 @@ def report_accept(self, pattern: List[Node]) -> None: f"Node {node.name} was reported as annotated but annotation metadata is missing." ) qspec_input_map_lines = [ - f"{node.name}: {_qspec_repr(qspec)}" + f"{node.name}: {qspec_repr(qspec)}" for node, qspec in annotation.input_qspec_map.items() ] @@ -163,7 +148,7 @@ def report_accept(self, pattern: List[Node]) -> None: NodeQSpecReport( node.name, qspec_input_map_lines, - _qspec_repr(annotation.output_qspec), + qspec_repr(annotation.output_qspec), ) ) @@ -180,11 +165,11 @@ def report_reject(self, pattern, reason): def get_quantizer_info_rows(self) -> List[str]: rows = [] rows.append( - f"{self.quantizer.name} using {self.quantizer.targeted_nodes_description}" + f"{self.quantizer_info.name} using {self.quantizer_info.targeted_nodes_description}" ) - rows.append(f"Annotating with {self.quantizer.quantization_config_path}") + rows.append(f"Annotating with {self.quantizer_info.qconfig_label}") rows.append( - f"Supported operators and patterns defined by {self.quantizer.support_config_path}" + f"Supported operators and patterns defined by {self.quantizer_info.support_config_path}" ) if ( @@ -317,7 +302,7 @@ def set_quantizers(self, quantizers: List[QuantizerReporterUser]) -> None: f"Quantizer {quantizer.__class__.__name__} does not implement QuantizerReporterUser interface and will not report quantization decisions." ) - self.quantizers[quantizer] = QuantizerReport(quantizer) + self.quantizers[quantizer] = QuantizerReport(quantizer.get_quantizer_info()) def report_reject( self, quantizer: QuantizerReporterUser, pattern: List[Node], reason: str diff --git a/backends/cortex_m/test/misc/test_quantizer_reporter.py b/backends/cortex_m/test/misc/test_quantizer_reporter.py index 368ff78793c..9504820d65c 100644 --- a/backends/cortex_m/test/misc/test_quantizer_reporter.py +++ b/backends/cortex_m/test/misc/test_quantizer_reporter.py @@ -6,19 +6,36 @@ import logging import torch -from executorch.backends.cortex_m.quantizer.quantization_configs import ( - INT8_ACTIVATION_PER_CHANNEL_QSPEC, - INT8_WEIGHT_PER_TENSOR_QSPEC, -) from executorch.backends.cortex_m.quantizer.quantizer import mark_node_as_annotated -from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( +from executorch.backends.cortex_m.quantizer_reporter import ( logger as quantizer_logger, + qspec_repr, QuantizerInfo, QuantizerReport, QuantizerReporter, QuantizerReporterUser, ) from torch.export import export +from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, + SharedQuantizationSpec, +) + +INT8_WEIGHT_PER_TENSOR_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, + quant_min=-127, + quant_max=127, +) +INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, + qscheme=torch.per_channel_affine, + ch_axis=0, +) class _TwoOpModule(torch.nn.Module): @@ -38,11 +55,79 @@ def get_quantizer_info(self) -> QuantizerInfo: return QuantizerInfo( name="DummyQuantizer", targeted_nodes_description="dummy nodes", - quantization_config_path="dummy.config", + qconfig_label="dummy.config", support_config_path="dummy.support", ) +def test_qspec_repr_quantization_spec_with_range(): + qspec = QuantizationSpec( + torch.int8, + MinMaxObserver, + quant_min=-42, + quant_max=123, + ) + assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT8, range=(-42,123))" + + +def test_qspec_repr_quantization_spec_without_range(): + qspec = QuantizationSpec( + torch.int16, + MinMaxObserver, + ) + assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT16)" + + +def test_qspec_repr_quantization_spec_partial_range(): + qspec = QuantizationSpec( + torch.int16, + MinMaxObserver, + quant_min=-100, + ) + assert qspec_repr(qspec) == "QuantizationSpec(dtype=INT16, range=(-100,None))" + + +def test_qspec_repr_shared_quantization_spec(): + graph_module = _export_two_op_graph_module() + add_node = next( + node + for node in graph_module.graph.nodes + if node.target == torch.ops.aten.add.Tensor + ) + qspec = SharedQuantizationSpec(add_node) + + assert qspec_repr(qspec) == f"SharedQuantizationSpec(edge_or_node={add_node})" + + +def test_qspec_repr_derived_quantization_spec(): + graph_module = _export_two_op_graph_module() + x_node = next(node for node in graph_module.graph.nodes if node.name == "x") + y_node = next(node for node in graph_module.graph.nodes if node.name == "y") + add_node = next( + node + for node in graph_module.graph.nodes + if node.target == torch.ops.aten.add.Tensor + ) + derived_from = [(x_node, add_node), (y_node, add_node)] + qspec = DerivedQuantizationSpec( + derived_from=derived_from, + derive_qparams_fn=lambda _: ( + torch.tensor([1.0]), + torch.tensor([0], dtype=torch.int32), + ), + dtype=torch.int32, + ) + + assert ( + qspec_repr(qspec) + == f"DerivedQuantizationSpec(derived_from={derived_from}, dtype={qspec.dtype})" + ) + + +def test_qspec_repr_none(): + assert qspec_repr(None) == "None" + + def test_warning_log_level(caplog): graph_module = _export_two_op_graph_module() @@ -128,11 +213,11 @@ def test_debug_log_level(caplog): Rejected due to previous annotation: 0 Rejected nodes: 0 - NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP - -- ----------- ------------------------------- --------------------------------- - ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC NO_QSPEC - | y: NO_QSPEC - ╘ relu INT8_ACTIVATION_PER_CHANNEL_QSPEC + NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP + -- ----------- ------------------------------------------------- ---------------------------- + ╒ add x: QuantizationSpec(dtype=INT8, range=(-127,127)) None + | y: None + ╘ relu QuantizationSpec(dtype=INT8) ---------------------------------------------------------------------------------------------------- DummyQuantizer using dummy nodes Annotating with dummy.config