From ee9a981bfe10a62f191fc0c98ff8651595718333 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Sun, 19 Apr 2026 07:22:55 +0000 Subject: [PATCH 1/3] Fix AttributeError in _mark_nodes_as_annotated when node is None --- backends/qualcomm/quantizer/rules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/qualcomm/quantizer/rules.py b/backends/qualcomm/quantizer/rules.py index 878acfea422..77989f8e874 100644 --- a/backends/qualcomm/quantizer/rules.py +++ b/backends/qualcomm/quantizer/rules.py @@ -29,6 +29,8 @@ def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: + if node is None: + continue if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() node.meta[Q_ANNOTATION_KEY]._annotated = True From 40404fdd974e75deec9afb0e299bdbe2cf9a04fc Mon Sep 17 00:00:00 2001 From: KevinJie Date: Mon, 20 Apr 2026 15:15:44 +0000 Subject: [PATCH 2/3] [QNN] Guard get_parameter against node=None in LayerNormVisitor Fixes AttributeError when aten.native_layer_norm has optional weight=None. Both weight and bias are guarded to handle the None case gracefully. Co-Authored-By: Claude Opus 4.6 --- backends/qualcomm/builders/op_layer_norm.py | 26 +++++++++++++-------- backends/qualcomm/builders/utils.py | 4 +++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index a51056eb7bb..9882c77cc7f 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -55,16 +55,22 @@ def define_node( axis_shape = [len(axis)] weight_node = self.get_node(node.args[2]) - weight_tensor = get_parameter(weight_node, self.edge_program) - weight_tensor_wrapper = self.define_tensor( - weight_node, - node, - weight_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - - layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + if weight_node is not None: + weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor_wrapper = self.define_tensor( + weight_node, + node, + weight_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + else: + warnings.warn( + "[QNN Delegate Op Builder]: LayerNorm weight is None, skipping", + stacklevel=1, + ) + layer_norm_input_tensors = [input_tensor_wrapper] bias_node = self.get_node(node.args[3]) if bias_node is not None: diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index 3345f2e1fc9..acf4400f212 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -29,7 +29,9 @@ def is_parameter( def get_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram -) -> torch.Tensor: +) -> Optional[torch.Tensor]: + if node is None: + return None param = None if is_param(edge_program, node): param = get_param(edge_program, node) From 5beaa5774c7f782e2a7e5edaa3bfb56a3d6ef1e2 Mon Sep 17 00:00:00 2001 From: KevinJie Date: Wed, 22 Apr 2026 03:08:10 +0000 Subject: [PATCH 3/3] [Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend - add QNN layer norm support for aten.native_layer_norm.default - handle missing weight/bias by creating identity weight and zero bias - always provide bias tensor for QNN LayerNorm op - add floating-point and quantized tests for native_layer_norm - print generated pte filename after export --- backends/qualcomm/builders/op_layer_norm.py | 84 +++++++++++++------ backends/qualcomm/export_utils.py | 1 + .../quantizer/annotators/htp_rules.py | 3 +- backends/qualcomm/tests/models.py | 20 +++++ backends/qualcomm/tests/test_qnn_delegate.py | 15 ++++ 5 files changed, 96 insertions(+), 27 deletions(-) diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 9882c77cc7f..d4ca7e4589a 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -11,7 +11,12 @@ import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_QUANT_ATTRS, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor @@ -31,6 +36,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + # args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps'] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( @@ -54,37 +60,61 @@ def define_node( axis = [len(input_tensor.shape) - 1] axis_shape = [len(axis)] - weight_node = self.get_node(node.args[2]) - if weight_node is not None: + has_weight = len(node.args) > 2 and node.args[2] is not None + if has_weight: + weight_node = self.get_node(node.args[2]) weight_tensor = get_parameter(weight_node, self.edge_program) - weight_tensor_wrapper = self.define_tensor( - weight_node, - node, - weight_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] else: - warnings.warn( - "[QNN Delegate Op Builder]: LayerNorm weight is None, skipping", - stacklevel=1, + # elementwise_affine=False: use all-ones weight as identity + weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32) + weight_node = torch.fx.Node( + node.graph, + node.name + "_runtime_weight", + "call_function", + exir_ops.edge.aten.tensor.default, + (), + {}, ) - layer_norm_input_tensors = [input_tensor_wrapper] + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + weight_tensor_wrapper = self.define_tensor( + weight_node, + node, + weight_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) - bias_node = self.get_node(node.args[3]) - if bias_node is not None: + # Fake node: even when original bias is absent, QNN still needs it + has_bias = len(node.args) > 3 and node.args[3] is not None + if has_bias: + bias_node = self.get_node(node.args[3]) bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, + else: + bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32) + bias_node = torch.fx.Node( + node.graph, + node.name + "_runtime_bias", + "call_function", + exir_ops.edge.aten.tensor.default, + (), + {}, ) - layer_norm_input_tensors.append(bias_tensor_wrapper) + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_tensor_wrapper = self.define_tensor( + bias_node, + node, + bias_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) - epsilon = node.args[4] + epsilon = node.args[4] if len(node.args) > 4 else 1e-05 output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( @@ -100,7 +130,9 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, ) - layer_norm_op.AddInputTensors(layer_norm_input_tensors) + layer_norm_op.AddInputTensors( + [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] + ) layer_norm_op.AddOutputTensors([output_tensor_wrapper]) layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, diff --git a/backends/qualcomm/export_utils.py b/backends/qualcomm/export_utils.py index 2c7ab2abd02..f66bf9d5858 100644 --- a/backends/qualcomm/export_utils.py +++ b/backends/qualcomm/export_utils.py @@ -617,6 +617,7 @@ def build_executorch_binary( with open(pte_name, "wb") as file: exec_prog_mgr.write_to_file(file) + print(f"Successfully generated {pte_name}.") if qnn_config.compile_only: sys.exit(0) diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index ce01fceca80..5b763bf9b11 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -828,7 +828,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( - [torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name + [torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default], + QnnConstants.OpLayerNorm.op_name, ) class LayerNorm(GeneralOpDef): @staticmethod diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 053e5d26455..d5cb42d72f2 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1388,6 +1388,26 @@ def forward(self, x): return self.linear(self.layer_norm(x)) +class NativeLayerNorm(torch.nn.Module): + def __init__(self, affine=True): + super().__init__() + self.affine = affine + self.weight = torch.nn.Parameter(torch.ones(768)) + self.bias = torch.nn.Parameter(torch.zeros(768)) + self.normalized_shape = [768] + self.eps = 1e-6 + + def forward(self, x): + if self.affine: + return torch.native_layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + )[0] + else: + return torch.native_layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + )[0] + + class LayerNormAdd(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 48f07da06e9..ef1a61715e8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1384,6 +1384,13 @@ def test_qnn_backend_layer_norm(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_native_layer_norm(self): + modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + sample_input = (torch.randn(196, 768),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_leaky_relu(self): torch.manual_seed(8) test_comb = [ @@ -3811,6 +3818,14 @@ def test_qnn_backend_layer_norm(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_native_layer_norm(self): + modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405 + sample_input = (torch.randn(196, 768),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_leaky_relu(self): test_comb = [ {