🐛 Describe the bug
Quantized aten.cat constant-fold bug with list-valued tensor inputs
Summary
The Arm TOSA export path can fold a quantized aten.cat incorrectly when the
concatenated inputs are constant tensors provided through a list-valued tensor
argument.
For mixed-input cases such as torch.cat((horizontal_ramp, vertical_ramp), dim=1), the folded constant produced by FuseConstantArgsPass is wrong before
serialization. The bad value is then preserved in the emitted .tosa
flatbuffer.
This reproduces consistently with a small self-contained repro and the failure is
explained by a mismatch between how FoldAndAnnotateQParamsPass writes
input_qparams metadata and how FuseConstantArgsPass consumes it.
Self-contained repro
Run the following Python from the ExecuTorch repo root in an environment with
the Arm test dependencies installed. Set PATCH_FUSE_LOGIC = True to validate
the proposed fix shape.
import copy
import json
import tempfile
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import tosa.TosaGraph
from executorch.backends.arm._passes.arm_pass_utils import (
get_constant_placeholder_kind,
get_first_fake_tensor,
get_param_tensor,
is_persistent_buffer,
)
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
FuseConstantArgsPass,
)
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import QuantizationSpec
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.transforms.utils import create_constant_placeholder
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torchao.quantization.pt2e.observer import FixedQParamsObserver
HEIGHT = 4
WIDTH = 5
PATCH_FUSE_LOGIC = False
class CatConstantModel(nn.Module):
def forward(self) -> torch.Tensor:
h = torch.linspace(
-1.0 + (1.0 / WIDTH),
1.0 - (1.0 / WIDTH),
WIDTH,
dtype=torch.float32,
).view(1, 1, 1, WIDTH).expand(1, 1, HEIGHT, WIDTH)
v = torch.linspace(
-1.0 + (1.0 / HEIGHT),
1.0 - (1.0 / HEIGHT),
HEIGHT,
dtype=torch.float32,
).view(1, 1, HEIGHT, 1).expand(1, 1, HEIGHT, WIDTH)
return torch.cat((h, v), dim=1)
def fixed_snorm_int8_qspec() -> QuantizationSpec:
return QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 127.0,
zero_point=0,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
quant_min=-127,
quant_max=127,
),
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
)
OUTPUT_QSPEC = fixed_snorm_int8_qspec()
OUTPUT_QCONFIG = QuantizationConfig(OUTPUT_QSPEC, OUTPUT_QSPEC, None, None)
def decode_name(value: Any) -> str:
if isinstance(value, (bytes, bytearray)):
return value.decode()
return str(value)
def summarize_tensor(array: np.ndarray) -> dict[str, Any]:
return {
"shape": list(array.shape),
"channel_0_unique_count": int(np.unique(array[..., 0]).size),
"channel_1_unique_count": int(np.unique(array[..., 1]).size),
"channel_0_minmax": [int(array[..., 0].min()), int(array[..., 0].max())],
"channel_1_minmax": [int(array[..., 1].min()), int(array[..., 1].max())],
}
def compare_arrays(actual: np.ndarray, expected: np.ndarray) -> dict[str, Any]:
diff = np.abs(actual.astype(np.int32) - expected.astype(np.int32))
return {
"equal": bool(np.array_equal(actual, expected)),
"max_abs_diff": int(diff.max()),
"mae": float(diff.mean()),
}
def quantize_reference(tensor: torch.Tensor) -> np.ndarray:
qargs = QuantArgs(
scale=1.0 / 127.0,
zp=0,
qmin=-127,
qmax=127,
dtype=torch.int8,
)
quantized = qargs.quantize_value(tensor).detach().cpu().numpy().astype(np.int8)
return np.transpose(quantized, (0, 2, 3, 1)).copy()
def load_tosa_tensor(tosa_path: Path) -> np.ndarray:
buf = tosa_path.read_bytes()
graph = tosa.TosaGraph.TosaGraph.GetRootAs(buf, 0)
block = graph.Regions(0).Blocks(0)
shape_map = {}
tensor_map = {}
for i in range(block.ShapesLength()):
shape = block.Shapes(i)
shape_map[decode_name(shape.Name())] = np.frombuffer(
shape.DataAsNumpy().tobytes(),
dtype=np.int64,
).tolist()
for i in range(block.TensorsLength()):
tensor = block.Tensors(i)
name = decode_name(tensor.Name())
tensor_map[name] = tensor
if name not in shape_map:
shape_map[name] = [int(tensor.Shape(j)) for j in range(tensor.ShapeLength())]
candidates = [
name
for name in tensor_map
if shape_map.get(name) == [1, HEIGHT, WIDTH, 2]
]
if len(candidates) != 1:
raise RuntimeError(f"expected one fused cat tensor, got {candidates}")
name = candidates[0]
return np.frombuffer(
tensor_map[name].DataAsNumpy().tobytes(),
dtype=np.int8,
).reshape(shape_map[name])
def resolve_buggy(arg, *, input_nodes, qparams, exported_program):
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
idx = input_nodes.index(arg)
tensor = get_param_tensor(exported_program, arg)
if qparams and idx in qparams:
tensor = qparams[idx].dequantize_value(tensor)
return tensor
if isinstance(arg, list):
return [
resolve_buggy(
item,
input_nodes=input_nodes,
qparams=qparams,
exported_program=exported_program,
)
for item in arg
]
if isinstance(arg, tuple):
return tuple(
resolve_buggy(
item,
input_nodes=input_nodes,
qparams=qparams,
exported_program=exported_program,
)
for item in arg
)
return arg
def resolve_fixed(arg, *, arg_index, qparams, exported_program):
qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
if isinstance(arg, torch.fx.Node):
tensor = get_param_tensor(exported_program, arg)
if qparam is not None:
tensor = qparam.dequantize_value(tensor)
return tensor
if isinstance(arg, list):
return [
resolve_fixed(
item,
arg_index=arg_index,
qparams=qparams,
exported_program=exported_program,
)
for item in arg
]
if isinstance(arg, tuple):
return tuple(
resolve_fixed(
item,
arg_index=arg_index,
qparams=qparams,
exported_program=exported_program,
)
for item in arg
)
return arg
def fixed_fuse_nodes(pass_instance: FuseConstantArgsPass, node: Any) -> bool:
input_nodes = list(node.all_input_nodes)
qparams = node.meta.get("input_qparams", None)
def resolve_arg(arg, arg_index=None):
qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
tensor = get_param_tensor(pass_instance.exported_program, arg)
if qparam is not None:
tensor = qparam.dequantize_value(tensor)
return tensor
if isinstance(arg, tuple):
return tuple(resolve_arg(x, arg_index) for x in arg)
if isinstance(arg, list):
return [resolve_arg(x, arg_index) for x in arg]
return arg
new_args = tuple(resolve_arg(arg, i) for i, arg in enumerate(node.args))
new_kwargs = {key: resolve_arg(value, None) for key, value in node.kwargs.items()}
data = node.target(*new_args, **new_kwargs)
if data.numel() > get_first_fake_tensor(node).numel():
return False
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
data = node.meta["output_qparams"][0].quantize_value(data)
insert_pos = list(node.all_input_nodes)[0]
input_kind = get_constant_placeholder_kind(pass_instance.exported_program, insert_pos)
persistent_buffer = is_persistent_buffer(pass_instance.exported_program, insert_pos)
with node.graph.inserting_before(insert_pos):
const_node = create_constant_placeholder(
exp_program=pass_instance.exported_program,
graph=node.graph,
kind=input_kind,
name=node.name + "_fused_const",
data=data,
persistent_buffer=persistent_buffer,
)
node.replace_all_uses_with(const_node)
return True
def export_case():
model = CatConstantModel().eval()
with torch.no_grad():
reference_float = model()
reference_quant_nhwc = quantize_reference(reference_float)
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(OUTPUT_QCONFIG).set_io(OUTPUT_QCONFIG)
captured_graph = torch.export.export(model, ())
quantized_graph = quantizer.quantize_with_submodules(
captured_graph.module(),
calibration_samples=[()],
is_qat=False,
)
exported = torch.export.export(quantized_graph, ())
artifact_dir = Path(tempfile.mkdtemp(prefix="quantized_cat_constant_fold_"))
compile_spec = TosaCompileSpec(
TosaSpecification.create_from_string("TOSA-1.0+INT")
).dump_intermediate_artifacts_to(str(artifact_dir))
partitioner = TOSAPartitioner(compile_spec)
captured = {}
original_fuse_nodes = FuseConstantArgsPass._fuse_nodes
def wrapped_fuse_nodes(self, node):
if "aten_cat" in node.name:
input_nodes = list(node.all_input_nodes)
qparams = copy.deepcopy(node.meta.get("input_qparams", {}))
buggy_args = tuple(
resolve_buggy(
arg,
input_nodes=input_nodes,
qparams=qparams,
exported_program=self.exported_program,
)
for arg in node.args
)
fixed_args = tuple(
resolve_fixed(
arg,
arg_index=i,
qparams=qparams,
exported_program=self.exported_program,
)
for i, arg in enumerate(node.args)
)
buggy = node.target(*buggy_args, **node.kwargs)
fixed = node.target(*fixed_args, **node.kwargs)
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
output_qparams = node.meta["output_qparams"][0]
buggy = output_qparams.quantize_value(buggy)
fixed = output_qparams.quantize_value(fixed)
captured["input_nodes"] = [n.name for n in input_nodes]
captured["input_qparams_keys"] = sorted(int(k) for k in qparams.keys())
captured["buggy_eval_summary"] = summarize_tensor(
np.transpose(buggy.detach().cpu().numpy().astype(np.int8), (0, 2, 3, 1))
)
captured["fixed_eval_summary"] = summarize_tensor(
np.transpose(fixed.detach().cpu().numpy().astype(np.int8), (0, 2, 3, 1))
)
if PATCH_FUSE_LOGIC:
return fixed_fuse_nodes(self, node)
return original_fuse_nodes(self, node)
FuseConstantArgsPass._fuse_nodes = wrapped_fuse_nodes
try:
to_edge_transform_and_lower(
exported,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
finally:
FuseConstantArgsPass._fuse_nodes = original_fuse_nodes
tosa_files = list(artifact_dir.rglob("*.tosa"))
if len(tosa_files) != 1:
raise RuntimeError(f"expected one TOSA file, got {tosa_files}")
fused_q_nhwc = load_tosa_tensor(tosa_files[0])
print(
json.dumps(
{
"patch_fuse_logic": PATCH_FUSE_LOGIC,
"artifact_dir": str(artifact_dir),
"reference_quant_summary": summarize_tensor(reference_quant_nhwc),
"fused_tosa_summary": summarize_tensor(fused_q_nhwc),
"tosa_vs_reference": compare_arrays(fused_q_nhwc, reference_quant_nhwc),
"fuse_constant_capture": captured,
},
indent=2,
)
)
export_case()
With PATCH_FUSE_LOGIC = False on current main, the minimized hv case
shows:
- eager quantized reference channel 1 unique values:
4 (-95, -32, 32, 95)
- fused TOSA constant channel 1 unique values:
2 (-127, 127)
- TOSA vs reference:
equal=False, max_abs_diff=95, mae=31.75
With PATCH_FUSE_LOGIC = True, the same repro shows:
- fused TOSA constant channel 1 unique values:
4
- TOSA vs reference:
equal=True, max_abs_diff=0
So a single mixed-input torch.cat((h, v), dim=1) case is sufficient to
demonstrate the bug and validate the proposed fix shape.
Expected behavior
For quantized constant folding of aten.cat([lhs, rhs], dim=1):
- each tensor input in the list should be dequantized using the correct input
qparams before eager evaluation in FuseConstantArgsPass
- the folded constant should stay close to the eager quantized reference
- mixed and duplicated-input cases should behave consistently apart from normal
quantization drift
Actual behavior
For mixed-input cases:
- only the first tensor in the
aten.cat input list is dequantized
- later tensors are treated as raw int8 codes cast to float
- the folded constant is corrupted inside
FuseConstantArgsPass
- the corrupted folded constant is serialized into the emitted
.tosa file
Root cause
The bug is an inconsistency between two Arm passes.
1. FoldAndAnnotateQParamsPass
File:
backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
This pass stores input_qparams per top-level argument index.
For aten.cat([tensor0, tensor1], dim=1) the tensor list is argument 0, so
the metadata becomes:
node.meta["input_qparams"] == {0: qparams}
even though the list contains multiple tensor inputs.
2. FuseConstantArgsPass
File:
backends/arm/_passes/fuse_constant_ops_pass.py
This pass currently consumes qparams as if they were keyed by flattened
node.all_input_nodes index:
input_nodes = list(node.all_input_nodes)
idx = input_nodes.index(arg)
if qparams and idx in qparams.keys():
t = qparams[idx].dequantize_value(t)
For aten.cat([tensor0, tensor1], dim=1):
input_nodes == [tensor0, tensor1]
input_qparams == {0: qparams}
As a result:
tensor0 is dequantized
tensor1 is left as raw int8 codes cast to float
That corrupts the folded result before serialization.
Why duplicated-input controls can pass
If the same constant node is used twice, for example:
torch.cat((horizontal, horizontal), dim=1)
torch.cat((vertical, vertical), dim=1)
Because node.all_input_nodes contains unique nodes, the duplicated cases only
expose one input node to the buggy code path. That accidentally matches the
single qparam entry {0: ...} and hides the bug.
Proposed fix
The smallest fix is to make FuseConstantArgsPass resolve qparams by top-level
argument index and propagate that qparam through nested list/tuple arguments.
That matches the contract already established by
FoldAndAnnotateQParamsPass, and it is safer than changing the metadata model
because the rest of the Arm backend already treats input_qparams as
argument-index keyed.
Suggested patch shape:
diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py
@@
- def resolve_arg(arg):
+ def resolve_arg(arg, arg_index=None):
+ qparam = qparams.get(arg_index) if (qparams and arg_index is not None) else None
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
- idx = input_nodes.index(arg)
t = get_param_tensor(self.exported_program, arg)
- if qparams and idx in qparams.keys():
- t = qparams[idx].dequantize_value(t)
+ if qparam is not None:
+ t = qparam.dequantize_value(t)
return t
if isinstance(arg, tuple):
- return tuple(resolve_arg(x) for x in arg)
+ return tuple(resolve_arg(x, arg_index) for x in arg)
if isinstance(arg, list):
- return [resolve_arg(x) for x in arg]
+ return [resolve_arg(x, arg_index) for x in arg]
return arg
- new_args = tuple(resolve_arg(a) for a in node.args)
+ new_args = tuple(resolve_arg(a, i) for i, a in enumerate(node.args))
Fix validation
The proposed fix shape was validated by rerunning the inline repro above with
PATCH_FUSE_LOGIC = True.
Before the patch:
- fused TOSA channel 0 unique count:
5
- fused TOSA channel 1 unique count:
2
- TOSA vs eager quantized reference:
equal=False, max_abs_diff=95
After the patch:
- fused TOSA channel 0 unique count:
5
- fused TOSA channel 1 unique count:
4
- TOSA vs eager quantized reference:
equal=True, max_abs_diff=0
This fix shape corrects the emitted folded TOSA constant end-to-end for the
minimized repro.
Versions
Collecting environment information...
PyTorch version: 2.10.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Enterprise (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 576.88
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Name: 13th Gen Intel(R) Core(TM) i9-13900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3000
MaxClockSpeed: 3000
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None
Versions of relevant libraries:
[pip3] executorch==1.2.0.dev20260305+cpu
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] torchvision==0.25.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell
🐛 Describe the bug
Quantized
aten.catconstant-fold bug with list-valued tensor inputsSummary
The Arm TOSA export path can fold a quantized
aten.catincorrectly when theconcatenated inputs are constant tensors provided through a list-valued tensor
argument.
For mixed-input cases such as
torch.cat((horizontal_ramp, vertical_ramp), dim=1), the folded constant produced byFuseConstantArgsPassis wrong beforeserialization. The bad value is then preserved in the emitted
.tosaflatbuffer.
This reproduces consistently with a small self-contained repro and the failure is
explained by a mismatch between how
FoldAndAnnotateQParamsPasswritesinput_qparamsmetadata and howFuseConstantArgsPassconsumes it.Self-contained repro
Run the following Python from the ExecuTorch repo root in an environment with
the Arm test dependencies installed. Set
PATCH_FUSE_LOGIC = Trueto validatethe proposed fix shape.
With
PATCH_FUSE_LOGIC = Falseon currentmain, the minimizedhvcaseshows:
4(-95, -32, 32, 95)2(-127, 127)equal=False,max_abs_diff=95,mae=31.75With
PATCH_FUSE_LOGIC = True, the same repro shows:4equal=True,max_abs_diff=0So a single mixed-input
torch.cat((h, v), dim=1)case is sufficient todemonstrate the bug and validate the proposed fix shape.
Expected behavior
For quantized constant folding of
aten.cat([lhs, rhs], dim=1):qparams before eager evaluation in
FuseConstantArgsPassquantization drift
Actual behavior
For mixed-input cases:
aten.catinput list is dequantizedFuseConstantArgsPass.tosafileRoot cause
The bug is an inconsistency between two Arm passes.
1.
FoldAndAnnotateQParamsPassFile:
backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.pyThis pass stores
input_qparamsper top-level argument index.For
aten.cat([tensor0, tensor1], dim=1)the tensor list is argument0, sothe metadata becomes:
even though the list contains multiple tensor inputs.
2.
FuseConstantArgsPassFile:
backends/arm/_passes/fuse_constant_ops_pass.pyThis pass currently consumes qparams as if they were keyed by flattened
node.all_input_nodesindex:For
aten.cat([tensor0, tensor1], dim=1):input_nodes == [tensor0, tensor1]input_qparams == {0: qparams}As a result:
tensor0is dequantizedtensor1is left as raw int8 codes cast to floatThat corrupts the folded result before serialization.
Why duplicated-input controls can pass
If the same constant node is used twice, for example:
torch.cat((horizontal, horizontal), dim=1)torch.cat((vertical, vertical), dim=1)Because
node.all_input_nodescontains unique nodes, the duplicated cases onlyexpose one input node to the buggy code path. That accidentally matches the
single qparam entry
{0: ...}and hides the bug.Proposed fix
The smallest fix is to make
FuseConstantArgsPassresolve qparams by top-levelargument index and propagate that qparam through nested list/tuple arguments.
That matches the contract already established by
FoldAndAnnotateQParamsPass, and it is safer than changing the metadata modelbecause the rest of the Arm backend already treats
input_qparamsasargument-index keyed.
Suggested patch shape:
Fix validation
The proposed fix shape was validated by rerunning the inline repro above with
PATCH_FUSE_LOGIC = True.Before the patch:
52equal=False,max_abs_diff=95After the patch:
54equal=True,max_abs_diff=0This fix shape corrects the emitted folded TOSA constant end-to-end for the
minimized repro.
Versions
Collecting environment information...
PyTorch version: 2.10.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Enterprise (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.5 (tags/v3.10.5:f377153, Jun 6 2022, 16:14:13) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti
Nvidia driver version: 576.88
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Name: 13th Gen Intel(R) Core(TM) i9-13900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3000
MaxClockSpeed: 3000
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None
Versions of relevant libraries:
[pip3] executorch==1.2.0.dev20260305+cpu
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] torchvision==0.25.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell