Skip to content

Arm backend: Quantized aten.cat constant-fold bug with list-valued tensor inputs #18971

@Rob-Hughes-Arm

Description

@Rob-Hughes-Arm

🐛 Describe the bug

Quantized aten.cat constant-fold bug with list-valued tensor inputs

Summary

The Arm TOSA export path can fold a quantized aten.cat incorrectly when the
concatenated inputs are constant tensors provided through a list-valued tensor
argument.

For mixed-input cases such as torch.cat((horizontal_ramp, vertical_ramp), dim=1), the folded constant produced by FuseConstantArgsPass is wrong before
serialization. The bad value is then preserved in the emitted .tosa
flatbuffer.

This reproduces consistently with a small self-contained repro and the failure is
explained by a mismatch between how FoldAndAnnotateQParamsPass writes
input_qparams metadata and how FuseConstantArgsPass consumes it.

Self-contained repro

Run the following Python from the ExecuTorch repo root in an environment with
the Arm test dependencies installed. Set PATCH_FUSE_LOGIC = True to validate
the proposed fix shape.

import copy
import json
import tempfile
from pathlib import Path
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import tosa.TosaGraph
from executorch.backends.arm._passes.arm_pass_utils import (
    get_constant_placeholder_kind,
    get_first_fake_tensor,
    get_param_tensor,
    is_persistent_buffer,
)
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
    FuseConstantArgsPass,
)
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import QuantizationSpec
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.transforms.utils import create_constant_placeholder
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torchao.quantization.pt2e.observer import FixedQParamsObserver

HEIGHT = 4
WIDTH = 5
PATCH_FUSE_LOGIC = False


class CatConstantModel(nn.Module):
    def forward(self) -> torch.Tensor:
        h = torch.linspace(
            -1.0 + (1.0 / WIDTH),
            1.0 - (1.0 / WIDTH),
            WIDTH,
            dtype=torch.float32,
        ).view(1, 1, 1, WIDTH).expand(1, 1, HEIGHT, WIDTH)
        v = torch.linspace(
            -1.0 + (1.0 / HEIGHT),
            1.0 - (1.0 / HEIGHT),
            HEIGHT,
            dtype=torch.float32,
        ).view(1, 1, HEIGHT, 1).expand(1, 1, HEIGHT, WIDTH)
        return torch.cat((h, v), dim=1)


def fixed_snorm_int8_qspec() -> QuantizationSpec:
    return QuantizationSpec(
        dtype=torch.int8,
        observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
            scale=1.0 / 127.0,
            zero_point=0,
            dtype=torch.qint8,
            qscheme=torch.per_tensor_symmetric,
            quant_min=-127,
            quant_max=127,
        ),
        quant_min=-127,
        quant_max=127,
        qscheme=torch.per_tensor_symmetric,
        is_dynamic=False,
    )


OUTPUT_QSPEC = fixed_snorm_int8_qspec()
OUTPUT_QCONFIG = QuantizationConfig(OUTPUT_QSPEC, OUTPUT_QSPEC, None, None)


def decode_name(value: Any) -> str:
    if isinstance(value, (bytes, bytearray)):
        return value.decode()
    return str(value)


def summarize_tensor(array: np.ndarray) -> dict[str, Any]:
    return {
        "shape": list(array.shape),
        "channel_0_unique_count": int(np.unique(array[..., 0]).size),
        "channel_1_unique_count": int(np.unique(array[..., 1]).size),
        "channel_0_minmax": [int(array[..., 0].min()), int(array[..., 0].max())],
        "channel_1_minmax": [int(array[..., 1].min()), int(array[..., 1].max())],
    }


def compare_arrays(actual: np.ndarray, expected: np.ndarray) -> dict[str, Any]:
    diff = np.abs(actual.astype(np.int32) - expected.astype(np.int32))
    return {
        "equal": bool(np.array_equal(actual, expected)),
        "max_abs_diff": int(diff.max()),
        "mae": float(diff.mean()),
    }


def quantize_reference(tensor: torch.Tensor) -> np.ndarray:
    qargs = QuantArgs(
        scale=1.0 / 127.0,
        zp=0,
        qmin=-127,
        qmax=127,
        dtype=torch.int8,
    )
    quantized = qargs.quantize_value(tensor).detach().cpu().numpy().astype(np.int8)
    return np.transpose(quantized, (0, 2, 3, 1)).copy()


