Skip to content

[Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend#18990

Open
KevinUW114514 wants to merge 3 commits intopytorch:mainfrom
KevinUW114514:fix/qnn-layer-norm-none-check
Open

[Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend#18990
KevinUW114514 wants to merge 3 commits intopytorch:mainfrom
KevinUW114514:fix/qnn-layer-norm-none-check

Conversation

@KevinUW114514
Copy link
Copy Markdown

@KevinUW114514 KevinUW114514 commented Apr 19, 2026

[Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend

Summary

Adds QNN backend support for aten.native_layer_norm.default (which is the decomposed form of torch.nn.LayerNorm) and handles models where weight/bias are not provided (elementwise_affine=False).

Problem

When exporting models with torch.native_layer_norm or torch.nn.LayerNorm(affine=False) to the QNN backend, the following issues occur:

  1. Missing native_layer_norm visitor: The original LayerNormVisitor only targets aten.layer_norm.default, but PyTorch decomposes torch.nn.LayerNorm to aten.native_layer_norm.default during export.

  2. None weight/bias: When elementwise_affine=False, the weight and bias arguments are None. QNN x86_64 runtime cannot handle None tensor inputs, causing AttributeError when calling get_parameter().

Solution

1. Update visitor target (op_layer_norm.py)

Change the visitor target from aten.layer_norm.default to aten.native_layer_norm.default:

# Before
target = ["aten.layer_norm.default"]

# After
target = ["aten.native_layer_norm.default"]

This is correct because during ExecuTorch export, aten.layer_norm.default is decomposed to aten.native_layer_norm.default before the QNN lowering stage.

2. Handle None weight/bias (op_layer_norm.py)

When weight/bias are None, create synthetic tensors:

  • Missing weight → torch.ones(normalized_shapes) (identity transform)
  • Missing bias → torch.zeros(normalized_shapes) (no offset)

Create synthetic fx.Node objects to register these as QNN static tensors:

weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
weight_node = torch.fx.Node(
    node.graph,
    node.name + "_runtime_weight",
    "call_function",
    exir_ops.edge.aten.tensor.default,
    (),
    {},
)
# Preserve quant_attrs with zero_point=0 for QNN compatibility

3. Use same annotator for both ops (htp_rules.py)

The quantizer annotator registers both aten.layer_norm.default and aten.native_layer_norm.default to the same LayerNorm class, since both ops have identical argument schemas:

@register_annotator(
    [torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
    QnnConstants.OpLayerNorm.op_name,
)

4. Add None check to get_parameter() (utils.py)

Guard against None nodes to prevent AttributeError:

if node is None:
    return None

Files Changed

File Changes
builders/op_layer_norm.py Add native_layer_norm support + handle None weight/bias
builders/utils.py Add None guard in get_parameter()
quantizer/annotators/htp_rules.py Register annotator for both ops
tests/models.py Add NativeLayerNorm test model
tests/test_qnn_delegate.py Add floating-point and quantized tests

Test Plan

Run QNN delegate tests for layer_norm:

python backends/qualcomm/tests/test_qnn_delegate.py \
    -k "test_qnn_backend_layer_norm or test_qnn_backend_native_layer_norm" \
    --soc_model SM8650 \
    --build_folder build-x86/ \
    --executorch_root . \
    --enable_x86_64

Expected: 4 tests pass (2 floating-point, 2 quantized).

Release Notes

  • Release notes: qualcomm

Related Issues

This resolves the issue where FLUX2 transformer export fails with:

  • [QNN Delegate Op Builder]: LayerNorm weight is None, skipping
  • AttributeError: 'NoneType' object has no attribute 'name'

Fixes #18989

  • Labels: bug, module:qnn

@abhinaykukkadapu

Copilot AI review requested due to automatic review settings April 19, 2026 07:29
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 19, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18990

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Apr 19, 2026

Hi @KevinUW114514!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@KevinUW114514
Copy link
Copy Markdown
Author

@pytorchbot label "release notes: none"

@pytorch-bot pytorch-bot Bot added the release notes: none Do not include this in the release notes label Apr 19, 2026
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Apr 19, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 19, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a crash in the Qualcomm QNN PT2E quantizer by making _mark_nodes_as_annotated robust to None entries in node lists (e.g., when aten.layer_norm has optional affine args like weight=None).

Changes:

  • Skip None entries in _mark_nodes_as_annotated to avoid AttributeError when accessing node.meta.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -29,6 +29,8 @@

def _mark_nodes_as_annotated(nodes: List[Node]):
Comment on lines +32 to +33
if node is None:
continue
@abhinaykukkadapu
Copy link
Copy Markdown
Contributor

abhinaykukkadapu commented Apr 20, 2026

Hi @KevinUW114514 thank you for your contribution. I think the root cause is that we need to guard weight and bias creation in rules files for htp and lpai similar to #18219, let me know if you are willing to change it. Adding the guard might silently propagate bad configs like these in the pipeline and i think we should fail loudly. CC: @shewu-quic

@KevinUW114514
Copy link
Copy Markdown
Author

Hi @abhinaykukkadapu , thanks for the follow-up! Actually I also realized this root issue as I encountered the error in my downstream tasks. I am currently working on fixing this. I can edit the issue and PR to re-state the issue and submit a complete fix for it. Let me know if any concern. Thank you!

@abhinaykukkadapu
Copy link
Copy Markdown
Contributor

Hi @abhinaykukkadapu , thanks for the follow-up! Actually I also realized this root issue as I encountered the error in my downstream tasks. I am currently working on fixing this. I can edit the issue and PR to re-state the issue and submit a complete fix for it. Let me know if any concern. Thank you!

Thanks, that would be awesome, will look forward to your changes.

@KevinUW114514 KevinUW114514 changed the title Fix AttributeError in _mark_nodes_as_annotated when node is None [QNN] Fix AttributeError in _mark_nodes_as_annotated when node is None Apr 20, 2026
Fixes AttributeError when aten.native_layer_norm has optional weight=None.
Both weight and bias are guarded to handle the None case gracefully.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@nil-is-all nil-is-all added the module: qnn Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/ label Apr 20, 2026
… backend

- add QNN layer norm support for aten.native_layer_norm.default
- handle missing weight/bias by creating identity weight and zero bias
- always provide bias tensor for QNN LayerNorm op
- add floating-point and quantized tests for native_layer_norm
- print generated pte filename after export
@KevinUW114514 KevinUW114514 changed the title [QNN] Fix AttributeError in _mark_nodes_as_annotated when node is None [Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: qnn Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/ release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[QNN] AttributeError in _mark_nodes_as_annotated when a layer_norm node has optional weight as None

4 participants