π Describe the bug
[QNN] KeyError: 'aten.pow.Scalar' β QNN Partitioner lacks node visitor for aten.pow.Scalar
Bug Description
During PT2E quantization export via to_edge_transform_and_lower_to_qnn() on the Qualcomm QNN (HTP) backend, the QNN partitioner raises a KeyError when it encounters aten.pow.Scalar nodes. This happens because the QNN HTP node visitor registration does not include a visitor for aten.pow.Scalar.
Error Traceback
File "executorch/backends/qualcomm/export_utils.py", line 569, in build_executorch_binary
File "executorch/backends/qualcomm/utils/utils.py", line 456, in to_edge_transform_and_lower_to_qnn
File "executorch/exir/program/_program.py", line 1407, in to_edge_transform_and_lower
File "executorch/exir/program/_program.py", line 1709, in to_backend
File "executorch/exir/backend/backend_api.py", line 721, in _
File "executorch/exir/backend/partitioner.py", line 66, in __call__
File "executorch/backends/qualcomm/partition/qnn_partitioner.py", line 221, in partition
File "executorch/backends/qualcomm/partition/qnn_partitioner.py", line 171, in generate_partitions
File "executorch/exir/backend/canonical_partitioners/pattern_op_partitioner.py", line 54, in generate_partitions_from_list_of_nodes
File "executorch/backends/qualcomm/partition/qnn_partitioner.py", line 107, in is_node_supported
KeyError: 'aten.pow.Scalar'
Environment
- ExecuTorch version: 1.3.0a0+490ec5c
- Backend: Qualcomm QNN (HTP)
- Target SoC: SM8850 / SM8650
- Quantization: PT2E w8a8 (int8)
- PyTorch version: 2.11.0+cu128
- Export command:
torch.export.export(model, sample_input, strict=False) β to_edge_transform_and_lower_to_qnn(qnn_config)
Root Cause
The QNN HTP backend's node visitor registration table in executorch/backends/qualcomm/builders/ does not include a visitor for aten.pow.Scalar. When the partitioner validates exported nodes, it finds the aten.pow.Scalar op but cannot find a corresponding visitor to lower it to a QNN op, resulting in a KeyError.
Note that aten.pow.Tensor (where the exponent is an explicit torch.Tensor) and aten.pow.out are separate aten overloads and may or may not have visitor coverage.
Proposed Solutions
I have identified two approaches and would like the community's guidance on which is preferred.
Solution A: Add a QNN node visitor for aten.pow.Scalar
Implement a new visitor class (e.g., PowScalarVisitor) in executorch/backends/qualcomm/builders/ that maps aten.pow.Scalar to the QNN Power op in the HTP backend.
Solution B: Graph rewrite β convert aten.pow.Scalar to aten.pow.Tensor before partitioner evaluation
Apply a pre-export FX graph transformation that rewrites aten.pow.Scalar nodes into aten.pow.Tensor nodes:
import torch.fx as fx
class PowScalarToTensorRewriter(fx.Interpreter):
"""Rewrite aten.pow.Scalar β aten.pow.Tensor for QNN compatibility."""
def run_node(self, n: fx.Node):
if n.target == torch.ops.aten.pow.Scalar:
x = n.args[0]
# Get scalar exponent from args or kwargs
scalar_exp = n.args[1] if len(n.args) > 1 else n.kwargs.get("exponent")
if scalar_exp is None:
return super().run_node(n)
with self.graph.inserting_before(n):
tensor_exp = self.graph.call_function(
torch.tensor, args=([scalar_exp],)
)
new_node = self.graph.call_function(
torch.ops.aten.pow.Tensor, args=(x, tensor_exp)
)
n.replace_all_uses_with(new_node)
self.graph.erase_node(n)
return new_node
return super().run_node(n)
# Usage: apply before partitioner
traced = torch.export.export(model, sample_input, strict=False)
rewritten = PowScalarToTensorRewriter(traced.graph).run()
Questions for the Community
- Which approach is preferred? Solution A (native QNN visitor) or Solution B (graph rewrite)?
- For Solution A β is the QNN HTP
Power op available and stable enough to expose as a visitor?
- For Solution B β does
aten.pow.Tensor have QNN visitor coverage, or would the same rewrite be needed for that op too?
Any guidance on the preferred direction would be greatly appreciated. Happy to contribute either a visitor implementation or a graph rewrite utility, depending on what the maintainers prefer.
Sorry to bother @abhinaykukkadapu @shewu-quic
Labels: bug, module:qualcomm, backend:qualcomm
Versions
$ python collect_env.py
Collecting environment information...
PyTorch version: 2.11.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.10
Libc version: glibc-2.35
Python version: 3.10.20 | packaged by conda-forge | (main, Mar 5 2026, 16:42:22) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-173-generic-x86_64-with-glibc2.35
CPU:
Architecture: x86_64
Versions of relevant libraries:
[pip3] executorch==1.3.0a0+490ec5c
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.19.0.56
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.11.0+cu128
[pip3] torchao==0.17.0+git02105d46c
[pip3] torchaudio==2.11.0+cu128
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.26.0+cu128
[pip3] triton==3.6.0+git884fdae8
[conda] executorch 1.3.0a0+490ec5c pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.19.0.56 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.28.9 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pytorch-tokenizers 1.1.0 pypi_0 pypi
[conda] torch 2.11.0+cu128 pypi_0 pypi
[conda] torchao 0.17.0+git02105d46c pypi_0 pypi
[conda] torchaudio 2.11.0+cu128 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchtune 0.0.0 pypi_0 pypi
[conda] torchvision 0.26.0+cu128 pypi_0 pypi
[conda] triton 3.6.0+git884fdae8 pypi_0 pypi
cc @cccclai @cbilgin @abhinaykukkadapu
π Describe the bug
[QNN] KeyError: 'aten.pow.Scalar' β QNN Partitioner lacks node visitor for aten.pow.Scalar
Bug Description
During PT2E quantization export via
to_edge_transform_and_lower_to_qnn()on the Qualcomm QNN (HTP) backend, the QNN partitioner raises aKeyErrorwhen it encountersaten.pow.Scalarnodes. This happens because the QNN HTP node visitor registration does not include a visitor foraten.pow.Scalar.Error Traceback
Environment
torch.export.export(model, sample_input, strict=False)βto_edge_transform_and_lower_to_qnn(qnn_config)Root Cause
The QNN HTP backend's node visitor registration table in
executorch/backends/qualcomm/builders/does not include a visitor foraten.pow.Scalar. When the partitioner validates exported nodes, it finds theaten.pow.Scalarop but cannot find a corresponding visitor to lower it to a QNN op, resulting in aKeyError.Note that
aten.pow.Tensor(where the exponent is an explicittorch.Tensor) andaten.pow.outare separate aten overloads and may or may not have visitor coverage.Proposed Solutions
I have identified two approaches and would like the community's guidance on which is preferred.
Solution A: Add a QNN node visitor for
aten.pow.ScalarImplement a new visitor class (e.g.,
PowScalarVisitor) inexecutorch/backends/qualcomm/builders/that mapsaten.pow.Scalarto the QNNPowerop in the HTP backend.Solution B: Graph rewrite β convert
aten.pow.Scalartoaten.pow.Tensorbefore partitioner evaluationApply a pre-export FX graph transformation that rewrites
aten.pow.Scalarnodes intoaten.pow.Tensornodes:Questions for the Community
Powerop available and stable enough to expose as a visitor?aten.pow.Tensorhave QNN visitor coverage, or would the same rewrite be needed for that op too?Any guidance on the preferred direction would be greatly appreciated. Happy to contribute either a visitor implementation or a graph rewrite utility, depending on what the maintainers prefer.
Sorry to bother @abhinaykukkadapu @shewu-quic
Labels: bug, module:qualcomm, backend:qualcomm
Versions
$ python collect_env.py
Collecting environment information...
PyTorch version: 2.11.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.10
Libc version: glibc-2.35
Python version: 3.10.20 | packaged by conda-forge | (main, Mar 5 2026, 16:42:22) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-173-generic-x86_64-with-glibc2.35
CPU:
Architecture: x86_64
Versions of relevant libraries:
[pip3] executorch==1.3.0a0+490ec5c
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.19.0.56
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.11.0+cu128
[pip3] torchao==0.17.0+git02105d46c
[pip3] torchaudio==2.11.0+cu128
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.26.0+cu128
[pip3] triton==3.6.0+git884fdae8
[conda] executorch 1.3.0a0+490ec5c pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.19.0.56 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.28.9 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pytorch-tokenizers 1.1.0 pypi_0 pypi
[conda] torch 2.11.0+cu128 pypi_0 pypi
[conda] torchao 0.17.0+git02105d46c pypi_0 pypi
[conda] torchaudio 2.11.0+cu128 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchtune 0.0.0 pypi_0 pypi
[conda] torchvision 0.26.0+cu128 pypi_0 pypi
[conda] triton 3.6.0+git884fdae8 pypi_0 pypi
cc @cccclai @cbilgin @abhinaykukkadapu