def load_tosa_tensor(tosa_path: Path) -> np.ndarray:
    buf = tosa_path.read_bytes()
    graph = tosa.TosaGraph.TosaGraph.GetRootAs(buf, 0)
    block = graph.Regions(0).Blocks(0)
    shape_map = {}
    tensor_map = {}
    for i in range(block.ShapesLength()):
        shape = block.Shapes(i)
        shape_map[decode_name(shape.Name())] = np.frombuffer(
            shape.DataAsNumpy().tobytes(),
            dtype=np.int64,
        ).tolist()
    for i in range(block.TensorsLength()):
        tensor = block.Tensors(i)
        name = decode_name(tensor.Name())
        tensor_map[name] = tensor
        if name not in shape_map:
            shape_map[name] = [int(tensor.Shape(j)) for j in range(tensor.ShapeLength())]
    candidates = [
        name
        for name in tensor_map
        if shape_map.get(name) == [1, HEIGHT, WIDTH, 2]
    ]
    if len(candidates) != 1:
        raise RuntimeError(f"expected one fused cat tensor, got {candidates}")
    name = candidates[0]
    return np.frombuffer(
        tensor_map[name].DataAsNumpy().tobytes(),
        dtype=np.int8,
    ).reshape(shape_map[name])


def resolve_buggy(arg, *, input_nodes, qparams, exported_program):
    if isinstance(arg, torch.fx.Node) and arg in input_nodes:
        idx = input_nodes.index(arg)
        tensor = get_param_tensor(exported_program, arg)
        if qparams and idx in qparams:
            tensor = qparams[idx].dequantize_value(tensor)
        return tensor
    if isinstance(arg, list):
        return [
            resolve_buggy(
                item,
                input_nodes=input_nodes,
                qparams=qparams,
                exported_program=exported_program,
            )
            for item in arg
        ]
    if isinstance(arg, tuple):
        return tuple(
            resolve_buggy(
                item,
                input_nodes=input_nodes,
                qparams=qparams,
                exported_program=exported_program,
            )
            for item in arg
        )
    return arg


def resolve_fixed(arg, *, arg_index, qparams, exported_program):
    qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
    if isinstance(arg, torch.fx.Node):
        tensor = get_param_tensor(exported_program, arg)
        if qparam is not None:
            tensor = qparam.dequantize_value(tensor)
        return tensor
    if isinstance(arg, list):
        return [
            resolve_fixed(
                item,
                arg_index=arg_index,
                qparams=qparams,
                exported_program=exported_program,
            )
            for item in arg
        ]
    if isinstance(arg, tuple):
        return tuple(
            resolve_fixed(
                item,
                arg_index=arg_index,
                qparams=qparams,
                exported_program=exported_program,
            )
            for item in arg
        )
    return arg


def fixed_fuse_nodes(pass_instance: FuseConstantArgsPass, node: Any) -> bool:
    input_nodes = list(node.all_input_nodes)
    qparams = node.meta.get("input_qparams", None)

    def resolve_arg(arg, arg_index=None):
        qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
        if isinstance(arg, torch.fx.Node) and arg in input_nodes:
            tensor = get_param_tensor(pass_instance.exported_program, arg)
            if qparam is not None:
                tensor = qparam.dequantize_value(tensor)
            return tensor
        if isinstance(arg, tuple):
            return tuple(resolve_arg(x, arg_index) for x in arg)
        if isinstance(arg, list):
            return [resolve_arg(x, arg_index) for x in arg]
        return arg

    new_args = tuple(resolve_arg(arg, i) for i, arg in enumerate(node.args))
    new_kwargs = {key: resolve_arg(value, None) for key, value in node.kwargs.items()}
    data = node.target(*new_args, **new_kwargs)
    if data.numel() > get_first_fake_tensor(node).numel():
        return False
    if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
        data = node.meta["output_qparams"][0].quantize_value(data)
    insert_pos = list(node.all_input_nodes)[0]
    input_kind = get_constant_placeholder_kind(pass_instance.exported_program, insert_pos)
    persistent_buffer = is_persistent_buffer(pass_instance.exported_program, insert_pos)
    with node.graph.inserting_before(insert_pos):
        const_node = create_constant_placeholder(
            exp_program=pass_instance.exported_program,
            graph=node.graph,
            kind=input_kind,
            name=node.name + "_fused_const",
            data=data,
            persistent_buffer=persistent_buffer,
        )
    node.replace_all_uses_with(const_node)
    return True


def export_case():
    model = CatConstantModel().eval()
    with torch.no_grad():
        reference_float = model()
    reference_quant_nhwc = quantize_reference(reference_float)

    quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
    quantizer.set_global(OUTPUT_QCONFIG).set_io(OUTPUT_QCONFIG)
    captured_graph = torch.export.export(model, ())
    quantized_graph = quantizer.quantize_with_submodules(
        captured_graph.module(),
        calibration_samples=[()],
        is_qat=False,
    )
    exported = torch.export.export(quantized_graph, ())

    artifact_dir = Path(tempfile.mkdtemp(prefix="quantized_cat_constant_fold_"))
    compile_spec = TosaCompileSpec(
        TosaSpecification.create_from_string("TOSA-1.0+INT")
    ).dump_intermediate_artifacts_to(str(artifact_dir))
    partitioner = TOSAPartitioner(compile_spec)

    captured = {}
    original_fuse_nodes = FuseConstantArgsPass._fuse_nodes

    def wrapped_fuse_nodes(self, node):
        if "aten_cat" in node.name:
            input_nodes = list(node.all_input_nodes)
            qparams = copy.deepcopy(node.meta.get("input_qparams", {}))
            buggy_args = tuple(
                resolve_buggy(
                    arg,
                    input_nodes=input_nodes,
                    qparams=qparams,
                    exported_program=self.exported_program,
                )
                for arg in node.args
            )
            fixed_args = tuple(
                resolve_fixed(
                    arg,
                    arg_index=i,
                    qparams=qparams,
                    exported_program=self.exported_program,
                )
                for i, arg in enumerate(node.args)
            )
            buggy = node.target(*buggy_args, **node.kwargs)
            fixed = node.target(*fixed_args, **node.kwargs)
            if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
                output_qparams = node.meta["output_qparams"][0]
                buggy = output_qparams.quantize_value(buggy)
                fixed = output_qparams.quantize_value(fixed)
            captured["input_nodes"] = [n.name for n in input_nodes]
            captured["input_qparams_keys"] = sorted(int(k) for k in qparams.keys())
            captured["buggy_eval_summary"] = summarize_tensor(
                np.transpose(buggy.detach().cpu().numpy().astype(np.int8), (0, 2, 3, 1))
            )
            captured["fixed_eval_summary"] = summarize_tensor(
                np.transpose(fixed.detach().cpu().numpy().astype(np.int8), (0, 2, 3, 1))
            )
        if PATCH_FUSE_LOGIC:
            return fixed_fuse_nodes(self, node)
        return original_fuse_nodes(self, node)

    FuseConstantArgsPass._fuse_nodes = wrapped_fuse_nodes
    try:
        to_edge_transform_and_lower(
            exported,
            partitioner=[partitioner],
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
        )
    finally:
        FuseConstantArgsPass._fuse_nodes = original_fuse_nodes

    tosa_files = list(artifact_dir.rglob("*.tosa"))
    if len(tosa_files) != 1:
        raise RuntimeError(f"expected one TOSA file, got {tosa_files}")
    fused_q_nhwc = load_tosa_tensor(tosa_files[0])

    print(
        json.dumps(
            {
                "patch_fuse_logic": PATCH_FUSE_LOGIC,
                "artifact_dir": str(artifact_dir),
                "reference_quant_summary": summarize_tensor(reference_quant_nhwc),
                "fused_tosa_summary": summarize_tensor(fused_q_nhwc),
                "tosa_vs_reference": compare_arrays(fused_q_nhwc, reference_quant_nhwc),
                "fuse_constant_capture": captured,
            },
            indent=2,
        )
    )


export_case()

With PATCH_FUSE_LOGIC = False on current main, the minimized hv case
shows:

  • eager quantized reference channel 1 unique values: 4 (-95, -32, 32, 95)
  • fused TOSA constant channel 1 unique values: 2 (-127, 127)
  • TOSA vs reference: equal=False, max_abs_diff=95, mae=31.75

With PATCH_FUSE_LOGIC = True, the same repro shows:

  • fused TOSA constant channel 1 unique values: 4
  • TOSA vs reference: equal=True, max_abs_diff=0

So a single mixed-input torch.cat((h, v), dim=1) case is sufficient to
demonstrate the bug and validate the proposed fix shape.

Expected behavior

For quantized constant folding of aten.cat([lhs, rhs], dim=1):

  • each tensor input in the list should be dequantized using the correct input
    qparams before eager evaluation in FuseConstantArgsPass
  • the folded constant should stay close to the eager quantized reference
  • mixed and duplicated-input cases should behave consistently apart from normal
    quantization drift

Actual behavior

For mixed-input cases:

  • only the first tensor in the aten.cat input list is dequantized
  • later tensors are treated as raw int8 codes cast to float
  • the folded constant is corrupted inside FuseConstantArgsPass
  • the corrupted folded constant is serialized into the emitted .tosa file

Root cause

The bug is an inconsistency between two Arm passes.

1. FoldAndAnnotateQParamsPass

File:

  • backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

This pass stores input_qparams per top-level argument index.

For aten.cat([tensor0, tensor1], dim=1) the tensor list is argument 0, so
the metadata becomes:

node.meta["input_qparams"] == {0: qparams}

even though the list contains multiple tensor inputs.

2. FuseConstantArgsPass

File:

  • backends/arm/_passes/fuse_constant_ops_pass.py

This pass currently consumes qparams as if they were keyed by flattened
node.all_input_nodes index:

input_nodes = list(node.all_input_nodes)
idx = input_nodes.index(arg)
if qparams and idx in qparams.keys():
    t = qparams[idx].dequantize_value(t)

For aten.cat([tensor0, tensor1], dim=1):

  • input_nodes == [tensor0, tensor1]
  • input_qparams == {0: qparams}

As a result:

  • tensor0 is dequantized
  • tensor1 is left as raw int8 codes cast to float

That corrupts the folded result before serialization.

Why duplicated-input controls can pass

If the same constant node is used twice, for example:

  • torch.cat((horizontal, horizontal), dim=1)
  • torch.cat((vertical, vertical), dim=1)

Because node.all_input_nodes contains unique nodes, the duplicated cases only
expose one input node to the buggy code path. That accidentally matches the
single qparam entry {0: ...} and hides the bug.

Proposed fix

The smallest fix is to make FuseConstantArgsPass resolve qparams by top-level
argument index and propagate that qparam through nested list/tuple arguments.

That matches the contract already established by
FoldAndAnnotateQParamsPass, and it is safer than changing the metadata model
because the rest of the Arm backend already treats input_qparams as
argument-index keyed.

Suggested patch shape:

diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py
@@
-        def resolve_arg(arg):
+        def resolve_arg(arg, arg_index=None):
+            qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
             if isinstance(arg, torch.fx.Node) and arg in input_nodes:
-                idx = input_nodes.index(arg)
                 t = get_param_tensor(self.exported_program, arg)
-                if qparams and idx in qparams.keys():
-                    t = qparams[idx].dequantize_value(t)
+                if qparam is not None:
+                    t = qparam.dequantize_value(t)
                 return t
             if isinstance(arg, tuple):
-                return tuple(resolve_arg(x) for x in arg)
+                return tuple(resolve_arg(x, arg_index) for x in arg)
             if isinstance(arg, list):
-                return [resolve_arg(x) for x in arg]
+                return [resolve_arg(x, arg_index) for x in arg]
             return arg

-        new_args = tuple(resolve_arg(a) for a in node.args)
+        new_args = tuple(resolve_arg(a, i) for i, a in enumerate(node.args))

Fix validation

The proposed fix shape was validated by rerunning the inline repro above with
PATCH_FUSE_LOGIC = True.

Before the patch:

  • fused TOSA channel 0 unique count: 5
  • fused TOSA channel 1 unique count: 2
  • TOSA vs eager quantized reference: equal=False, max_abs_diff=95

After the patch:

  • fused TOSA channel 0 unique count: 5
  • fused TOSA channel 1 unique count: 4
  • TOSA vs eager quantized reference: equal=True, max_abs_diff=0

This fix shape corrects the emitted folded TOSA constant end-to-end for the
minimized repro.

Versions

Collecting environment information...
PyTorch version: 2.10.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Enterprise (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 576.88
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Name: 13th Gen Intel(R) Core(TM) i9-13900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3000
MaxClockSpeed: 3000
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None

Versions of relevant libraries:
[pip3] executorch==1.2.0.dev20260305+cpu
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] torchvision==0.25.0
[conda] Could not collect

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell

Metadata

Metadata

Assignees

No one assigned

    Labels

    partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions