diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index c01ce225d29..ff30acf2103 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -25,6 +25,12 @@ runtime.python_library( "//executorch/backends/arm:common", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/tosa/dialect:lib", + "//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops", + "//executorch/backends/transforms:fuse_cascaded_view_ops", + "//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass", + "//executorch/backends/transforms:remove_permutes_around_elementwise_ops", + "//executorch/backends/transforms:postpone_permute_below_squeeze_view", + "//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view", "//executorch/exir:lib", ], ) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 48ab88c7939..5a6bc2a9d40 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -7,7 +7,6 @@ from . import arm_pass_utils # noqa from .arm_pass import ArmPass # noqa # usort: skip from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa -from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa @@ -165,7 +164,6 @@ from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_input_pass import SizeAdjustInputPass # noqa -from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip diff --git a/backends/arm/_passes/annotate_output_dim_order_pass.py b/backends/arm/_passes/annotate_output_dim_order_pass.py deleted file mode 100644 index 3124bd98532..00000000000 --- a/backends/arm/_passes/annotate_output_dim_order_pass.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Set, Type - -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders -from executorch.exir.pass_base import ExportPass, PassResult - - -class AnnotateOutputDimOrderPass(ArmPass): - """Stores the current output dim_orders in the meta dict of the output node. - - This is used for verifying that the dim order does not change unexpectedly - in later passes. - - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - def call(self, graph_module): - output_node = graph_module.graph.output_node() - output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module) - - return PassResult(graph_module, True) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 0455b0f5fe4..16b4c840d88 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -12,7 +12,6 @@ from executorch.backends.arm._passes import ( AccumulateIndexPutPass, - AnnotateOutputDimOrderPass, BroadcastArgsPass, CanonicalizeGatherPass, CastInt64BuffersToInt32Pass, @@ -44,7 +43,6 @@ DecomposeAtanPass, DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, - DecomposeConvWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -58,7 +56,6 @@ DecomposeFloorDividePass, DecomposeGeluPass, DecomposeGluPass, - DecomposeGroupedConvPass, DecomposeGroupNormPass, DecomposeGruPass, DecomposeIndexCopyPass, @@ -141,7 +138,6 @@ RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, - ToTosaMemoryFormatPass, UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) @@ -157,7 +153,26 @@ TosaLoweringContext, TosaSpecification, ) +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.fuse_cascaded_view_ops import ( + FuseCascadedViewOps, +) +from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import ( + FuseTransposeOrPermuteOpPairsPass, +) +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps, +) +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass, +) from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager from torch._export.utils import _get_shape_env_from_gm @@ -385,9 +400,6 @@ def _tosa_pipeline( # Allow subclasses to configure pass insertions before building pipeline self._configure_pass_insertions(exported_program) - # Preprocessing passes - self.add_pass(AnnotateOutputDimOrderPass()) - # Node transformation passes (pre q/dq folding) self.add_passes( [ @@ -455,7 +467,6 @@ def _tosa_pipeline( DecomposeFloorDividePass(), DecomposeGeluPass(), DecomposeAddSubAlphaPass(), - DecomposeGroupedConvPass(), DecomposeUnfoldToGatherPass(), DecomposeEmbeddingPass(), DecomposeIndexSelectToGatherPass(), @@ -518,7 +529,6 @@ def _tosa_pipeline( ConvertPermuteSingletonToViewPass(), RewriteHighRankSingletonPermutePass(), FuseViewCopyTransformPass(), - DecomposeConvWithInt16ActivationPass(), DecomposeSumPass(), InsertTableOpsPass(exported_program), ] @@ -532,7 +542,6 @@ def _tosa_pipeline( RewriteConvPass(exported_program), RewriteMatmulPass(), RewritePadPass(), - RewriteSlicePass(), InsertConstShapesPass(), ] ) @@ -542,14 +551,40 @@ def _tosa_pipeline( [ CastInt64BuffersToInt32Pass(exported_program), FuseEqualPlaceholdersPass(exported_program), + FuseConstantArgsPass(exported_program), FuseConsecutiveConcatShapesPass(), - ToTosaMemoryFormatPass(exported_program), RemoveNoopPass(), InsertRescalePass(), InsertDataLayoutCastsPass(), ] ) + # Additional optimization passes for permutes + # Fuse identity permute pairs across RESCALE ops + fuse_pairs = FuseTransposeOrPermuteOpPairsPass() + fuse_pairs.bypass_ops = fuse_pairs.bypass_ops | { + exir_ops.backend.tosa.RESCALE.default, + } + + # Remove permutes around elementwise ops including RESCALE + remove_around = RemovePermutesAroundElementwiseOps() + remove_around.permutable_ops = remove_around.permutable_ops | { + exir_ops.backend.tosa.RESCALE.default, + } + + self.add_passes( + [ + remove_around, + RewriteSlicePass(), + fuse_pairs, + ReplaceNopTransposeOrPermuteWithViewPass(), + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + FuseCascadedTransposeOrPermuteOps(), + FuseCascadedViewOps(), + InsertConstShapesPass(), + ] + ) + # Apply all pass insertions once after all passes are collected self._apply_pass_insertions() diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 7f9b47d3e01..afdcb65121b 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -352,11 +352,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): raise RuntimeError("Invalid type") -def get_output_dim_orders(graph_module): - output_node = graph_module.graph.output_node() - return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]] - - def is_nested_control_flow_graph(graph_module: GraphModule) -> bool: """Returns True if graph_module is a nested control-flow graph.""" diff --git a/backends/arm/_passes/rewrite_avg_pool2d_pass.py b/backends/arm/_passes/rewrite_avg_pool2d_pass.py index 2f71bdda4a2..cd90cb09de5 100644 --- a/backends/arm/_passes/rewrite_avg_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_avg_pool2d_pass.py @@ -7,69 +7,139 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult from .fuse_constant_ops_pass import ComputeConstantOpsAOTPass +_NCHW_TO_NHWC = [0, 2, 3, 1] +_NHWC_TO_NCHW = [0, 3, 1, 2] + class RewriteAvgPool2dPass(ArmPass): - """Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op.""" + """Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op with NHWC layout.""" - # Target the original avg_pool2d operator targeted_ops = {exir_ops.edge.aten.avg_pool2d.default} _passes_required_after: Set[Type[ExportPass]] = { ComputeConstantOpsAOTPass, } - def call_operator(self, op, args, kwargs, meta, updated=False): - - # Only rewrite avg_pool2d - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta, updated) - - x = args[0] - pad_h, pad_w = args[3] - # Make sure pad corresponds to TOSA - pad = [pad_h, pad_w, pad_h, pad_w] - - _, _, h, w = x.data.shape - kernel_h, kernel_w = args[1] - stride_h, stride_w = args[2] - - ceil_mode = args[4] if len(args) > 4 else False - - # Adjust padding if necessary - pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode) - pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode) - - # Materialize zero-point constants - in_qparams = meta.data.get("input_qparams", {}) - in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0 - # Materialize input zero-point as a scalar tensor - input_zp = super().call_scalar(in_zp_val, meta) - - out_qparams = meta.data.get("output_qparams", {}) - out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0 - # Materialize output zero-point as a scalar tensor - output_zp = super().call_scalar(out_zp_val, meta) - - # Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise - if x.data.dtype in (torch.int8, torch.int16): - acc_type = torch.int32 - else: - acc_type = torch.float32 - - tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type) - - # Emit TOSA AVG_POOL2D with normalized args - return super().call_operator( - exir_ops.backend.tosa.AVG_POOL2D.default, - tosa_args, - {}, - meta, - True, + @staticmethod + def _insert_permute(graph_module, anchor_node, input_node, perm, before=True): + ctx = ( + graph_module.graph.inserting_before(anchor_node) + if before + else graph_module.graph.inserting_after(anchor_node) ) + with ctx: + return create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(input_node, perm), + from_node=input_node, + ) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + modified = True + x = node.args[0] + + pad_h, pad_w = node.args[3] + pad = [pad_h, pad_w, pad_h, pad_w] + + input_fake = get_first_fake_tensor(x) + _, _, h, w = input_fake.shape + kernel_h, kernel_w = node.args[1] + stride_h, stride_w = node.args[2] + + ceil_mode = node.args[4] if len(node.args) > 4 else False + + pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode) + pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode) + + # Determine zero-points and accumulator type + in_qparams = node.meta.get("input_qparams", {}) + in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0 + + out_qparams = node.meta.get("output_qparams", {}) + out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0 + + if input_fake.dtype in (torch.int8, torch.int16): + acc_type = torch.int32 + else: + acc_type = torch.float32 + + # Insert NCHW → NHWC permute on input + x_permuted = self._insert_permute( + graph_module, node, x, _NCHW_TO_NHWC, before=True + ) + + # Materialize zp scalars as graph constants using aten.full with + # explicit dtype matching the input tensor. This ensures the + # pre-computed buffer placeholders carry the correct type for + # INT-only TOSA profiles (avoids defaulting to float32). + zp_kwargs = {"dtype": input_fake.dtype, "device": input_fake.device} + with graph_module.graph.inserting_before(node): + input_zp_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.full.default, + args=((1,), in_zp_val), + kwargs=zp_kwargs, + from_node=node, + ) + output_zp_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.full.default, + args=((1,), out_zp_val), + kwargs=zp_kwargs, + from_node=node, + ) + + kernel = list(node.args[1]) + stride = list(node.args[2]) + + tosa_args = (x_permuted, input_zp_node, output_zp_node, kernel, stride, pad, acc_type) + + # Create TOSA AVG_POOL2D node + with graph_module.graph.inserting_after(node): + tosa_op = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.AVG_POOL2D.default, + args=tosa_args, + from_node=node, + inherit_qparams=True, + ) + + # Compute correct NHWC FakeTensor + input_fake_nhwc = input_fake.permute(_NCHW_TO_NHWC) + input_zp_fake = torch.tensor(in_zp_val, dtype=input_fake.dtype) + output_zp_fake = torch.tensor(out_zp_val, dtype=input_fake.dtype) + tosa_node_fake = exir_ops.backend.tosa.AVG_POOL2D.default( + input_fake_nhwc, input_zp_fake, output_zp_fake, kernel, stride, pad, acc_type + ) + tosa_op.meta["val"] = tosa_node_fake + + # Insert NHWC → NCHW permute on output + output_permute = self._insert_permute( + graph_module, tosa_op, tosa_op, _NHWC_TO_NCHW, before=False + ) + + node.replace_all_uses_with(output_permute) + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index e4be0b5dc25..70e118d198f 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -5,7 +5,7 @@ import itertools -from typing import Any, Set, Type +from typing import Any, cast, Set, Type import torch from executorch.backends.arm._passes import ArmPass @@ -21,6 +21,7 @@ get_input_qparams, get_output_qparams, ) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import get_context_shape_env @@ -42,6 +43,22 @@ def __init__(self, exported_program: torch.export.ExportedProgram, *args, **kwar _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod + def _nchw_to_nhwc_perm(rank: int) -> list[int]: + if rank == 4: + return [0, 2, 3, 1] + if rank == 5: + return [0, 2, 3, 4, 1] + return list(range(rank)) + + @staticmethod + def _nhwc_to_nchw_perm(rank: int) -> list[int]: + if rank == 4: + return [0, 3, 1, 2] + if rank == 5: + return [0, 4, 1, 2, 3] + return list(range(rank)) + # torch.nn.Conv2d does not require the result of # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` # to be an integer, but tosa currently strictly require this property. @@ -264,6 +281,372 @@ def insert_output_rescale(self, graph_module, source_node, conv_node): ) return rescale_node + def _insert_permute(self, graph_module, anchor_node, input_node, perm, before=True): + ctx = ( + graph_module.graph.inserting_before(anchor_node) + if before + else graph_module.graph.inserting_after(anchor_node) + ) + with ctx: + return create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(input_node, perm), + from_node=input_node, + ) + + def _is_grouped_conv(self, node: torch.fx.Node) -> bool: + """Return True for grouped convolutions that need decomposition. + + Depthwise convolutions (groups == in_channels) are handled natively + by TOSA and are *not* considered grouped here. + """ + groups = node.args[-1] + if groups <= 1: + return False + input_tensor = get_first_fake_tensor(node.all_input_nodes[0]) + if len(input_tensor.shape) != 4: + return False + return not self._is_depthwise_conv2d(node) + + def _handle_grouped_conv( # noqa: C901 + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + x: torch.fx.Node, + weight: torch.fx.Node, + bias: torch.fx.Node | None, + stride_list: list[int], + pad_list: list[int], + dilation_list: list[int], + group: int, + input_shape: torch.Size, + weight_shape: torch.Size, + spatial_rank: int, + rank: int, + ) -> torch.fx.Node: + """Decompose a grouped conv into per-group TOSA.CONV2D ops in NHWC. + + Produces a single input permute (NCHW→NHWC) and a single output + permute (NHWC→NCHW), with the per-group slice / conv / cat operating + entirely in NHWC. This avoids the problematic pattern of one permute + pair per sub-conv that downstream optimisation passes can mishandle. + """ + nchw_to_nhwc = self._nchw_to_nhwc_perm(rank) + nhwc_to_nchw = self._nhwc_to_nchw_perm(rank) + nhwc_channel_dim = rank - 1 + + in_channels = input_shape[1] + out_channels = get_first_fake_tensor(node).shape[1] + input_slice_size = in_channels // group + output_slice_size = out_channels // group + + # Compute TOSA pad attribute (same logic as the non-grouped path). + pad_attr: list[int] = [] + for value in pad_list: + pad_attr.extend([value, value]) + for axis_index in range(spatial_rank): + pad_index = axis_index * 2 + 1 + pad_attr[pad_index] = self._adjust_pad_if_needed( + input_shape[axis_index + 2], + weight_shape[axis_index + 2], + stride_list[axis_index], + pad_attr[pad_index], + dilation_list[axis_index], + ) + stride_tuple = tuple(stride_list) + dilation_tuple = tuple(dilation_list) + + weight_perm = self._nchw_to_nhwc_perm(len(weight_shape)) + + # ---- Quantisation info ------------------------------------------ + input_dtype = get_first_fake_tensor(x).dtype + is_quantized = self._is_quantized_conv(node) + has_qparam_bias = ( + is_quantized and len(node.meta.get("input_qparams", {})) > 2 + ) + is_int8 = is_quantized and input_dtype == torch.int8 + is_int16_with_bias = ( + is_quantized and input_dtype == torch.int16 and has_qparam_bias + ) + is_int16_no_bias = ( + is_quantized and input_dtype == torch.int16 and not has_qparam_bias + ) + + original_bias = bias # Keep for INT16+bias decomposition + + # Pre-compute rescale factors for INT8 / INT16-no-bias paths. + full_weight_scale: list[float] = [] + input_scale = 0.0 + output_scale = 0.0 + output_zp = 0 + rescale_dtype = torch.int8 + if is_int8 or is_int16_no_bias: + iq = get_input_qparams(node) + oq = self._get_effective_output_qparams(node)[0] + wq = iq[1] + if wq.per_channel: + full_weight_scale = wq.get_scale_per_channel() + else: + full_weight_scale = [wq.get_scale_per_tensor()] + input_scale = iq[0].get_scale_per_tensor() + output_scale = oq.get_scale_per_tensor() + output_zp = oq.get_zp_per_tensor() + rescale_dtype = oq.dtype + + # ---- ONE input permute NCHW→NHWC -------------------------------- + x_permuted = self._insert_permute( + graph_module, node, x, nchw_to_nhwc, before=True + ) + + group_outputs: list[torch.fx.Node] = [] + cursor = node + + for g in range(group): + # Slice NHWC input along channel dim + with graph_module.graph.inserting_before(node): + sliced_input = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=( + x_permuted, + nhwc_channel_dim, + g * input_slice_size, + (g + 1) * input_slice_size, + ), + from_node=x, + ) + + # Slice weight along output-channel dim (dim 0 in OIHW) + with graph_module.graph.inserting_before(node): + sliced_weight = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=( + weight, + 0, + g * output_slice_size, + (g + 1) * output_slice_size, + ), + from_node=weight, + ) + + # Permute weight OIHW→OHWI + sliced_weight_permuted = self._insert_permute( + graph_module, node, sliced_weight, weight_perm, before=True + ) + + # ---- Per-group bias ----------------------------------------- + if is_int16_with_bias or is_int16_no_bias: + # INT16: TOSA conv always needs an INT48-tagged zero bias. + # Create a fresh per-group constant so the tag survives + # constant-folding passes. + zb = torch.zeros(size=(output_slice_size,), dtype=torch.int32) + with graph_module.graph.inserting_after(weight): + group_bias = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=zb, + persistent_buffer=True, + name=f"{node.name}_g{g}_zero_bias", + ) + group_bias.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + elif bias is not None: + with graph_module.graph.inserting_before(node): + group_bias = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=( + bias, + 0, + g * output_slice_size, + (g + 1) * output_slice_size, + ), + from_node=bias, + ) + # Propagate INT48 tag from parent bias if present. + if bias.meta.get(TosaSpecialDtype.meta_key()) == TosaSpecialDtype.INT48: + group_bias.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + else: + dtype = torch.int32 if is_quantized else node.meta["val"].dtype + zb = torch.zeros(size=(output_slice_size,), dtype=dtype) + with graph_module.graph.inserting_after(weight): + group_bias = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=zb, + persistent_buffer=True, + name=f"{node.name}_g{g}_bias", + ) + if input_dtype == torch.int16: + group_bias.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + + # ---- TOSA.CONV2D -------------------------------------------- + conv_args = ( + sliced_input, + sliced_weight_permuted, + group_bias, + stride_tuple, + pad_attr, + dilation_tuple, + ) + with graph_module.graph.inserting_after(cursor): + tosa_op = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.CONV2D.default, + args=conv_args, + from_node=node, + inherit_qparams=True, + ) + cursor = tosa_op + + # ---- Per-group quantised output ----------------------------- + if is_int8 or is_int16_no_bias: + if len(full_weight_scale) > 1: # per-channel + gws = full_weight_scale[ + g * output_slice_size : (g + 1) * output_slice_size + ] + else: + gws = full_weight_scale + gscale = [(input_scale * w) / output_scale for w in gws] + with graph_module.graph.inserting_after(cursor): + rescale = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=(tosa_op, rescale_dtype, gscale, 0, output_zp), + from_node=node, + ) + if is_int16_no_bias: + tosa_op.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + cursor = rescale + group_outputs.append(rescale) + elif is_int16_with_bias: + # Full per-group INT16+bias decomposition so that each + # group is self-contained (required by U55 Vela). + output_qparams = cast( + QuantArgs, node.meta["output_qparams"][0] + ) + bias_qparams = cast( + QuantArgs, node.meta["input_qparams"][2] + ) + if bias_qparams.per_channel: + full_bias_scale = bias_qparams.get_scale_per_channel() + else: + full_bias_scale = [bias_qparams.get_scale_per_tensor()] + + tosa_op.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + + # 1. RESCALE INT48 → INT32 (identity) + with graph_module.graph.inserting_after(cursor): + int48_rescale = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=( + tosa_op, + torch.int32, + [1.0] * output_slice_size, + 0, + 0, + ), + from_node=node, + ) + cursor = int48_rescale + + # 2. Slice original bias for this group + with graph_module.graph.inserting_before(node): + group_bias_slice = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=( + original_bias, + 0, + g * output_slice_size, + (g + 1) * output_slice_size, + ), + from_node=original_bias, + ) + + # 3. Reshape sliced bias to NHWC: [1, 1, ..., 1, C_group] + group_bias_view_shape = [ + 1, + *([1] * (rank - 2)), + output_slice_size, + ] + with graph_module.graph.inserting_after(cursor): + group_bias_view = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(group_bias_slice, group_bias_view_shape), + from_node=original_bias, + ) + cursor = group_bias_view + + # 4. ADD bias (INT32, NHWC) + with graph_module.graph.inserting_after(cursor): + group_bias_add = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.add.Tensor, + args=(int48_rescale, group_bias_view), + from_node=node, + ) + cursor = group_bias_add + + # 5. RESCALE INT32 → INT16 with group-specific scale + if len(full_bias_scale) > 1: # per-channel + gbs = full_bias_scale[ + g * output_slice_size : (g + 1) * output_slice_size + ] + else: + gbs = full_bias_scale + group_final_scale = [ + b / output_qparams.scale for b in gbs + ] + with graph_module.graph.inserting_after(cursor): + group_final_rescale = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=( + group_bias_add, + output_qparams.dtype, + group_final_scale, + 0, + 0, + ), + from_node=node, + ) + cursor = group_final_rescale + group_outputs.append(group_final_rescale) + else: + group_outputs.append(tosa_op) + + # ---- Cat along NHWC channel dim --------------------------------- + with graph_module.graph.inserting_after(cursor): + cat_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.cat.default, + args=(group_outputs, nhwc_channel_dim), + from_node=node, + ) + cursor = cat_node + + # ---- ONE output permute NHWC→NCHW ------------------------------- + output_permute = self._insert_permute( + graph_module, cursor, cursor, nhwc_to_nchw, before=False + ) + return output_permute + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False for node in graph_module.graph.nodes: @@ -302,6 +685,28 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 if not has_bias: bias = self._add_bias(graph_module, node, weight) + # Insert activation permute: NCHW → NHWC + rank = len(input_shape) + nchw_to_nhwc = self._nchw_to_nhwc_perm(rank) + nhwc_to_nchw = self._nhwc_to_nchw_perm(rank) + + # Grouped conv (not depthwise): decompose in NHWC with a single + # input/output permute pair so downstream permute-optimisation + # passes cannot break the output layout. + if not transposed and self._is_grouped_conv(node): + result = self._handle_grouped_conv( + graph_module, node, x, weight, bias, + stride_list, pad_list, dilation_list, + group, input_shape, weight_shape, spatial_rank, rank, + ) + node.replace_all_uses_with(result) + graph_module.graph.erase_node(node) + continue + + x_permuted = self._insert_permute( + graph_module, node, x, nchw_to_nhwc, before=True + ) + conv_args: tuple[Any, ...] if transposed: if spatial_rank != 2: @@ -322,9 +727,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 -pad_list[1] + output_padding_list[1], ] target_op = exir_ops.backend.tosa.TRANSPOSE_CONV2D.default + # Weight permute: IOHW → OHWI + weight_perm = [1, 2, 3, 0] + weight_permuted = self._insert_permute( + graph_module, node, weight, weight_perm, before=True + ) conv_args = ( - x, - weight, + x_permuted, + weight_permuted, bias, out_pad, stride, @@ -353,16 +763,36 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 target_op = exir_ops.backend.tosa.CONV3D.default elif self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default - # If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them. - if all(user.target != target_op for user in weight.users): + # If there are any TOSA.DEPTHWISE_CONV2D nodes using + # the weights (possibly via a permute_copy), we've + # already reshaped them. + already_reshaped = any( + user.target == target_op + or ( + user.target + == exir_ops.edge.aten.permute_copy.default + and any( + u2.target == target_op for u2 in user.users + ) + ) + for user in weight.users + ) + if not already_reshaped: self._reshape_weights(weight, input_fake_tensor.shape[1]) weight_fake_tensor = get_first_fake_tensor(weight) else: target_op = exir_ops.backend.tosa.CONV2D.default + # Weight permute: OIHW → OHWI (or reshaped depthwise equivalent) + weight_perm = self._nchw_to_nhwc_perm( + len(weight_fake_tensor.shape) + ) + weight_permuted = self._insert_permute( + graph_module, node, weight, weight_perm, before=True + ) conv_args = ( - x, - weight, + x_permuted, + weight_permuted, bias, stride, pad, @@ -378,19 +808,32 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 inherit_qparams=True, ) bias_fake_tensor = get_first_fake_tensor(bias) if bias else None + input_fake_nhwc = input_fake_tensor.permute(nchw_to_nhwc) + weight_fake_permuted = weight_fake_tensor.permute(weight_perm) + tosa_node_fake_tensor = target_op( - input_fake_tensor, - weight_fake_tensor, + input_fake_nhwc, + weight_fake_permuted, bias_fake_tensor, *conv_args[3:], ) + # Insert output permute: NHWC → NCHW + output_permute = self._insert_permute( + graph_module, tosa_op, tosa_op, nhwc_to_nchw, before=False + ) + if ( tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int8 ): output_rescale = self.insert_output_rescale(graph_module, node, tosa_op) - node.replace_all_uses_with(output_rescale) + output_permute_after_rescale = self._insert_permute( + graph_module, output_rescale, output_rescale, nhwc_to_nchw, before=False + ) + output_permute.replace_all_uses_with(tosa_op) + graph_module.graph.erase_node(output_permute) + node.replace_all_uses_with(output_permute_after_rescale) elif ( tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int16 @@ -400,12 +843,126 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 output_rescale = self.insert_output_rescale( graph_module, node, tosa_op ) - node.replace_all_uses_with(output_rescale) + output_permute_after_rescale = self._insert_permute( + graph_module, output_rescale, output_rescale, nhwc_to_nchw, before=False + ) + output_permute.replace_all_uses_with(tosa_op) + graph_module.graph.erase_node(output_permute) + node.replace_all_uses_with(output_permute_after_rescale) else: - node.replace_all_uses_with(tosa_op) + # INT16 conv with bias: the TOSA conv produces INT48 + # output. We handle the full bias decomposition here + # entirely in NHWC layout, with the output permute placed + # AFTER the final RESCALE. + # + # Graph produced: + # tosa.CONV2D(input, weight, zero_bias_INT48) → INT48, NHWC + # → RESCALE(INT48 → INT32, scale=1.0, NHWC) + # → ADD(bias reshaped to [1,1,...,1,C] for NHWC broadcast) + # → RESCALE(INT32 → INT16, final_scale, NHWC) + # → permute(NHWC → NCHW) + + # Save original bias before replacing with zero bias + original_bias_node = node.args[2] + + # Create a zero bias tagged INT48 for the TOSA conv + output_channels = get_first_fake_tensor(node).shape[1] + zero_bias_data = torch.zeros( + size=(output_channels,), dtype=torch.int32 + ) + with graph_module.graph.inserting_after(weight): + zero_bias_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=zero_bias_data, + persistent_buffer=True, + name=f"{node.name}_zero_bias", + ) + zero_bias_node.meta[ + TosaSpecialDtype.meta_key() + ] = TosaSpecialDtype.INT48 + # Replace original bias with zero bias in tosa conv + bias_arg_index = list(tosa_op.args).index(bias) + tosa_op.update_arg(bias_arg_index, zero_bias_node) + + output_qparams = cast( + QuantArgs, node.meta["output_qparams"][0] + ) + bias_qparams = cast( + QuantArgs, node.meta["input_qparams"][2] + ) + if bias_qparams.per_channel: + bias_scale = bias_qparams.get_scale_per_channel() + else: + bias_scale = [bias_qparams.get_scale_per_tensor()] + + # Remove the original output permute first — we'll add + # a new one at the very end of the chain. + output_permute.replace_all_uses_with(tosa_op) + graph_module.graph.erase_node(output_permute) + + # Build the chain sequentially, each node after the + # previous, so graph ordering matches logical ordering. + # Use a cursor variable to track insertion point. + cursor = tosa_op + + # 1. RESCALE INT48 → INT32 (NHWC) + conv_rescale_factors = [1.0] * len(bias_scale) + with graph_module.graph.inserting_after(cursor): + int48_rescale = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=(tosa_op, torch.int32, conv_rescale_factors, 0, 0), + from_node=node, + ) + cursor = int48_rescale + + # 2. Reshape bias to NHWC: [1, 1, ..., 1, C] + bias_data = get_first_fake_tensor(original_bias_node) + bias_view_shape = [1, *([1] * (rank - 2)), bias_data.shape[0]] + with graph_module.graph.inserting_after(cursor): + bias_view = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(original_bias_node, bias_view_shape), + from_node=original_bias_node, + ) + cursor = bias_view + + # 3. ADD bias (NHWC) + with graph_module.graph.inserting_after(cursor): + bias_add = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.add.Tensor, + args=(int48_rescale, bias_view), + from_node=node, + ) + cursor = bias_add + + # 4. RESCALE INT32 → output dtype (NHWC) + final_output_scale = [ + b / output_qparams.scale for b in bias_scale + ] + with graph_module.graph.inserting_after(cursor): + final_rescale = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=(bias_add, output_qparams.dtype, final_output_scale, 0, 0), + from_node=node, + ) + cursor = final_rescale + + # 5. Output permute NHWC → NCHW (LAST in chain) + output_permute_after_bias = self._insert_permute( + graph_module, cursor, cursor, nhwc_to_nchw, before=False + ) + + node.replace_all_uses_with(output_permute_after_bias) + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 else: - node.replace_all_uses_with(tosa_op) + node.replace_all_uses_with(output_permute) graph_module.graph.erase_node(node) diff --git a/backends/arm/_passes/rewrite_max_pool2d_pass.py b/backends/arm/_passes/rewrite_max_pool2d_pass.py index 123d21eda1f..8080f4f7a57 100644 --- a/backends/arm/_passes/rewrite_max_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_max_pool2d_pass.py @@ -5,15 +5,23 @@ from typing import Set, Type +import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult edge_max_pool2d_ops = (exir_ops.edge.aten.max_pool2d.default,) +_NCHW_TO_NHWC = [0, 2, 3, 1] +_NHWC_TO_NCHW = [0, 3, 1, 2] + def _to_2tuple(value): if isinstance(value, int): @@ -24,42 +32,93 @@ def _to_2tuple(value): class RewriteMaxPool2dPass(ArmPass): - """Rewrite max_pool2d ops to TOSA MAX_POOL2D.""" + """Rewrite max_pool2d ops to TOSA MAX_POOL2D with NHWC layout.""" _passes_required_after: Set[Type[ExportPass]] = set() - def call_operator(self, op, args, kwargs, meta): - if op not in edge_max_pool2d_ops: - return super().call_operator(op, args, kwargs, meta) + @staticmethod + def _insert_permute(graph_module, anchor_node, input_node, perm, before=True): + ctx = ( + graph_module.graph.inserting_before(anchor_node) + if before + else graph_module.graph.inserting_after(anchor_node) + ) + with ctx: + return create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(input_node, perm), + from_node=input_node, + ) - x = args[0] - kernel = _to_2tuple(args[1]) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False - if len(args) > 2 and args[2] is not None and len(args[2]) > 0: - stride = _to_2tuple(args[2]) - else: - stride = kernel + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in edge_max_pool2d_ops: + continue - padding = _to_2tuple(args[3]) if len(args) > 3 else (0, 0) - dilation = _to_2tuple(args[4]) if len(args) > 4 else (1, 1) - ceil_mode = args[5] if len(args) > 5 else False + x = node.args[0] + kernel = _to_2tuple(node.args[1]) - if dilation != (1, 1): - return super().call_operator(op, args, kwargs, meta) + if len(node.args) > 2 and node.args[2] is not None and len(node.args[2]) > 0: + stride = _to_2tuple(node.args[2]) + else: + stride = kernel - # TOSA MAX_POOL2D pad order is [top, bottom, left, right] - pad = [padding[0], padding[0], padding[1], padding[1]] - pad[1] = adjust_pooling_pad_if_needed( - x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode - ) - pad[3] = adjust_pooling_pad_if_needed( - x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode - ) + padding = _to_2tuple(node.args[3]) if len(node.args) > 3 else (0, 0) + dilation = _to_2tuple(node.args[4]) if len(node.args) > 4 else (1, 1) + ceil_mode = node.args[5] if len(node.args) > 5 else False - return super().call_operator( - exir_ops.backend.tosa.MAX_POOL2D.default, - (x, list(kernel), list(stride), pad), - {}, - meta, - updated=True, - ) + if dilation != (1, 1): + continue + + modified = True + + input_fake = get_first_fake_tensor(x) + + # TOSA MAX_POOL2D pad order is [top, bottom, left, right] + pad = [padding[0], padding[0], padding[1], padding[1]] + pad[1] = adjust_pooling_pad_if_needed( + input_fake.shape[2], kernel[0], stride[0], pad[1], ceil_mode + ) + pad[3] = adjust_pooling_pad_if_needed( + input_fake.shape[3], kernel[1], stride[1], pad[3], ceil_mode + ) + + # Insert NCHW → NHWC permute on input + x_permuted = self._insert_permute( + graph_module, node, x, _NCHW_TO_NHWC, before=True + ) + + tosa_args = (x_permuted, list(kernel), list(stride), pad) + + # Create TOSA MAX_POOL2D node + with graph_module.graph.inserting_after(node): + tosa_op = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.MAX_POOL2D.default, + args=tosa_args, + from_node=node, + inherit_qparams=True, + ) + + # Compute correct NHWC FakeTensor + input_fake_nhwc = input_fake.permute(_NCHW_TO_NHWC) + tosa_node_fake = exir_ops.backend.tosa.MAX_POOL2D.default( + input_fake_nhwc, list(kernel), list(stride), pad + ) + tosa_op.meta["val"] = tosa_node_fake + + # Insert NHWC → NCHW permute on output + output_permute = self._insert_permute( + graph_module, tosa_op, tosa_op, _NHWC_TO_NCHW, before=False + ) + + node.replace_all_uses_with(output_permute) + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py deleted file mode 100644 index ecab595c39e..00000000000 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright 2024-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -from typing import Set, Type - -import torch -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import ( - create_node, - get_first_fake_tensor, - is_param_node, -) -from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER -from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -logger = logging.getLogger(__name__) - - -def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: - """Returns True if the node is an input node, i.e. a placeholder or a - parameter. - """ - return node.op == "placeholder" and not is_param_node(exported_program, node) - - -def _is_transpose_conv2d_weight(node: torch.fx.Node) -> bool: - for user in node.users: - if ( - user.op == "call_function" - and user.target == exir_ops.backend.tosa.TRANSPOSE_CONV2D.default - and len(user.args) > 1 - and user.args[1] is node - ): - return True - return False - - -class ToTosaMemoryFormatPass(ArmPass): - """Annotates each node with a tosa_dim_order. - - tosa_dim_order can be seen as a channels-last dim-order that in most cases - will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts - backend.tosa.TRANSPOSE when a transition between 3D and 4D/5D tensors - happen. The annotated tosa_dim_order is used to permute the node's shape - such that it gives a TOSA-compliant shape. This pass also makes other values - aware of spatial dimensions required by future operators by back propogating - info as required. - - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.exported_program = exported_program - - @staticmethod - def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]: - """Compute the permutation of tensor dimensions corresponding to a - "channels_last"-style memory layout for an arbitrary tensor rank. - - In standard PyTorch convention: - - "channels_first" order is (N, C, H, W) - - "channels_last" order is (N, H, W, C) - This helper generalizes that concept beyond 4D tensors, producing an index - ordering that moves the channel dimension to the end while preserving the - relative order of batch and spatial dimensions. - - Args: - rank (int): Total number of tensor dimensions (e.g. 4 for NCHW). - spatial_rank (int): Number of spatial dimensions (e.g. 2 for HW, 3 for DHW). - Values outside [0, rank - 2] are clamped to that range. - - Returns: - tuple[int, ...]: A permutation of dimension indices that reorders the - tensor into "channels_last" format. For example: - - rank=4, spatial_rank=2 → (0, 2, 3, 1) # NCHW → NHWC - - rank=5, spatial_rank=3 → (0, 2, 3, 4, 1) # NCDHW → NDHWC - - rank=3, spatial_rank=1 → (0, 2, 1) - - Notes: - If `rank <= 2`, the function returns the identity order since there - are no distinct channel/spatial dimensions. - In practice only rank 4+ tensors will reach this function as the dim order should be fixed for those. - - """ - if rank <= 2: - return tuple(range(rank)) - spatial_rank = max(0, min(spatial_rank, rank - 2)) - channel_axis = rank - (spatial_rank + 1) - batch_axes = list(range(channel_axis)) - spatial_axes = list(range(channel_axis + 1, rank)) - return tuple(batch_axes + spatial_axes + [channel_axis]) - - @staticmethod - def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...]: - """Return the inverse permutation of `_channels_last_order`. - - This provides the axis order needed to map a tensor from "channels_last" - layout back to its original layout. - - """ - order = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) - inverse = [0] * rank - for idx, axis in enumerate(order): - inverse[axis] = idx - return tuple(inverse) - - def _infer_dim_order_for_node( - self, node: torch.fx.Node, node_data: torch.Tensor, spatial_rank: int - ) -> tuple[int, ...]: - rank = node_data.dim() - - # Inputs and outputs preserve their externally-declared dim order. - if _is_input(node, self.exported_program) or node.op == "output": - return node_data.dim_order() - - # Conv transpose weights are serialized in OHWI layout. - if rank == 4 and _is_transpose_conv2d_weight(node): - return (1, 2, 3, 0) - - if rank >= 4: - return self._channels_last_order(rank, spatial_rank) - return tuple(range(rank)) - - def _initial_spatial_rank(self, node: torch.fx.Node) -> int: - """Infer the initial spatial rank based on the current rank, input node - spatial ranks and node target. A spatial dimension includes Height, - Width or Depth fields. In most operators this will only ever be Height - and Width, but for 3D operators such as conv3d this would contain 3 - spatial dims. - - Spatial rank is the max of any input node spatial ranks and the number of - trailing spatial dims we need to preserve (rank - 2, capped at 3). This - decides which axes must stay channels-last when inserting transposes. - - """ - tensor = get_first_fake_tensor(node).data - # Start by assuming 2D when dealing with rank4+ to account for the base case - # of an increasing amount of batch dimensions. - rank = tensor.dim() - if rank >= 4: - spatial_rank = 2 - elif rank == 3: - spatial_rank = 1 - else: - spatial_rank = 0 - - # Look for supported 3D ops and update spatial rank if relevent. - # Currently only Conv3d is supported. - if node.target == exir_ops.backend.tosa.CONV3D.default: - spatial_rank = 3 - - # Check input spatial ranks to know what the previous node spatial ranks were. - input_ranks = [ - input_node.meta.get("tosa_spatial_rank", 0) - for input_node in node.all_input_nodes - ] - if input_ranks: - spatial_rank = max([spatial_rank, *input_ranks]) - - # The max that spatial rank can be is 3. If the current rank not capable of holding - # the current spatial rank, we clamp the max to Rank - (Channels and a singular batch dimension). - # This ensures we revert back to lower spatial ranks after we are finished processing higher spatial ops. - return min(spatial_rank, max(rank - 2, 0)) - - @staticmethod - def memory_format_differs(shape, spatial_rank): - """Determine whether a tensor shape would be laid out differently in - channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory - format. - """ - if len(shape) <= 2 or spatial_rank <= 0: - return False - channel_idx = len(shape) - (spatial_rank + 1) - channel_idx = max(0, min(channel_idx, len(shape) - 1)) - spatial_dims = shape[channel_idx + 1 :] - if not spatial_dims: - return False - channel_dim = shape[channel_idx] - return channel_dim > 1 and any(dim > 1 for dim in spatial_dims) - - @staticmethod - def is_channel_reshape( - input_shape, output_shape, input_spatial_rank, output_spatial_rank - ): - """Check whether a reshape touches the logical channel or consolidated - batch dimensions, which would invalidate dim-order annotations. - """ - - valid_ranks = {4, 5, 6} - - if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks): - return False - - def channel_index(shape, spatial_rank): - if len(shape) <= 2: - return len(shape) - 1 - idx = len(shape) - (spatial_rank + 1) - return max(0, min(idx, len(shape) - 1)) - - C_old = input_shape[channel_index(input_shape, input_spatial_rank)] - C_new = output_shape[channel_index(output_shape, output_spatial_rank)] - - def get_batch_prod_dim(shape, spatial_rank): - product = 1 - - for dim in shape[: channel_index(shape, spatial_rank)]: - product = product * dim - - return product - - N_old = get_batch_prod_dim(input_shape, input_spatial_rank) - N_new = get_batch_prod_dim(output_shape, output_spatial_rank) - - return (N_old != N_new) or (C_old != C_new) - - @staticmethod - def insert_input_transpose(node, input_node, graph_module): - """Ensure an input tensor is converted to channels-last ordering by - inserting (or folding) a backend `TRANSPOSE` node. - """ - if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default: - pre_permute_node = input_node.all_input_nodes[0] - node.replace_input_with(input_node, pre_permute_node) - return - - rank = len(get_first_fake_tensor(input_node).size()) - spatial_rank = input_node.meta["tosa_spatial_rank"] - mem_format = ToTosaMemoryFormatPass._channels_last_inverse_order( - rank, spatial_rank - ) - # Guard: mem_format must be a true permutation for the current rank - assert sorted(mem_format) == list( - range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" - - with graph_module.graph.inserting_before(node): - permute_node = create_node( - graph_module.graph, - exir_ops.backend.tosa.TRANSPOSE.default, - args=( - input_node, - list(mem_format), - ), - from_node=node, - ) - node.replace_input_with(input_node, permute_node) - - permute_node.meta["tosa_dim_order"] = tuple( - range(len(input_node.meta["val"].size())) - ) - permute_node.meta["tosa_spatial_rank"] = spatial_rank - - @staticmethod - def insert_output_transpose(node, graph_module): - """Convert a producer's output to channels-last by appending a backend - `TRANSPOSE` node and rewiring its users. - """ - - rank = len(get_first_fake_tensor(node).size()) - spatial_rank = node.meta["tosa_spatial_rank"] - mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) - # Guard: mem_format must be a true permutation for the current rank - assert sorted(mem_format) == list( - range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" - - with graph_module.graph.inserting_after(node): - permute_node = create_node( - graph_module.graph, - exir_ops.backend.tosa.TRANSPOSE.default, - args=( - node, - list(mem_format), - ), - from_node=node, - ) - - rank = len(get_first_fake_tensor(node).size()) - permute_node.meta["tosa_dim_order"] = mem_format - - node.meta["tosa_dim_order"] = tuple( - range(len(get_first_fake_tensor(node).size())) - ) - permute_node.meta["tosa_spatial_rank"] = spatial_rank - - users = [user for user in node.users if user != permute_node] - for user in users: - user.replace_input_with(node, permute_node) - - @staticmethod - def _insert_view_transpose( - input_shape, output_shape, node, input_node, graph_module - ): - """Insert the necessary input/output transposes around reshapes that - cross the (N)NCHW -> (N)NHWC boundary or that touch channel - dimensions. - """ - nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4 - nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4 - - input_sr = input_node.meta["tosa_spatial_rank"] - output_sr = node.meta["tosa_spatial_rank"] - - channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape( - input_shape, - output_shape, - input_sr, - output_sr, - ) - - if ( - channel_reshape or nhwc_to_nchw - ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr): - ToTosaMemoryFormatPass.insert_input_transpose( - node, input_node, graph_module - ) - - if ( - channel_reshape or nchw_to_nhwc - ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr): - ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) - - def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - """Transposes are needed for operators transforming the input to a - different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC- - format, whereas all other are in (N)NCHW format. - - This is relevant for the following cases: - - view: <4D -> >=4D - - view: >=4D -> <4D - Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leading to one extra input and output transpose for this case. - - Transposes can be avoided for shapes where there is no difference in actual memory, e.g for - - H == W == 1 - - C == 1 - - 1D/2D tensors - - """ - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - - # Transpose views - elif node.target == exir_ops.edge.aten.view_copy.default: - input_node = node.args[0] - input_shape = input_node.meta["val"].shape - output_shape = node.meta["val"].shape - self._insert_view_transpose( - input_shape, - output_shape, - node, - input_node, - graph_module, - ) - - output_node = graph_module.graph.output_node() - - # Transpose inputs if they are in (N)NCHW format - inputs = [ - n for n in graph_module.graph.nodes if _is_input(n, self.exported_program) - ] - for input_node in inputs: - input_dim_order = get_first_fake_tensor(input_node).dim_order() - if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER): - self.insert_output_transpose(input_node, graph_module) - - # Transpose outputs if they are in (N)NCHW format - outputs = output_node.args[0] - if not isinstance(outputs, (list, tuple)): - raise TypeError( - f"Expected output node args to be a list or tuple, got {type(outputs)}" - ) - output_dim_orders = output_node.meta.get("original_dim_orders") - if output_dim_orders is None: - raise RuntimeError(f"{output_dim_orders=} is not supported.") - - for output_node_input, output_dim_order in zip( - outputs, output_dim_orders, strict=True - ): - if output_dim_order in ( - NCHW_ORDER, - NNCHW_ORDER, - NNNCHW_ORDER, - ): - self.insert_input_transpose( - output_node, output_node_input, graph_module - ) - - def remove_dim_order_kwargs( - self, graph_module: torch.fx.GraphModule, node: torch.fx.Node - ): - """Drop any user-specified `dim_order` keyword arguments so the pass - remains the single source of truth for dim-order annotations. - """ - if node.op != "call_function": - return - - kwargs = dict(node.kwargs) - - if "dim_order" in kwargs: - logger.warning( - f"Ignoring dim_order kwarg '{kwargs['dim_order']}' for '{node.name}'." - ) - del kwargs["dim_order"] - - node.kwargs = kwargs - - def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None: - for arg in node.all_input_nodes: - if is_shape_op_node(arg): - # Shape nodes may get its dim_order from multiple users. Keep track of old dim_order to make sure all - # users agree on the same dim_order, otherwise we may end up with non-deterministic dim_orders for - # shape nodes depending on the order of user traversal. - old_dim_order = arg.meta.get("tosa_dim_order", None) is not None - dim_order = node.meta["tosa_dim_order"] - # The shape node may have a different rank than the dim_order being propagated from its users - if len(dim_order) != len(arg.meta["val"]): - # For pad shape nodes, the rank is always 2x of the input tensor rank, and the dim order needs to be adjusted accordingly. - # For other shape nodes, we assume the dim order is the same as the order of dimensions in the shape. - if node.target == exir_ops.backend.tosa.PAD.default: - dim_order = tuple( - i for axis in dim_order for i in (2 * axis, 2 * axis + 1) - ) - else: - dim_order = tuple(range(len(arg.meta["val"]))) - if old_dim_order and arg.meta["tosa_dim_order"] != dim_order: - raise RuntimeError( - f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}" - ) - if node.target == exir_ops.backend.tosa.RESIZE.default: - # RESIZE's shape input is expected to be in HW order, so we need to override the dim order to be the identity for it regardless of the user node's dim order. - dim_order = tuple(range(len(arg.meta["val"]))) - arg.meta["tosa_dim_order"] = dim_order - self._propagate_dim_order_to_shape_args(arg) - - def _annotate_shape_nodes(self, graph_module: torch.fx.GraphModule) -> None: - for node in graph_module.graph.nodes: - if not self._is_ok_for_annotation(node): - continue - self._propagate_dim_order_to_shape_args(node) - - def _is_ok_for_annotation(self, node: torch.fx.Node) -> bool: - if "val" not in node.meta: - return False - # Shape-only nodes which produce SymInt[] rather than real tensors are annotated separately by propagating dim order from their users. - # We must therefore annotate all valid nodes before propagating dim order upwards in graph. - if is_shape_op_node(node): - return False - # For some models, the symbolic value is passed to the graph, skip it - if isinstance(node.meta["val"], torch.SymInt): - return False - return True - - def call(self, graph_module: torch.fx.GraphModule): - """ - Entry point for the pass: annotate spatial ranks, compute dim orders, - insert bridging transposes, and forward to child passes. - """ - graph_module.graph.eliminate_dead_code() - nodes = list(graph_module.graph.nodes) - for node in nodes: - if not self._is_ok_for_annotation(node): - continue - node.meta["tosa_spatial_rank"] = self._initial_spatial_rank(node) - self.remove_dim_order_kwargs(graph_module, node) - - self._propagate_spatial_ranks(nodes) - - for node in nodes: - if not self._is_ok_for_annotation(node): - continue - node_data = get_first_fake_tensor(node).data - spatial_rank = node.meta["tosa_spatial_rank"] - dim_order = self._infer_dim_order_for_node(node, node_data, spatial_rank) - node.meta["tosa_dim_order"] = dim_order - - # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format. - # See insert_tosa_transposes for insertion conditions. - self.insert_tosa_transposes(graph_module) - # Special handling is needed for shape nodes as they don't have real tensors or real dim orders, but the order - # still needs to be propagated to them so that they can be serialized with the correct order and shapes. - self._annotate_shape_nodes(graph_module) - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) - - def _propagate_spatial_ranks(self, nodes): - """Propagate `tosa_spatial_rank` metadata backwards so earlier nodes - learn about upcoming spatial requirements from future ops. - """ - changed = True - while changed: - changed = False - for node in reversed(nodes): - if not self._is_ok_for_annotation(node): - continue - tensor = get_first_fake_tensor(node) - limit = max(tensor.dim() - 2, 0) - current = node.meta.get("tosa_spatial_rank") - propagated = current - for user in node.users: - user_rank = user.meta.get("tosa_spatial_rank") - if user_rank is None: - continue - propagated = max(propagated, min(user_rank, limit)) - if propagated != current: - node.meta["tosa_spatial_rank"] = propagated - changed = True diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index c0c795c2cfa..437fa561526 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -19,11 +19,6 @@ is_singleton_permutation, ) from executorch.backends.arm._passes.insert_table_ops import TableOps -from executorch.backends.arm._passes.to_tosa_memory_format_pass import ( - ToTosaMemoryFormatPass, -) -from executorch.backends.arm.operators.op_permute import transform_permutation_vector -from executorch.backends.arm.tosa.utils import tosa_shape from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.operator_support import OperatorSupportBase @@ -354,125 +349,6 @@ def _check_rank_constraints( return True - @staticmethod - def _spatial_rank(rank: int) -> int: - assert rank < 5, "Spatial rank determination only valid for rank <5." - if rank == 4: - return 2 - if rank == 3: - return 1 - return 0 - - def _transpose_requirements( - self, input_shape: shape_t, output_shape: shape_t - ) -> tuple[bool, bool]: - """Determine if reshaping requires input or output transposes. For ranks - >4, assume transpose is needed as we cannot determine the spatial rank - reliably which is needed to determine if a transpose is needed. - - Args: - input_shape (shape_t): Original tensor shape. - output_shape (shape_t): Reshaped tensor shape. - - Returns: - tuple[bool, bool]: ``(needs_input_transpose, needs_output_transpose)``. - - """ - input_rank = len(input_shape) - output_rank = len(output_shape) - if input_rank > 4 and output_rank <= 4: - # Assume input needs transpose if going from high-rank to low-rank - return ( - True, - output_rank == 4 - and ToTosaMemoryFormatPass.memory_format_differs( - output_shape, self._spatial_rank(output_rank) - ), - ) - elif input_rank <= 4 and output_rank > 4: - # Assume output needs transpose if going from low-rank to high-rank - return ( - input_rank == 4 - and ToTosaMemoryFormatPass.memory_format_differs( - input_shape, self._spatial_rank(input_rank) - ), - True, - ) - - input_sr = self._spatial_rank(input_rank) - output_sr = self._spatial_rank(output_rank) - nhwc_to_nchw = input_rank >= 4 and output_rank < 4 - nchw_to_nhwc = input_rank < 4 and output_rank >= 4 - channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape( - input_shape, output_shape, input_sr, output_sr - ) - - needs_input_transpose = ( - channel_reshape or nhwc_to_nchw - ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr) - needs_output_transpose = ( - channel_reshape or nchw_to_nhwc - ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr) - return needs_input_transpose, needs_output_transpose - - def _check_transpose_constraints( - self, - node: fx.Node, - dtype: torch.dtype | None, - input_shape: shape_t, - output_shape: shape_t, - needs_input_transpose: bool, - needs_output_transpose: bool, - ) -> bool: - """Apply dtype- and size-based constraints for transpose insertions. - - based on: - - NCHW -> NHWC or NHWC -> NCHW transposes are not supported in int32. - - Transposes with product of axes >65536 are not supported. - - Args: - node (fx.Node): Node requiring validation. - dtype (torch.dtype | None): Resolved dtype of the reshape. - input_shape (shape_t): Source tensor shape. - output_shape (shape_t): Destination tensor shape. - needs_input_transpose (bool): Whether an input transpose is expected. - needs_output_transpose (bool): Whether an output transpose is expected. - - Returns: - bool: ``True`` if any implied transpose satisfies U55 limits. - - """ - if dtype == torch.int32 and (needs_input_transpose or needs_output_transpose): - self.reporter.report_reject( - node, - "Operator requires transpose operator. No support for transpose with " - "rank >= 4 in int32, got rank=4.", - ) - return False - - if ( - needs_input_transpose - and self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT - ): - self.reporter.report_reject( - node, - f"Operator requires transpose operator. No support for {input_shape=}, " - f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", - ) - return False - if ( - needs_output_transpose - and self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT - ): - self.reporter.report_reject( - node, - f"Operator requires transpose operator. No support for {output_shape=}, " - f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", - ) - return False - - return True - # TODO: Extend this check to comply with u55 restrictions def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node @@ -481,16 +357,6 @@ def is_node_supported( Currently only checks dtypes and product of axes. - It is not the view operator itself that is not supported on U55. In - order for the view operator to be compatible with the channels-last - format of TosaBackend, transposes may need to be inserted before and - after the view op. If that happens and that transpose operator does not - adhere to the limitations then it will result in the following error: - - CPU performance estimation for "Transpose" not implemented. - ... - CPU operations are not supported for GraphAPI input - Args: submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. node (fx.Node): FX node for ``view_copy`` or ``select``. @@ -511,40 +377,22 @@ def is_node_supported( return True shape = list(get_first_fake_tensor(node).shape) - output_rank = len(shape) dtype = _try_determine_dtype(node) if node.target in ( exir_ops.edge.aten.select.int, exir_ops.edge.aten.select_copy.int, ): - # For select, the transpose condition should be applied on the output of the slice - # which has the same shape as the input except the selected dimension is 1. input_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape) dim = typing.cast(int, node.args[1]) input_shape[dim] = 1 else: input_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape) - input_rank = len(input_shape) if not self._check_rank_constraints(node, input_shape, shape, dtype): return False - if input_rank > 4 and output_rank > 4: - # If both input and output have rank >4, and passed the above checks, we can accept - # the node - return True - needs_input_transpose, needs_output_transpose = self._transpose_requirements( - input_shape, shape - ) - return self._check_transpose_constraints( - node, - dtype, - input_shape, - shape, - needs_input_transpose, - needs_output_transpose, - ) + return True class EthosU55TransposeCheck(OperatorSupportBase): @@ -688,25 +536,13 @@ def is_node_supported( shape, permutation = self._pad_to_rank_4(shape, permutation) if rank == 3 or rank == 4: - # For rank 3 and 4, we can have channels first or channels last dim order. - # Since we don't know which at partition-time, test both. - - nhwc_shape = tosa_shape(shape, [0, 2, 3, 1]) - nhwc_permutation = transform_permutation_vector(permutation, [0, 2, 3, 1]) - - if not self._permute_constraint(nhwc_shape, nhwc_permutation, dtype): + if not self._permute_constraint(shape, permutation, dtype): self.reporter.report_reject( node, - f"Unsupported NHWC {nhwc_shape=} for {nhwc_permutation=}, {dtype=}", + f"Unsupported {shape=} for {permutation=}, {dtype=}", ) return False - if not self._permute_constraint(shape, permutation, dtype): - self.reporter.report_reject( - node, f"Unsupported NCHW {shape=} for {permutation=}, {dtype=}" - ) - return False - return True diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 4ebec3cc1ac..106774fb3eb 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -69,7 +69,7 @@ def define_node( attr = ts.TosaSerializerAttribute() nan_mode = ts.NanPropagationMode.PROPAGATE - attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode) + attr.ReduceMaxAttribute(axis=dim, nan_mode=nan_mode) self._serialize_operator( node, tosa_graph, diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index ca8a47c582f..58aff4c904d 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -69,7 +69,7 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.ReduceMinAttribute( - axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE + axis=dim, nan_mode=ts.NanPropagationMode.PROPAGATE ) self._serialize_operator( node, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 390b805f2a8..c602bfc5e0f 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -47,7 +47,7 @@ def define_node( raise ValueError("This case should be handled by DecomposeAnyPass") attr = ts.TosaSerializerAttribute() - attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) + attr.ReduceAnyAttribute(dim) self._serialize_operator( node, diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 9a50435e04e..544beefadf9 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -57,7 +57,6 @@ def define_node( dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) dim = (dim + rank) % rank - dim = output.dim_order.index(dim) attr = ts.TosaSerializerAttribute() attr.ConcatAttribute(dim) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index b9c453980aa..e200478d7b3 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -22,80 +22,6 @@ from executorch.backends.arm.tosa.mapping import TosaArg -def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: - """Convert a permutation vector of length N to an N x N matrix. - - Example: - (1, 0, 2) -> - [0 1 0] - [1 0 0] - [0 0 1] - - """ - N = len(permutation_vector) - P = torch.zeros(N, N) - for row_index, col_index in enumerate(permutation_vector): - P[row_index][col_index] = 1 - return P - - -def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: - """Convert an N x N permutation matrix to a permutation vector of length N. - - Example: - [0 1 0] - [1 0 0] - [0 0 1] - -> (1, 0, 2) - - """ - N = len(permutation_matrix) - if N != len(permutation_matrix[0]): - raise ValueError( - f"A permutation matrix must be square, got shape {permutation_matrix.shape}" - ) - - p = [0] * N - for row_index, row in enumerate(permutation_matrix): - saw_one = False - for col_index, value in enumerate(row): - if value == 1: - if saw_one: - raise ValueError( - f"A permutation matrix can only have one 1 per row, got {row=}" - ) - p[row_index] = col_index - saw_one = True - elif value != 0: - raise ValueError( - f"A permutation matrix only contains 1's and 0's, got {value=}" - ) - return p - - -def transform_permutation_vector(permutation_vector: list[int], dim_order: list[int]): - """Transforms a permutation to dim_order.""" - - # We need to first transform to dim_order, apply the permutation P, - # and then transform back to the original dim_order. - # This transformation, S, is also a permutation, with the dim_order as permutation vector. - - # To do this, represent P and S with permutation matrices. - # Matrices can handle chained transformations and inversion easily. - S = permutation_vector_to_matrix(dim_order) - # The inverse of a permutation matrix is its transpose. - S_inverse = S.t() - P = permutation_vector_to_matrix(permutation_vector) - - # The complete transformation is S * P * S_inverse. - transformation_matrix = S.matmul(P.matmul(S_inverse)) - - # Luckily, since it is just a combination of permutations, the result is also a permutation - # that can again be described by a new permutation vector. - permutation_vector = permutation_matrix_to_vector(transformation_matrix) - return permutation_vector - - @register_node_visitor class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" @@ -127,18 +53,8 @@ def define_node( self.tosa_spec, ) - # The permutation vector describes a permutation P in default Pytorch dim_order. - # For rank 4, the default dim_order NCHW. - # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) permutation_vector = inputs[1].special - if output.dim_order != tuple(range(len(output.dim_order))): - # the permutation vector can't be used directly if we are not in NCHW dim_order. - # Transform to dim_order. - permutation_vector = transform_permutation_vector( - permutation_vector, output.dim_order - ) - attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(permutation_vector) self._serialize_operator( diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index f18d7609c28..fce5836493d 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -46,7 +46,7 @@ def define_node( dim = int(inputs[1].number % len(input_shape)) attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(tensor.dim_order.index(dim)) + attr.ReduceSumAttribute(dim) self._serialize_operator( node, diff --git a/backends/arm/operators/op_tosa_shapes.py b/backends/arm/operators/op_tosa_shapes.py index 7e426a1da4e..fa4ea38157c 100644 --- a/backends/arm/operators/op_tosa_shapes.py +++ b/backends/arm/operators/op_tosa_shapes.py @@ -34,8 +34,7 @@ def define_node( ) -> None: shape_input = inputs[0].special rank = len(shape_input) - tosa_dim_order = output.dim_order - vals = tosa_shape(node.meta["val"], tosa_dim_order) + vals = tosa_shape(node.meta["val"]) tosa_graph = cast(ts.TosaSerializer, tosa_graph) tosa_graph.addConst( [ diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py index 6b6c719d5ad..94d13340d41 100644 --- a/backends/arm/operators/op_while.py +++ b/backends/arm/operators/op_while.py @@ -6,7 +6,6 @@ from typing import Any, cast, List import tosa_serializer as ts -from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -64,21 +63,16 @@ def define_node( if num_inputs > num_outputs: # If we have more inputs than outputs, we can just add missing output tensors. body_module = getattr(node.graph.owning_module, body_graph) - output_dim_orders = get_output_dim_orders(body_module) body_outputs = body_module.graph.output_node().args[0] outputs_needing_tensors = body_outputs[num_outputs - num_inputs :] - output_dim_orders = output_dim_orders[num_outputs - num_inputs :] - for ( - output_needing_tensor, - dim_order, - ) in zip(outputs_needing_tensors, output_dim_orders, strict=True): + for output_needing_tensor in outputs_needing_tensors: tensor_name = output_needing_tensor.name + "_dummy" shape = output_needing_tensor.meta["val"].shape dtype = map_dtype(output_needing_tensor.meta["val"].dtype) tosa_graph.currRegion.currBasicBlock.addTensor( tensor_name, - tosa_shape(shape, dim_order), + tosa_shape(shape), dtype, ) output.multiple_output_names.append(tensor_name) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index e5f522d6e2e..f321f45f041 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -28,8 +28,8 @@ from torch.export.exported_program import ExportedProgram -def _tensor_to_numpy_with_dim_order( - tensor: torch.Tensor, dim_order: tuple[int, ...] +def _tensor_to_numpy( + tensor: torch.Tensor, ) -> np.ndarray: tensor = tensor.detach().cpu().contiguous() if tensor.dtype == torch.bfloat16: @@ -42,9 +42,7 @@ def _tensor_to_numpy_with_dim_order( np_tensor = tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16) else: np_tensor = tensor.numpy() - if dim_order == tuple(range(len(dim_order))): - return np_tensor - return np.transpose(np_tensor, dim_order) + return np_tensor def process_call_function( @@ -71,7 +69,7 @@ def process_call_function( tosa_graph = cast(ts.TosaSerializer, tosa_graph) if not output.multiple_output_names and not is_shape_op_node(node): tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, tosa_shape(output.shape, output.dim_order), output.dtype + output.name, tosa_shape(output.shape), output.dtype ) # Get item nodes just add tensors, no node visitor is needed. @@ -106,10 +104,9 @@ def process_inputs( ) from e input_shape = tosa_arg.shape - input_dim_order = tosa_arg.dim_order tensor = ts.TosaSerializerTensor( tosa_arg.name, - tosa_shape(input_shape, input_dim_order), + tosa_shape(input_shape), tosa_arg.dtype, data=None, ) @@ -137,8 +134,8 @@ def process_inputs_to_parameters( f"Expected parameter '{node.name}' to be a torch.Tensor, got " f"{type(parameter_data).__name__}" ) - parameter_values = _tensor_to_numpy_with_dim_order( - parameter_data, tosa_arg.dim_order # type: ignore[arg-type] + parameter_values = _tensor_to_numpy( + parameter_data, ) tosa_graph.addConst( @@ -167,7 +164,7 @@ def process_inputs_to_buffers( f"Expected buffer '{node.name}' to be a torch.Tensor, got " f"{type(buffer_data).__name__}" ) - buffer_values = _tensor_to_numpy_with_dim_order(buffer_data, tosa_arg.dim_order) # type: ignore[arg-type] + buffer_values = _tensor_to_numpy(buffer_data) tosa_graph.addConst( buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name @@ -188,9 +185,8 @@ def process_inputs_to_lifted_tensor_constants( "Is the original torch function supported?" ) from e tensor = get_lifted_tensor_constant(edge_program, node) - tensor_values = _tensor_to_numpy_with_dim_order( + tensor_values = _tensor_to_numpy( tensor, # type: ignore[arg-type] - tosa_arg.dim_order, # type: ignore[arg-type] ) tosa_graph.addConst( diff --git a/backends/arm/test/misc/test_const_shape.py b/backends/arm/test/misc/test_const_shape.py index 2694dc6ea97..3f959a27bd0 100644 --- a/backends/arm/test/misc/test_const_shape.py +++ b/backends/arm/test/misc/test_const_shape.py @@ -6,13 +6,9 @@ from typing import Set, Type import executorch.backends.arm.tosa.dialect # noqa: F401 -import pytest import torch import tosa_serializer as ts from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.to_tosa_memory_format_pass import ( - ToTosaMemoryFormatPass, -) from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.process_node import process_call_function from executorch.backends.arm.tosa.mapping import TosaSpecialDtype @@ -84,14 +80,6 @@ def _graph_module_with_unused_const_shape(): return graph_module -def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None: - output_node = next(node for node in graph_module.graph.nodes if node.op == "output") - output_node.meta["tosa_dim_order"] = (0,) - dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),)) - tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported) - tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node) - - def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule): tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape") node_visitors = get_node_visitors(None, tosa_spec) @@ -110,33 +98,18 @@ def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule): return tosa_graph -def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization(): +def test_dead_shape_ops_must_be_removed_before_tosa_serialization(): graph_module = _graph_module_with_unused_const_shape() - _propagate_shape_dim_orders_from_users(graph_module) - - const_shape_nodes = [ - node - for node in graph_module.graph.nodes - if node.op == "call_function" - and node.target == exir_ops.backend.tosa.CONST_SHAPE.default - ] - dead_const_shape, live_const_shape = const_shape_nodes - - assert dead_const_shape.users == {} - assert "tosa_dim_order" not in dead_const_shape.meta - assert live_const_shape.meta["tosa_dim_order"] == (0,) - - with pytest.raises(KeyError, match="tosa_dim_order"): - _serialize_graph_module_to_tosa(graph_module) + # After eliminating dead code, only the live const shape should remain. graph_module.graph.eliminate_dead_code() graph_module.recompile() - remaining_const_shape = next( + remaining_const_shapes = [ node for node in graph_module.graph.nodes if node.op == "call_function" and node.target == exir_ops.backend.tosa.CONST_SHAPE.default - ) - assert remaining_const_shape.meta["tosa_dim_order"] == (0,) + ] + assert len(remaining_const_shapes) == 1 assert _serialize_graph_module_to_tosa(graph_module) diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 2bee8ede0aa..be985566f4c 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -149,15 +149,13 @@ def test_view_vgf_quant(test_data: Tuple): @common.parametrize("test_data", View.rank_product_too_large) @common.XfailIfNoCorstone300 -def test_view_u55_INT_not_delegated(test_data: Tuple): +def test_view_u55_INT_large(test_data: Tuple): test_tensor, new_shape = test_data() - pipeline = OpNotSupportedPipeline[input_t1]( + pipeline = EthosU55PipelineINT[input_t1]( View(new_shape), (test_tensor,), - {"executorch_exir_dialects_edge__ops_aten_view_copy": 1}, - n_expected_delegates=1, - quantize=True, - u55_subset=True, + aten_op, + exir_ops=[], ) pipeline.run() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..bfca3a65fff 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -3,40 +3,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Dict, List, Protocol, Tuple +from typing import cast, Tuple import torch -from executorch.backends.arm._passes import ( - AnnotateOutputDimOrderPass, - ToTosaMemoryFormatPass, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( - PassPipeline, TosaPipelineINT, ) -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass input_t = Tuple[torch.Tensor] # Input x -class ModuleMetadata(Protocol): - ops_before_pass: Dict[str, int] - ops_after_pass: Dict[str, int] - ops_not_after_pass: List[str] - - def get_inputs(self) -> input_t: ... - - class NoNHWC(torch.nn.Module): - """Test-module with no ops requiring NHWC mermory format.""" - - ops_before_pass: Dict[str, int] = {} - ops_after_pass: Dict[str, int] = { - "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 - } - ops_not_after_pass: List[str] = [] + """Test-module with no ops requiring NHWC memory format.""" def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + x @@ -51,12 +31,6 @@ class ParallelClusters(torch.nn.Module): memory formats. """ - ops_before_pass: Dict[str, int] = {} - ops_after_pass: Dict[str, int] = { - "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 - } - ops_not_after_pass: List[str] = [] - def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( @@ -80,16 +54,10 @@ def get_inputs(self) -> input_t: class SerialClusters(torch.nn.Module): - """Test-module with multiple serial clusters of nodes requring different + """Test-module with multiple serial clusters of nodes requiring different memory formats. """ - ops_before_pass: Dict[str, int] = {} - ops_after_pass: Dict[str, int] = { - "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 - } - ops_not_after_pass: List[str] = [] - def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( @@ -119,54 +87,46 @@ def get_inputs(self) -> input_t: class Reshapes(torch.nn.Module): - """Test-module with different configurations of views requiring different - memory formats. - """ - - ops_before_pass: Dict[str, int] = {} - ops_after_pass: Dict[str, int] = { - "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 16 - } - ops_not_after_pass: List[str] = [] + """Test-module with different configurations of views.""" def __init__(self): super().__init__() - self.maxpool = torch.nn.MaxPool2d(1, 1) # Use maxpool to force NHWC format + self.maxpool = torch.nn.MaxPool2d(1, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.maxpool(x) - x = x.view((2, 2, 4, 16, 1)) # N-C-HW-invariant intact, no transposes needed - x = x * 2 # Add op to avoid views merging + x = x.view((2, 2, 4, 16, 1)) + x = x * 2 x = x.view((4, 4, 4, 4)) - x = x / 2 # Add op to avoid views merging + x = x / 2 x = self.maxpool(x) - x = x.view((256)) # Break N-C-HW invariant + x = x.view((256)) x = x * 2 x = x.view((4, 4, 4, 4)) x = x / 2 x = self.maxpool(x) - x = x.view((16, 16)) # Break N-C-HW invariant + x = x.view((16, 16)) x = x * 2 x = x.view((4, 4, 4, 4)) x = x / 2 x = self.maxpool(x) - x = x.view((16, 4, 4)) # Break N-C-HW invariant + x = x.view((16, 4, 4)) x = x * 2 x = x.view((4, 4, 4, 4)) x = x / 2 x = self.maxpool(x) - x = x.view((2, 4, 4, 8)) # Break N-C-HW invariant + x = x.view((2, 4, 4, 8)) x = x * 2 x = x.view((4, 4, 4, 4)) x = x / 2 x = self.maxpool(x) - x = x.view((8, 1, 2, 4, 4)) # Break N-C-HW invariant + x = x.view((8, 1, 2, 4, 4)) x = x * 2 x = x.view((4, 4, 4, 4)) x = self.maxpool(x) @@ -177,7 +137,7 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) -modules: Dict[str, ModuleMetadata] = { +modules = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), @@ -186,26 +146,8 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT(module: ModuleMetadata) -> None: - # We cannot check op counts after a specific pass with the full pipeline - module_nn = cast(torch.nn.Module, module) - pipeline = PassPipeline[input_t]( - module_nn, - module.get_inputs(), - ops_after_pass=module.ops_after_pass, - ops_not_after_pass=module.ops_not_after_pass, - pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], - passes_with_exported_program=[ToTosaMemoryFormatPass], - ) - pipeline.pop_stage( - "run_method_and_compare_outputs" - ) # Eager execution is not possible after introducing tosa.TRANSPOSE - pipeline.run() - - -@common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> None: - # Also run the actual pass pipeline to ensure functional correctness. +def test_tosa_memory_format_functional(module) -> None: + """Run the full TOSA pipeline to ensure functional correctness.""" module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 8d644eccef0..0a99de68d5d 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -1129,7 +1129,7 @@ def _get_dtype_distribution( placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): - dtype, _, _ = extract_tensor_meta(node.meta) + dtype, _ = extract_tensor_meta(node.meta) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 0d1dfb4dfa1..142a8e91ea9 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -38,7 +38,6 @@ from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.dim_order_utils import get_memory_format from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -119,12 +118,8 @@ def _sort_key(t: Node) -> int: def _get_matching_fake_tensor(node: Node): - """Return a fake tensor with the same properties as node, but with - .dim_order() == node.meta["tosa_dim_order"] - """ - fake_tensor = node.meta["val"] - desired_dim_order = node.meta["tosa_dim_order"] - return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order))) + """Return the fake tensor metadata for a node.""" + return node.meta["val"] def arm_get_first_delegation_tag(graph_module) -> str: diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d.py b/backends/arm/tosa/dialect/ops/avg_pool2d.py index 1a9192048a8..db28484dc55 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import List, Union import torch @@ -27,66 +28,33 @@ def AVG_POOL2D( pad: List[Union[int, torch.SymInt]], acc_type: torch.dtype, ) -> torch.Tensor: - """Compute output meta for a TOSA AVG_POOL2D operation.""" tosa_spec = get_context_spec() - # Validate dtype support - supported_int_types = [torch.int8] - supported_float_types = [ - torch.float16, - torch.float32, - ] + supported_dtypes = [] + if tosa_spec.support_integer(): + supported_dtypes.extend([torch.int8]) + if tosa_spec.support_float(): + supported_dtypes.extend([torch.float16, torch.float32]) if tosa_spec.support_extension("bf16"): - supported_float_types.append(torch.bfloat16) + supported_dtypes.append(torch.bfloat16) if tosa_spec.support_extension("int16"): - supported_int_types.append(torch.int16) + supported_dtypes.append(torch.int16) - if x.dtype in supported_int_types: - if not tosa_spec.support_integer(): - raise TosaValueError( - f"TOSA spec {tosa_spec} doesn't support integer pools", op="AVG_POOL2D" - ) - elif x.dtype in supported_float_types: - if not tosa_spec.support_float(): - raise TosaValueError( - f"TOSA spec {tosa_spec} doesn't support float pools", op="AVG_POOL2D" - ) - else: + if x.dtype not in supported_dtypes: raise TosaValueError( - f"Unsupported input dtype {x.dtype} for TOSA AVG_POOL2D", op="AVG_POOL2D" - ) - - # Validate input dimensions - if x.dim() != 4: - raise TosaValueError( - f"AVG_POOL2D requires a 4D tensor, got {x.dim()}D", op="AVG_POOL2D" - ) - - # Validate kernel, stride, pad lengths - if len(kernel) != 2 or len(stride) != 2 or len(pad) != 4: - raise TosaValueError( - f"AVG_POOL2D expects kernel of length 2, stride of length 2, pad of length 4; got " - f"kernel={kernel}, stride={stride}, pad={pad}", + f"Unsupported input dtype {x.dtype}, supported types are {supported_dtypes}", op="AVG_POOL2D", ) - # Validate and determine accumulator (output) dtype: only FP32 or INT32 - acc_allowed = [torch.float32, torch.int32] - if acc_type not in acc_allowed: - raise TosaValueError( - f"Unsupported acc_type {acc_type} for TOSA AVG_POOL2D; " - f"must be one of {acc_allowed}", - op="AVG_POOL2D", - ) - # Unpack dimensions and parameters; zero-points are not used for shape - n, c, h, w = x.shape + # Input is NHWC: [N, H, W, C] + N = x.shape[0] + H_in = x.shape[1] + W_in = x.shape[2] + C = x.shape[3] - k_h, k_w = kernel - s_h, s_w = stride - p_top, p_left, p_bot, p_right = pad - # Compute output spatial dimensions (floor division) - h_out = (h + p_top + p_left - k_h) // s_h + 1 - w_out = (w + p_bot + p_right - k_w) // s_w + 1 + # pad is [top, bottom, left, right] + H_out = math.floor((H_in + pad[0] + pad[1] - kernel[0]) / stride[0]) + 1 + W_out = math.floor((W_in + pad[2] + pad[3] - kernel[1]) / stride[1]) + 1 - # Return a tensor with the computed shape and dtype - return torch.empty(size=[n, c, h_out, w_out], dtype=x.dtype) + output_shape = [N, H_out, W_out, C] + return x.new_empty(output_shape, dtype=x.dtype) diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index 2b991600994..644c8a1d808 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -101,14 +101,15 @@ def CONV2D( torch_pad = [pad[0], pad[2]] N = x.shape[0] C_out = weight.shape[0] - H_in, W_in = x.shape[2:] + H_in, W_in = x.shape[1], x.shape[2] + kernel_h, kernel_w = weight.shape[1], weight.shape[2] H_out = math.floor( - (H_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[2] - 1) - 1) / stride[0] + (H_in + 2 * torch_pad[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1 ) W_out = math.floor( - (W_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[3] - 1) - 1) / stride[1] + (W_in + 2 * torch_pad[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1 ) - output_shape = [N, C_out, H_out, W_out] - return torch.empty(size=output_shape, dtype=output_dtype) + output_shape = [N, H_out, W_out, C_out] + return x.new_empty(output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py index bf316c3d52a..7e1b457f1ee 100644 --- a/backends/arm/tosa/dialect/ops/conv3d.py +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -54,18 +54,18 @@ def CONV3D( torch_pad = [pad[0], pad[2], pad[4]] N = x.shape[0] C_out = weight.shape[0] - D_in, H_in, W_in = x.shape[2:] + D_in, H_in, W_in = x.shape[1], x.shape[2], x.shape[3] D_out = math.floor( - (D_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[2] - 1) - 1) / stride[0] + (D_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[1] - 1) - 1) / stride[0] + 1 ) H_out = math.floor( - (H_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[3] - 1) - 1) / stride[1] + (H_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[2] - 1) - 1) / stride[1] + 1 ) W_out = math.floor( - (W_in + 2 * torch_pad[2] - dilation[2] * (weight.shape[4] - 1) - 1) / stride[2] + (W_in + 2 * torch_pad[2] - dilation[2] * (weight.shape[3] - 1) - 1) / stride[2] + 1 ) - output_shape = [N, C_out, D_out, H_out, W_out] - return torch.empty(size=output_shape, dtype=output_dtype) + output_shape = [N, D_out, H_out, W_out, C_out] + return x.new_empty(output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py index 7d8d5f9edc8..166b9a78951 100644 --- a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -39,15 +39,15 @@ def DEPTHWISE_CONV2D( ) torch_pad = [pad[0], pad[2]] - kernel_h, kernel_w = weight.shape[0], weight.shape[2] - C_out = weight.shape[1] * x.shape[1] + kernel_h, kernel_w = weight.shape[0], weight.shape[1] + C_out = weight.shape[2] * weight.shape[3] N = x.shape[0] - H_in, W_in = x.shape[2:] + H_in, W_in = x.shape[1], x.shape[2] H_out = math.floor( (H_in + 2 * torch_pad[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1 ) W_out = math.floor( (W_in + 2 * torch_pad[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1 ) - output_shape = [N, C_out, H_out, W_out] - return torch.empty(size=output_shape, dtype=output_dtype) + output_shape = [N, H_out, W_out, C_out] + return x.new_empty(output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/max_pool2d.py b/backends/arm/tosa/dialect/ops/max_pool2d.py index a0559937719..a5aebe0ae0f 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Union +import math import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError @@ -15,61 +15,45 @@ @register_fake_tosa_op( - "MAX_POOL2D(Tensor input, int[2] kernel, int[2] stride, SymInt[4] pad) -> Tensor", + "MAX_POOL2D(Tensor input, " + "int[2] kernel, " + "int[2] stride, " + "SymInt[4] pad) -> Tensor", TosaSpecification.all_versions_and_profiles(), ) def MAX_POOL2D( x: torch.Tensor, - kernel: List[int], - stride: List[int], - pad: List[Union[int, torch.SymInt]], + kernel: list[int], + stride: list[int], + pad: list[int | torch.SymInt], ) -> torch.Tensor: - """Compute output meta for a TOSA MAX_POOL2D operation.""" tosa_spec = get_context_spec() - supported_int_types = [torch.int8] - supported_float_types = [ - torch.float16, - torch.float32, - ] + supported_dtypes = [] + if tosa_spec.support_integer(): + supported_dtypes.extend([torch.int8]) + if tosa_spec.support_float(): + supported_dtypes.extend([torch.float16, torch.float32]) if tosa_spec.support_extension("bf16"): - supported_float_types.append(torch.bfloat16) + supported_dtypes.append(torch.bfloat16) if tosa_spec.support_extension("int16"): - supported_int_types.append(torch.int16) + supported_dtypes.append(torch.int16) - if x.dtype in supported_int_types: - if not tosa_spec.support_integer(): - raise TosaValueError( - f"TOSA spec {tosa_spec} doesn't support integer pools", op="MAX_POOL2D" - ) - elif x.dtype in supported_float_types: - if not tosa_spec.support_float(): - raise TosaValueError( - f"TOSA spec {tosa_spec} doesn't support float pools", op="MAX_POOL2D" - ) - else: + if x.dtype not in supported_dtypes: raise TosaValueError( - f"Unsupported input dtype {x.dtype} for TOSA MAX_POOL2D", op="MAX_POOL2D" - ) - - if x.dim() != 4: - raise TosaValueError( - f"MAX_POOL2D requires a 4D tensor, got {x.dim()}D", op="MAX_POOL2D" - ) - - if len(kernel) != 2 or len(stride) != 2 or len(pad) != 4: - raise TosaValueError( - f"MAX_POOL2D expects kernel of length 2, stride of length 2, pad of " - f"length 4; got kernel={kernel}, stride={stride}, pad={pad}", + f"Unsupported input dtype {x.dtype}, supported types are {supported_dtypes}", op="MAX_POOL2D", ) - n, c, h, w = x.shape - k_h, k_w = kernel - s_h, s_w = stride - # TOSA MAX_POOL2D pad order is [top, bottom, left, right] - p_top, p_bot, p_left, p_right = pad + # Input is NHWC: [N, H, W, C] + N = x.shape[0] + H_in = x.shape[1] + W_in = x.shape[2] + C = x.shape[3] + + # pad is [top, bottom, left, right] + H_out = math.floor((H_in + pad[0] + pad[1] - kernel[0]) / stride[0]) + 1 + W_out = math.floor((W_in + pad[2] + pad[3] - kernel[1]) / stride[1]) + 1 - h_out = (h + p_top + p_bot - k_h) // s_h + 1 - w_out = (w + p_left + p_right - k_w) // s_w + 1 - return torch.empty(size=[n, c, h_out, w_out], dtype=x.dtype) + output_shape = [N, H_out, W_out, C] + return x.new_empty(output_shape, dtype=x.dtype) diff --git a/backends/arm/tosa/dialect/ops/transpose.py b/backends/arm/tosa/dialect/ops/transpose.py index 3b252d30ae4..41d1be74b8e 100644 --- a/backends/arm/tosa/dialect/ops/transpose.py +++ b/backends/arm/tosa/dialect/ops/transpose.py @@ -18,8 +18,7 @@ def TRANSPOSE(a, perms): # The TOSA TRANSPOSE only do the transpose in the TOSA serialized world, # so just return the same shape and type. - # For certain operators we need the data in a specific data format. Changing tosa_dim_order - # is not sufficient as we also need transpose the data. + # For certain operators we need the data in a specific data format. # By utilizing an edge IR passthrough operator we can keep the edge program in # channels-first/contiguous and get the desired behavior in the TOSA lowering. diff --git a/backends/arm/tosa/dialect/ops/transpose_conv2d.py b/backends/arm/tosa/dialect/ops/transpose_conv2d.py index 9a85b6e379c..01db849179b 100644 --- a/backends/arm/tosa/dialect/ops/transpose_conv2d.py +++ b/backends/arm/tosa/dialect/ops/transpose_conv2d.py @@ -46,12 +46,12 @@ def TRANSPOSE_CONV2D( ) N = x.shape[0] - C_out = weight.shape[1] - H_in, W_in = x.shape[2:] - kernel_h = weight.shape[2] - kernel_w = weight.shape[3] + C_out = weight.shape[0] + H_in, W_in = x.shape[1], x.shape[2] + kernel_h = weight.shape[1] + kernel_w = weight.shape[2] H_out = (H_in - 1) * stride[0] + out_pad[0] + out_pad[1] + kernel_h W_out = (W_in - 1) * stride[1] + out_pad[2] + out_pad[3] + kernel_w - output_shape = [N, C_out, H_out, W_out] - return torch.empty(size=output_shape, dtype=output_dtype) + output_shape = [N, H_out, W_out, C_out] + return x.new_empty(output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 6c7d1532218..87c0868967a 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -116,14 +116,13 @@ def map_dtype(data_type: torch.dtype) -> Any: # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None def extract_tensor_meta(meta): - """Extract dtype, shape, and dimension order from FX metadata. + """Extract dtype and shape from FX metadata. Args: meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple). Returns: - tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing - tensor dtype, shape, and dimension order. + tuple[ts.DType, tuple[int, ...]]: Tuple containing tensor dtype and shape. Raises: ValueError: If ``meta['val']`` is not a ``FakeTensor``. @@ -132,7 +131,7 @@ def extract_tensor_meta(meta): special_dtype = meta.get(TosaSpecialDtype.meta_key()) if special_dtype == TosaSpecialDtype.SHAPE: shape_len = len(meta["val"]) - return (ts.DType.SHAPE, (shape_len,), meta["tosa_dim_order"]) + return (ts.DType.SHAPE, (shape_len,)) if meta.get("val") is None: raise ValueError("Expected node.meta['val'] to be set to a FakeTensor") @@ -153,11 +152,7 @@ def extract_tensor_meta(meta): dtype = map_dtype(val.dtype) shape = tuple(val.size()) - if meta.get("tosa_dim_order") is not None: - dim_order = meta["tosa_dim_order"] - else: - dim_order = tuple(range(len(shape))) - return (dtype, shape, dim_order) + return (dtype, shape) class TosaArg: @@ -171,8 +166,6 @@ class TosaArg: otherwise. dtype (ts.DType | None): Inferred dtype when available. shape (tuple[int, ...] | None): Inferred shape when available. - dim_order (tuple[int, ...] | None): Dimension order, defaulting to - ``range(len(shape))``. special (list | None): Captured list when the argument is a sequence. number (float | int | None): Captured numeric value when provided. multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise. @@ -190,7 +183,7 @@ def __process_node(self, argument: torch.fx.Node): self.name = argument.name + suffix if "val" in argument.meta: - output_dtype, self.shape, self.dim_order = extract_tensor_meta( + output_dtype, self.shape = extract_tensor_meta( argument.meta ) # Handle special case of types not representable in torch (i.e. i48_t) if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): @@ -277,7 +270,6 @@ def __init__(self, argument: Any, tosa_spec: TosaSpecification) -> None: self.name = "" self.dtype = None self.shape = None - self.dim_order = None return raise RuntimeError( @@ -299,8 +291,6 @@ def __repr__(self): attrs.append(f"dtype={ts.DTypeNames[self.dtype]}") if self.shape is not None: attrs.append(f"shape={self.shape!r}") - if self.dim_order is not None: - attrs.append(f"dim_order={self.dim_order!r}") if hasattr(self, "special") and self.special is not None: attrs.append(f"special={self.special!r}") if hasattr(self, "number") and self.number is not None: diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 602a9548791..2714b2c3aa3 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -91,7 +91,7 @@ def broadcast_tensors(tosa_fb, nodes: list[Node]) -> list[Any]: broadcast_tensors = [] for node in nodes: - tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta) + tens_dtype, tens_shape = extract_tensor_meta(node.meta) list_tens_shape = list(tens_shape) # Already in the right shape we can just add it to the list. if list_tens_shape == common_shape: @@ -163,24 +163,6 @@ def build_reshape_tosa( ) -def tosa_shape(shape, dim_order): - """Reorder a shape tuple into TOSA layout while resolving symints. - - Args: - shape (Sequence[int | torch.SymInt]): Original tensor shape, - possibly containing ``torch.SymInt``. - dim_order (Sequence[int]): Desired dimension order for the output - shape. - - Returns: - list[int]: List containing the reordered dimensions where symbolic - values become ``-1``. - - """ - reordered = tuple([shape[dim] for dim in dim_order]) - # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes, - # in TOSA we do not have this concept and instead use -1. - removed_symints = tuple( - [-1 if isinstance(d, torch.SymInt) else d for d in reordered] - ) - return list(removed_symints) +def tosa_shape(shape): + """Convert a shape tuple to a TOSA-compatible list, resolving symints.""" + return list([-1 if isinstance(d, torch.SymInt) else d for d in shape]) diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 3eb77d4470f..5b5316245f8 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -81,9 +81,9 @@ fbcode_target(_kind = runtime.python_library, "pass_utils.py", ], deps = [ - "fbsource//third-party/pypi/beartype:beartype", ":utils", "//caffe2:torch", + "//executorch/backends/transforms:permute_pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/passes:lib", @@ -188,7 +188,6 @@ fbcode_target(_kind = python_unittest, ], typing = True, deps = [ - "fbsource//third-party/pypi/beartype:beartype", ":pass_utils", "//caffe2:torch", ], @@ -267,6 +266,10 @@ fbcode_target(_kind = runtime.python_library, "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops", + "//executorch/backends/transforms:fuse_cascaded_view_ops", + "//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass", + "//executorch/backends/transforms:permute_pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", @@ -304,6 +307,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:simplify_ops", "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/transforms:remove_permutes_around_elementwise_ops", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", @@ -322,6 +326,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:compiler_utils", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:postpone_permute_below_squeeze_view", "//executorch/exir:pass_base", "//executorch/exir:tensor", "//executorch/exir/dialects:lib", @@ -343,6 +348,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:remove_ops", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view", "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 023a6f5760a..d6ee88e94c6 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -12,9 +12,8 @@ import logging import math import operator -from collections import deque from numbers import Number -from typing import Any, Callable, cast, Optional, override +from typing import Any, cast, Optional, override # Import these for the cadence function signatures. import executorch.backends.cadence.aot.ops_registrations # noqa: F401 @@ -22,10 +21,8 @@ import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( broadcastable, - get_permuted_dims, get_scale, get_tensor_from_attr, - get_transposed_dims, get_zero_point, ) from executorch.backends.cadence.aot.pass_utils import ( @@ -36,9 +33,21 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps as _SharedFuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.fuse_cascaded_view_ops import ( + FuseCascadedViewOps as _SharedFuseCascadedViewOps, +) +from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import ( + FuseTransposeOrPermuteOpPairsPass as _SharedFuseTransposeOrPermuteOpPairsPass, +) +from executorch.backends.transforms.permute_pass_utils import ( + FuseOpPairsAcrossBranchesPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult from executorch.exir.passes.cse_pass import CSEPass from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -578,207 +587,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): - """ - Fuse a chain of transpose and permute ops into a single permute or a no-op. - Handles branches and chains permutes. - """ - - transpose_or_permute_target = { - exir_ops.edge.aten.transpose_copy.int, - exir_ops.edge.aten.permute_copy.default, - } - - @property - def targets(self) -> list[EdgeOpOverload]: - return list(self.transpose_or_permute_target) - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Fuse with the parent node if it's also a permute or a transpose. Since the - # pass interface traverses all ops in order the pass will properly fuse a chain - # of permutes. - parent_node = get_arg(node, "input", torch.fx.Node) - if parent_node.target not in self.transpose_or_permute_target: - return False - input_of_parent = get_arg(parent_node, "input", torch.fx.Node) - - # Compute combined effect of permutes. - dims = list(range(node.meta["val"].ndim)) - - if parent_node.target == exir_ops.edge.aten.transpose_copy.int: - dims = get_transposed_dims(parent_node, dims) - else: - dims = get_permuted_dims(parent_node, dims) - - if node.target == exir_ops.edge.aten.transpose_copy.int: - dims = get_transposed_dims(node, dims) - else: - dims = get_permuted_dims(node, dims) - - # If combined effect is identity replace the node with input. - if dims == sorted(dims): - node.replace_all_uses_with(input_of_parent) - else: - with node.graph.inserting_before(node): - new_permute = node.graph.call_function( - exir_ops.edge.aten.permute_copy.default, - args=(input_of_parent, dims), - ) - new_permute.meta = node.meta - node.replace_all_uses_with(new_permute) - - return True +class FuseCascadedTransposeOrPermuteOps(_SharedFuseCascadedTransposeOrPermuteOps): + pass @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedViewOps(RemoveOrReplacePassInterface): - """ - Fuse a cascaded chain of view ops - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [exir_ops.edge.aten.view_copy.default] - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Check if the input to this view node is also a view node - input_view = node.args[0] - if not isinstance(input_view, torch.fx.Node): - return False - - if ( - input_view.op != "call_function" - or input_view.target != exir_ops.edge.aten.view_copy.default - ): - return False - - # Replace the input of this view node with the input of the cascaded view - # This effectively "skips" the intermediate view node - node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) - return True - - -class FuseOpPairsAcrossBranchesPass(ExportPass): - """ - Base class for passes that fuse op pairs across branches. - Provides common functionality for finding and fusing producer-consumer chains. - """ - - def check_ok_to_fuse( - self, - producer: torch.fx.Node, - consumers: list[torch.fx.Node], - ) -> bool: - # Always ok to replace / remove. - return True - - def can_fuse_for_chain( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - ) -> bool: - """ - Returns true if producer and consumer can be fused for a single chain - (-> producer -> ops -> consumer ->) to (-> ops -> fused_op) - """ - if ( - isinstance(consumer.target, EdgeOpOverload) - and get_edge_overload_packet(consumer.target) in consumer_op_packets - ): - return True - return False - - def get_fuse_candidates( - self, - producer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - bypass_ops: set[EdgeOpOverload], - ) -> list[torch.fx.Node]: - # Start by iterating over all the users of this node, and check - # if they are have their target in consumer_op_packets. - users = deque(producer.users.keys()) - # This holds the list of the user ops that directly (or transitively - # via view/slice) consume this producer_op_packets, and hence can be removed. - removal_candidates = [] - while users: - user = users.popleft() - - # If the user is a bypass op, we bypass it, and examine - # its users instead for consumer_op_packets. - if user.target in bypass_ops: - users.extend(list(user.users.keys())) - elif self.can_fuse_for_chain(producer, user, consumer_op_packets): - removal_candidates.append(user) - else: - removal_candidates.clear() - break - return removal_candidates - - def find_and_fuse( - self, - graph_module: torch.fx.GraphModule, - producer_op_packets: set[EdgeOpOverloadPacket], - consumer_op_packets: set[EdgeOpOverloadPacket], - bypass_ops: set[EdgeOpOverload], - ) -> bool: - """ - Find and fuse producer-consumer op pairs. - - Returns True if any fusion was performed, False otherwise. - """ - modified = False - for node in graph_module.graph.nodes: - # We are only interested in ops that have overload target in - # producer_op. - if not ( - isinstance(node.target, EdgeOpOverload) - and get_edge_overload_packet(node.target) in producer_op_packets - ): - continue - - removal_candidates = self.get_fuse_candidates( - node, consumer_op_packets, bypass_ops - ) - - if len(removal_candidates) == 0: - # No candidates found. - continue - - if not self.check_ok_to_fuse(node, removal_candidates): - # Not ok to remove quant-dequant pairs or replace with requantize. - continue - - self.fuse(node, removal_candidates, graph_module) - modified = True - - if modified: - graph_module.recompile() - - return modified - - def get_fused_node( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - graph_module: torch.fx.GraphModule, - ) -> torch.fx.Node: - return consumer - - def fuse( - self, - node: torch.fx.Node, - removal_candidates: list[torch.fx.Node], - graph_module: torch.fx.GraphModule, - ) -> None: - # Replace all the uses of the producer op with it's input. - node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) - graph_module.graph.erase_node(node) - - # Iterate over all the removal candidates (quantize op users) and generate replacements. - for rnode in removal_candidates: - rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) - graph_module.graph.erase_node(rnode) +class FuseCascadedViewOps(_SharedFuseCascadedViewOps): + pass @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1123,89 +938,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): - """ - Fuse transpose or permute op pairs to a single view op. - (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) - This happens when op2(op1) == identity, modulo unitary dimensions. - 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] - so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. - """ - - # A list of ops that can be bypassed when looking for a - # dequantize->quantize chain - bypass_ops: set[EdgeOpOverload] = { - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - } - - def can_fuse_for_chain( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - ) -> bool: - if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): - return False - - # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions - producer_input = cast(torch.fx.Node, producer.args[0]) - if "val" not in producer_input.meta: - return False - input_shape = producer_input.meta["val"].shape - ident_dims = list(range(len(input_shape))) - # this mapping helps to handle both transpose and permutations - f: dict[Any, Callable] = { - exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, - exir_ops.edge.aten.permute_copy.default: get_permuted_dims, +class FuseTransposeOrPermuteOpPairsPass(_SharedFuseTransposeOrPermuteOpPairsPass): + bypass_ops: set[EdgeOpOverload] = ( + _SharedFuseTransposeOrPermuteOpPairsPass.bypass_ops + | { + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, } - in_dims = f[producer.target](producer, ident_dims) - out_dims = f[consumer.target](consumer, in_dims) - # Filtering out unitary dimensions - non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] - non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] - return non_unit_out_dims == non_unit_ident_dims - - def get_fused_node( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - graph_module: torch.fx.GraphModule, - ) -> torch.fx.Node: - # This step is important because of how we can fuse transpositions that are not perfectly - # reverse one of another but will be fused if there are unitary dimensions. - # The fused operation must have the same output shape as the consumer. - output_shape = consumer.meta["val"].shape - with graph_module.graph.inserting_after(consumer): - view = graph_module.graph.call_function( - exir_ops.edge.aten.view_copy.default, - (consumer.args[0], output_shape), - {}, - ) - return view - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # Remove any transpose/permutation op pair that cancel each other. - modified = self.find_and_fuse( - graph_module, - producer_op_packets={ - exir_ops.edge.aten.transpose_copy, - exir_ops.edge.aten.permute_copy, - }, - consumer_op_packets={ - exir_ops.edge.aten.transpose_copy, - exir_ops.edge.aten.permute_copy, - }, - bypass_ops=self.bypass_ops, - ) - if modified: - return super().call(graph_module) - return PassResult(graph_module, False) + ) @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index d03862d44fa..ab42ef43d56 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -7,20 +7,22 @@ # pyre-strict import dataclasses -from abc import abstractmethod from dataclasses import dataclass -from typing import Callable, List, Optional, override, Set, Type, TypeVar, Union +from typing import Callable, List, Optional, Set, Type, Union import torch -from beartype.door import die_if_unbearable from executorch.backends.cadence.aot.utils import get_edge_overload_packet + +# Re-exported for downstream consumers (noqa for flake8, `as X` for Pyre strict). +from executorch.backends.transforms.permute_pass_utils import ( # noqa: F401 + get_arg as get_arg, + HierarchicalInplacePassInterface as HierarchicalInplacePassInterface, + RemoveOrReplacePassInterface as RemoveOrReplacePassInterface, + set_arg as set_arg, +) from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, PassBase, PassResult +from executorch.exir.pass_base import PassBase, PassResult from torch._ops import OpOverloadPacket -from torch.fx import Node -from torch.fx.node import Argument - -T = TypeVar("T") # Is an overlap in tensor lifetime and storage allowed at the current opt level? @@ -207,152 +209,6 @@ def nodes_not_adjacent_in_gm( return True -def get_arg( - node: torch.fx.Node, - kwarg_name: str, - expected_type: Type[T] = Argument, -) -> T: - """ - Get the arg with arg_name of the node, returns default value if not set. - - Args: - node: The FX node to extract the argument from - kwarg_name: The name of the argument to extract - expected_type: Optional type to validate and cast the argument to. - If provided, asserts the argument is an instance of this type. - - Returns: - The argument value, optionally type-checked and cast to expected_type - - Example: - # Get a node argument with type checking - conv_weight_node = get_arg(node, "weight", torch.fx.Node) - - # Get a float argument with type checking - eps = get_arg(node, "eps", float) - - # Get an argument without type checking (returns Argument) - value = get_arg(node, "some_arg") - """ - # Try to get the arg from kwargs first since this is faster - if kwarg_name in node.kwargs: - value = node.kwargs[kwarg_name] - else: - # If it's not found in kwargs, try to normalize the args - normalized_args = node.normalized_arguments( - node.graph.owning_module, normalize_to_only_use_kwargs=True - ) - if not normalized_args: - raise RuntimeError( - f"get_arg: Node {node} does not support normalization of arguments" - ) - value = normalized_args.kwargs[kwarg_name] - - # Validate type using beartype's runtime type checker when a specific - # type is requested (not the default Argument type alias, which contains - # recursive forward references that beartype cannot resolve). - if expected_type is not Argument: - die_if_unbearable(value, expected_type) - return value # type: ignore[return-value] - - -def set_arg( - node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument -) -> None: - """ - Set the node's arg with its name to the given value. - """ - # Try to set the arg if it is present in kwargs first since this is faster - if kwarg_name in node.kwargs: - node.update_kwarg(kwarg_name, value) - return - - # If it's not found in kwargs, try to normalize the args and set the arg - normalized_args = node.normalized_arguments( - node.graph.owning_module, normalize_to_only_use_kwargs=True - ) - if not normalized_args: - raise RuntimeError( - f"set_arg: Node {node} does not support normalization of arguments" - ) - - kwargs = normalized_args.kwargs - if kwarg_name not in kwargs: - raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used") - - idx = list(kwargs.keys()).index(kwarg_name) - if idx < len(node.args): - node.update_arg(idx, value) - else: - node.update_kwarg(kwarg_name, value) - - def none_throws(x: Optional[PassResult]) -> PassResult: assert x is not None return x - - -class HierarchicalInplacePassInterface(ExportPass): - """A base class for passes that apply in-place modification to the graph module and its submodules. - Also calls ExportPass.call() in case the graph module is modified to ensure all nodes have valid `meta['val']`. - """ - - @abstractmethod - def _apply_flat_inplace(self, graph_module) -> bool: - """Apply in-place modification to the graph module.""" - raise NotImplementedError("`_apply_flat_inplace` must be implemented") - - def _apply_hierarchical_inplace(self, graph_module: torch.fx.GraphModule) -> bool: - """Apply in-place modification recursively to the graph module and its submodules.""" - - modified: bool = False - for module in filter( - lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() - ): - modified |= self._apply_flat_inplace(module) - - return modified - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - modified = self._apply_hierarchical_inplace(graph_module) - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - return super().call(graph_module) - - return PassResult(graph_module, False) - - -class RemoveOrReplacePassInterface(HierarchicalInplacePassInterface): - @property - @abstractmethod - def targets(self) -> list[EdgeOpOverload]: - """ - The list of targets to potentially remove or replace. - """ - raise NotImplementedError("`targets` must be implemented") - - @abstractmethod - def maybe_remove_or_replace(self, node: Node) -> bool: - """ - If the node should be removed/replaced, removes/replaces from the graph. Returns - True if the graph was modified, else False. - """ - raise NotImplementedError("`maybe_remove_or_replace` must be implemented") - - @override - def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: - changed = False - for target in self.targets: - for node in graph_module.graph.find_nodes( - op="call_function", target=target - ): - if len(node.users) == 0: - # It is possible that maybe_remove_or_replace would have removed - # this target by starting from a different target. In this case, - # we should ignore it. If it wasn't erased, it will be handled - # in eliminate_dead_code. - continue - changed |= self.maybe_remove_or_replace(node) - return changed diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index a85b13452c1..dabab032116 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,7 +6,6 @@ # pyre-strict -from dataclasses import dataclass, field from typing import cast, List, Optional, Sequence, Set, Type # Import these for the cadence function signatures. @@ -26,6 +25,9 @@ from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps as _SharedRemovePermutesAroundElementwiseOps, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, PassResult @@ -386,267 +388,17 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class RemovePermutesAroundElementwiseOps(ExportPass): - """ - Looks for subgraphs of elementwise ops sandwiched between permutes and removes those - permutes if possible. - Allows special handling for certain non-elementwise ops that can be easily updated - based on the permute's parameter such as mean, cat, and slice. - """ - - @dataclass() - class Subgraph: - start_permute: list[int] - end_permute: list[int] - # Nodes in the subgraph, does not include permutes. - nodes: set[torch.fx.Node] = field(default_factory=set) - # Incoming edges to the subgraph from permute nodes. - edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) - # Outgoing edges of the subgraph to permute nodes. - edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) - # Incoming edges from constant nodes that need a compensating permute. - constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field( - default_factory=set - ) - - permutable_ops: set[EdgeOpOverload] = { - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - exir_ops.edge.cadence.requantize.per_tensor, - exir_ops.edge.cadence.quantized_add.per_tensor, - # Ops that require special handling. - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.slice_copy.Tensor, - } - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] - processed_nodes: set[torch.fx.Node] = set() - for node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.permute_copy.default - ): - start_permute = self.get_permutation(node) - # Expected end permutation for the subgraph. - end_permute = [start_permute.index(i) for i in range(len(start_permute))] - - for user in node.users: - if user.target not in self.permutable_ops: - continue - # Create a separate subgraph for each user since there may be cases - # where only a portion of the users are permutable. - subgraph = self.Subgraph(start_permute, end_permute) - if self.visit(user, subgraph, processed_nodes): - subgraphs_found.append(subgraph) - for node in subgraph.nodes: - processed_nodes.add(node) - - modified = False - for subgraph in subgraphs_found: - self.permute_subgraph(subgraph) - modified = True - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - return super().call(graph_module) - - return PassResult(graph_module, False) - - def visit( # noqa: C901 - self, - node: torch.fx.Node, - subgraph: Subgraph, - processed_nodes: set[torch.fx.Node], - ) -> bool: - if node in subgraph.nodes: - return True - if node in processed_nodes or not self.is_node_permutable(node): - return False - subgraph.nodes.add(node) - - # Traverse downstream: - for user in node.users: - # Output should either go to a matching permute or another permutable op. - if user.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(user) != subgraph.end_permute: - return False - subgraph.edges_out.add((node, user)) - elif user.op == "output": - # Graph output requires the data in its original layout. - # Removing permutes here would silently change the output - # format, so treat this as an invalid subgraph boundary. - return False - elif not self.visit(user, subgraph, processed_nodes): - return False - - # Traverse upstream: - for inp in node.all_input_nodes: - # Input should either come from a matching permute or another permutable op. - if inp.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(inp) != subgraph.start_permute: - return False - subgraph.edges_in.add((inp, node)) - elif self._is_constant(inp): - # Only accept the constant if we can compensate it with a - # permute or view. Otherwise reject the subgraph. - const_rank = self._get_node_rank(inp) - if const_rank is None: - return False - if const_rank > len(subgraph.end_permute): - return False - if ( - const_rank < len(subgraph.end_permute) - and inp.meta.get("val") is None - ): - return False - subgraph.constant_edges_in.add((inp, node)) - elif not self.visit(inp, subgraph, processed_nodes): - return False - - return True - - def _is_constant(self, node: torch.fx.Node) -> bool: - """Check if a node's value is available at compile time. - Only considers direct constants (get_attr, parameter/buffer/constant - placeholders) — does not recurse into call_function chains to avoid - stack overflow on deep graphs.""" - if node.op == "get_attr": - return True - if node.op == "placeholder": - target = str(node.target) - return target.startswith(("b_", "p_", "c_")) - return False - - def _get_node_rank(self, node: torch.fx.Node) -> int | None: - """Return the tensor rank of a node's output, or None if unknown.""" - val = node.meta.get("val") - if val is not None and hasattr(val, "shape"): - return len(val.shape) - return None - - def is_node_permutable(self, node: torch.fx.Node) -> bool: - if node.target not in self.permutable_ops: - return False - if node.target == exir_ops.edge.aten.mean.dim: - # keepdim should be True. - if len(node.args) >= 3: - if not node.args[2]: - return False - elif "keepdim" in node.kwargs: - if not node.kwargs["keepdim"]: - return False - else: - # Default keepdim is False. - return False - return True - - def permute_subgraph(self, subgraph: Subgraph) -> None: - # Skip incoming permutes. - for inp, out in subgraph.edges_in: - assert inp.target == exir_ops.edge.aten.permute_copy.default - if len(inp.args) >= 1: - out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) - else: - out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) - - # Insert compensating permute on constant inputs. - # Since the subgraph's start permutes are being removed, the subgraph - # will operate in the un-permuted (original) layout. Constants that - # were in the permuted layout need end_permute (the inverse of - # start_permute) to convert back to the original layout. - for const_node, user_node in subgraph.constant_edges_in: - graph = const_node.graph - const_rank = self._get_node_rank(const_node) - permute_rank = len(subgraph.end_permute) - - with graph.inserting_after(const_node): - if const_rank is not None and const_rank == permute_rank: - new_node = graph.create_node( - "call_function", - exir_ops.edge.aten.permute_copy.default, - args=(const_node, subgraph.end_permute), - ) - elif ( - const_rank is not None - and const_rank < permute_rank - and const_node.meta.get("val") is not None - ): - # Rank mismatch (e.g. rank-1 bias with rank-4 permute). - # The constant is broadcastable and its shape is smaller - # than the permute rank, so we can't apply the permute - # directly. Instead, use view_copy to rearrange the - # shape according to the end_permute restricted to - # the trailing dimensions. - original_shape = list(const_node.meta["val"].shape) - # Pad shape to match permute rank for reordering - padded = [1] * (permute_rank - const_rank) + original_shape - target_shape = [padded[d] for d in subgraph.end_permute] - # Strip leading 1s back to original rank - target_shape = target_shape[permute_rank - const_rank :] - new_node = graph.create_node( - "call_function", - exir_ops.edge.aten.view_copy.default, - args=(const_node, target_shape), - ) - else: - # Cannot determine rank or handle this case; skip. - continue - user_node.replace_input_with(const_node, new_node) - - # Skip outgoing permutes. - for inp, out in subgraph.edges_out: - assert out.target == exir_ops.edge.aten.permute_copy.default - out.replace_all_uses_with(inp) - - # Handle dimension related node arguments. - for node in subgraph.nodes: - if node.target == exir_ops.edge.aten.cat.default: - self.update_cat(node, subgraph.start_permute) - elif node.target == exir_ops.edge.aten.mean.dim: - self.update_mean_dim(node, subgraph.start_permute) - elif node.target == exir_ops.edge.aten.slice_copy.Tensor: - self.update_slice_copy(node, subgraph.start_permute) - - def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg(1, start_permute[cast(int, node.args[1])]) - elif "dim" in node.kwargs: - node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) - else: - # Default cat dim is 0. - node.update_kwarg("dim", start_permute[0]) - - def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg( - 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] - ) - else: - node.update_kwarg( - "dim", - [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], - ) - - def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg(1, start_permute[cast(int, node.args[1])]) - else: - node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) - - def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: - assert permute_node.target == exir_ops.edge.aten.permute_copy.default - if len(permute_node.args) >= 2: - return cast(list[int], permute_node.args[1]) - assert "dim" in permute_node.kwargs - return cast(list[int], permute_node.kwargs["dim"]) +class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): + permutable_ops: set[EdgeOpOverload] = ( + _SharedRemovePermutesAroundElementwiseOps.permutable_ops + | { + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, + exir_ops.edge.cadence.requantize.per_tensor, + exir_ops.edge.cadence.quantized_add.per_tensor, + } + ) @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 8a0e112aaf3..e14471bc7ed 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -9,10 +9,9 @@ # This file contains all the functions that reorder ops in the graph module. -import copy from collections import defaultdict from math import prod -from typing import cast, DefaultDict, List, Tuple +from typing import DefaultDict, List, Tuple import torch import torch.fx @@ -24,6 +23,9 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView as _SharedPostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult @@ -633,191 +635,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(RemoveOrReplacePassInterface): - """ - A common pattern seen in transformer models. If the consumer of permute - is a view op, swap their order so permute is below view. - Change "permute -> view" to "view -> permute" - This is to optimize a chain of view->permute->view->permute... - so that the chain will be become view->v...->view->permute->p...->permute. - The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and - FuseCascadedViewOps(). - Notice the class name has ViewSqueeze to indicate the View is - functionally the same as a squeeze or unsqueeze. It does not necessarily - mean the view_copy is normalized from squeeze or unsqueeze. - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [exir_ops.edge.aten.permute_copy.default] - - # If list1 and list2 are same (same values and in same order) except - # list1 has one more element with value of 1. Return index of the extra 1. - # Otherwise return -1. - def check_if_shapes_differ_in_single_dim_of_size_1( - self, list1: List, list2: List - ) -> int: - if len(list1) != len(list2) + 1: - return -1 - for i in range(len(list2)): - if list1[i] != list2[i]: - # Return index of the extra 1 if the remaining parts are the same - if list1[i] == 1 and list2[i:] == list1[i + 1 :]: - return i - else: - return -1 - # If no difference was found, the extra element is at the end - if list1[-1] == 1: - return len(list2) - else: - return -1 - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - users = list(node.users.keys()) - # Transform only for pattern permute_copy->view_copy, and - # view_copy op is the only user of permute_copy. - if len(users) != 1 or users[0].target not in ( - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.view.default, - ): - return False - - # If the permute_node/view_node was newly added to the - # graph, it may not have the meta["val"] FakeTensor. - # Skip in this case. - if node.meta.get("val") is None: - return False - - permute_node_shape = [*cast(list, get_shape(node.graph.owning_module, node))] - - permute_dims = cast(list, node.args[1]) - view_node = users[0] - - if view_node.meta.get("val") is None: - return False - - view_node_shape = [*cast(list, get_shape(node.graph.owning_module, view_node))] - - pred = node.args[0] - if not isinstance(pred, torch.fx.Node) or pred.meta.get("val") is None: - return False - - pred_shape = [*cast(list, get_shape(node.graph.owning_module, pred))] - - # Handle three cases - # 1. view_node_shape is almost same as permute_node_shape - # except the view_node has one more dim somewhere - # and the extra dim has value of 1. - # 2. view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - # 3. view_node_shape is the same as permute_node_shape. - - if len(permute_node_shape) + 1 == len(view_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - view_node_shape, permute_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except it has one more dim somewhere - # and the extra dim has value of 1. - new_view_shape = copy.deepcopy(pred_shape) - new_view_shape.insert(index, 1) - new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] - new_permute_dims.insert(index, index) - self._insert_nodes( - node.graph, - pred, - node, - view_node, - new_view_shape, - new_permute_dims, - ) - return True - - elif len(view_node_shape) + 1 == len(permute_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - permute_node_shape, view_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - # Convert permute_dims to list of ints - index_to_remove = permute_dims[index] - new_view_shape = copy.deepcopy(pred_shape) - del new_view_shape[index_to_remove] - new_permute_dims = [ - x - 1 if x > index_to_remove else x for x in permute_dims - ] - del new_permute_dims[index] - self._insert_nodes( - node.graph, - pred, - node, - view_node, - new_view_shape, - new_permute_dims, - ) - return True - - elif permute_node_shape == view_node_shape: - # view_node_shape is the same as permute_node_shape - # Replace the uses of view_node with permute_node - view_node.replace_all_uses_with(node) - return True - - return False - - def _insert_nodes( - self, - graph: torch.fx.Graph, - pred: torch.fx.Node, - permute_node: torch.fx.Node, - view_node: torch.fx.Node, - new_view_shape: List, - new_permute_dims: List, - ) -> None: - with graph.inserting_after(view_node): - # Target is guaranteed to be a callable since it's from the graph - view_target = view_node.target - assert callable(view_target), "View target must be callable" - new_view_node = graph.call_function( - view_target, - args=(pred, new_view_shape), - ) - - with graph.inserting_after(new_view_node): - # Target is guaranteed to be a callable since it's from our targets list - permute_target = permute_node.target - assert callable(permute_target), "Permute target must be callable" - new_permute_node = graph.call_function( - permute_target, - args=(new_view_node, new_permute_dims), - ) - new_permute_node.meta = view_node.meta - view_node.replace_all_uses_with(new_permute_node) - - # view_node is user of permute_node, so must erase view_node first - graph.erase_node(view_node) - graph.erase_node(permute_node) - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # This pass needs to iterate until convergence because postponing - # one permute may enable postponing another in a chain - iter_count = 0 - local_modified = False - overall_modified = False - while local_modified or iter_count == 0: - result = super().call(graph_module) - local_modified = result.modified - overall_modified |= local_modified - graph_module = result.graph_module - iter_count += 1 - if iter_count == 4: - break - - return PassResult(graph_module, overall_modified) +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView( + _SharedPostponePermuteOpBelowSqueezeOrUnsqueezeLikeView +): + pass # The following class consolidates functions to reoder ops (i.e., either hoist diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e09a6589e76..4b60feb2121 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -28,6 +28,9 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import is_depthwise_conv +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass as _SharedReplaceNopTransposeOrPermuteWithViewPass, +) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -1745,77 +1748,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): - """ - If the transpose/permute op does not change the byte order (e.g., - transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced - by view op. - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [ - exir_ops.edge.aten.transpose_copy.int, - exir_ops.edge.aten.permute_copy.default, - ] - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Get the input tensor and shape - in_tensor_node = node.args[0] - assert isinstance(in_tensor_node, torch.fx.Node) - in_shape = in_tensor_node.meta["val"].shape - # Get the output tensor shape - out_shape = node.meta["val"].shape - - if node.target == exir_ops.edge.aten.transpose_copy.int: - # Get the two dims to be transposed - dim0 = cast(int, node.args[1]) - dim1 = cast(int, node.args[2]) - dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 - dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 - # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; - # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. - both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 - either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( - in_shape[dim0] == 1 or in_shape[dim1] == 1 - ) - if both_one or either_one_and_consecutive: - with node.graph.inserting_before(node): - new_node = node.graph.call_function( - exir_ops.edge.aten.view_copy.default, - args=(in_tensor_node, list(out_shape)), - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - return True - - elif node.target == exir_ops.edge.aten.permute_copy.default: - old_dims = list(range(len(in_shape))) - new_dims = cast(Sequence[int], node.args[1]) - # If the permute does not change anything, return the input as output. - if old_dims == list(new_dims): - node.replace_all_uses_with(in_tensor_node) - return True - # Get the old dim order, and the permuted dim order for all dims that - # are not 1. - old_order = [ - dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 - ] - new_order = [ - dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 - ] - # If the byte ordering for non-unit dims is unchanged, this is a nop. - if old_order == new_order: - with node.graph.inserting_before(node): - new_node = node.graph.call_function( - exir_ops.edge.aten.view_copy.default, - args=(in_tensor_node, list(out_shape)), - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - return True - - return False +class ReplaceNopTransposeOrPermuteWithViewPass( + _SharedReplaceNopTransposeOrPermuteWithViewPass +): + pass @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/cadence/aot/tests/test_pass_utils.py b/backends/cadence/aot/tests/test_pass_utils.py index 2776a370541..c9987cb7196 100644 --- a/backends/cadence/aot/tests/test_pass_utils.py +++ b/backends/cadence/aot/tests/test_pass_utils.py @@ -7,10 +7,8 @@ # pyre-strict import unittest -from typing import List import torch -from beartype.roar import BeartypeDoorHintViolation from executorch.backends.cadence.aot.pass_utils import get_arg @@ -61,9 +59,11 @@ def test_get_arg_with_list_type(self) -> None: self.assertEqual(result, [1, 2, 3]) def test_get_arg_with_list_int_type(self) -> None: - """Test get_arg validates parameterized List[int] type.""" + """Test get_arg accepts parameterized List[int] type without crashing.""" _, node = self._create_graph_with_kwargs(input=[1, 2, 3], other=2) - result = get_arg(node, "input", List[int]) + # Subscripted generics can't be checked with isinstance, so get_arg + # silently skips validation. Just verify it returns the value. + result = get_arg(node, "input", list) self.assertEqual(result, [1, 2, 3]) def test_get_arg_without_type_returns_value(self) -> None: @@ -73,13 +73,13 @@ def test_get_arg_without_type_returns_value(self) -> None: self.assertEqual(result, 42) def test_get_arg_type_mismatch_raises(self) -> None: - """Test get_arg raises BeartypeDoorHintViolation on type mismatch.""" + """Test get_arg raises TypeError on type mismatch.""" _, node = self._create_graph_with_kwargs(input="not_an_int", other=2) - with self.assertRaises(BeartypeDoorHintViolation): + with self.assertRaises(TypeError): get_arg(node, "input", int) def test_get_arg_list_type_mismatch_raises(self) -> None: - """Test get_arg raises BeartypeDoorHintViolation when list elements mismatch.""" - _, node = self._create_graph_with_kwargs(input=["a", "b"], other=2) - with self.assertRaises(BeartypeDoorHintViolation): - get_arg(node, "input", List[int]) + """Test get_arg raises TypeError when value is not a list.""" + _, node = self._create_graph_with_kwargs(input="not_a_list", other=2) + with self.assertRaises(TypeError): + get_arg(node, "input", list) diff --git a/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py new file mode 100644 index 00000000000..b8d6c75a174 --- /dev/null +++ b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from executorch.backends.transforms.permute_pass_utils import ( + get_arg, + get_permuted_dims, + get_transposed_dims, + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch.fx import Node + + +class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): + """ + Fuse a chain of transpose and permute ops into a single permute or a no-op. + Handles branches and chains permutes. + """ + + transpose_or_permute_target = { + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + } + + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transpose_or_permute_target) + + def maybe_remove_or_replace(self, node: Node) -> bool: + # Fuse with the parent node if it's also a permute or a transpose. Since the + # pass interface traverses all ops in order the pass will properly fuse a chain + # of permutes. + parent_node = get_arg(node, "input", Node) + if parent_node.target not in self.transpose_or_permute_target: + return False + input_of_parent = get_arg(parent_node, "input", Node) + + # Compute combined effect of permutes. + dims = list(range(node.meta["val"].ndim)) + + if parent_node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(parent_node, dims) + else: + dims = get_permuted_dims(parent_node, dims) + + if node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(node, dims) + else: + dims = get_permuted_dims(node, dims) + + # If combined effect is identity replace the node with input. + if dims == sorted(dims): + node.replace_all_uses_with(input_of_parent) + else: + with node.graph.inserting_before(node): + new_permute = node.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(input_of_parent, dims), + ) + new_permute.meta = node.meta + node.replace_all_uses_with(new_permute) + + return True diff --git a/backends/transforms/fuse_cascaded_view_ops.py b/backends/transforms/fuse_cascaded_view_ops.py new file mode 100644 index 00000000000..7daf6ffe92e --- /dev/null +++ b/backends/transforms/fuse_cascaded_view_ops.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch +from executorch.backends.transforms.permute_pass_utils import ( + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload + + +class FuseCascadedViewOps(RemoveOrReplacePassInterface): + """ + Fuse a cascaded chain of view ops + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.view_copy.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if the input to this view node is also a view node + input_view = node.args[0] + if not isinstance(input_view, torch.fx.Node): + return False + + if ( + input_view.op != "call_function" + or input_view.target != exir_ops.edge.aten.view_copy.default + ): + return False + + # Replace the input of this view node with the input of the cascaded view + # This effectively "skips" the intermediate view node + node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) + return True diff --git a/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py b/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py new file mode 100644 index 00000000000..008775511ec --- /dev/null +++ b/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Any, Callable, cast + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + FuseOpPairsAcrossBranchesPass, + get_permuted_dims, + get_transposed_dims, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import PassResult + + +class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): + """ + Fuse transpose or permute op pairs to a single view op. + (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) + This happens when op2(op1) == identity, modulo unitary dimensions. + 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] + so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. + """ + + # A list of ops that can be bypassed when looking for a + # transpose-permute chain. Subclasses can extend this with backend-specific ops. + bypass_ops: set[EdgeOpOverload] = { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + + def can_fuse_for_chain( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + ) -> bool: + if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): + return False + + # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions + producer_input = cast(torch.fx.Node, producer.args[0]) + if "val" not in producer_input.meta: + return False + input_shape = producer_input.meta["val"].shape + ident_dims = list(range(len(input_shape))) + # this mapping helps to handle both transpose and permutations + f: dict[Any, Callable] = { + exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, + exir_ops.edge.aten.permute_copy.default: get_permuted_dims, + } + in_dims = f[producer.target](producer, ident_dims) + out_dims = f[consumer.target](consumer, in_dims) + # Filtering out unitary dimensions + non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] + non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] + return non_unit_out_dims == non_unit_ident_dims + + def get_fused_node( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + graph_module: torch.fx.GraphModule, + ) -> torch.fx.Node: + # This step is important because of how we can fuse transpositions that are not perfectly + # reverse one of another but will be fused if there are unitary dimensions. + # The fused operation must have the same output shape as the consumer. + output_shape = consumer.meta["val"].shape + with graph_module.graph.inserting_after(consumer): + view = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (consumer.args[0], output_shape), + {}, + ) + return view + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Remove any transpose/permutation op pair that cancel each other. + modified = self.find_and_fuse( + graph_module, + producer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, + consumer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, + bypass_ops=self.bypass_ops, + ) + if modified: + return super().call(graph_module) + return PassResult(graph_module, False) diff --git a/backends/transforms/permute_pass_utils.py b/backends/transforms/permute_pass_utils.py new file mode 100644 index 00000000000..95a429e09b4 --- /dev/null +++ b/backends/transforms/permute_pass_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors + +"""Shared utilities and base classes for permute optimization passes. + +These were originally in executorch.backends.cadence.aot and are used by +both the Cadence and Arm backends. +""" + +from abc import abstractmethod +from collections import deque +from typing import cast, List, Optional, Type, TypeVar, Union + +import torch +import torch.fx +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node +from torch.fx.node import Argument + +T = TypeVar("T") + + +def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: + edge_op_namespace, edge_op_name = ( + edge_op.namespace, + edge_op._schema.name.split("::")[1], + ) + edge_op_overload_packet = getattr( + getattr(exir_ops.edge, edge_op_namespace), edge_op_name + ) + return edge_op_overload_packet + + +def get_shape( + graph_module: torch.fx.GraphModule, node: torch.fx.Node +) -> Union[torch.Size, None]: + """Return the shape of the tensor corresponding to node.""" + try: + if isinstance(node, (float, int, bool)): + return torch.Size([1]) + fake_tensor = node.meta.get("val") + if fake_tensor is not None: + return fake_tensor.shape + if node.op == "get_attr": + attr_node = getattr(graph_module, node.target) + return attr_node.shape + return None + except RuntimeError: + return None + + +def get_transposed_dims( + node: torch.fx.Node, dims: Optional[List[int]] = None +) -> List[int]: + """Applies the transposition as given by node onto the dimensions given in input.""" + assert node.target == exir_ops.edge.aten.transpose_copy.int + assert dims is not None + dim_len = len(dims) + transpose_dims0 = node.args[1] + transpose_dims1 = node.args[2] + assert isinstance(transpose_dims0, int) + assert isinstance(transpose_dims1, int) + dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len + dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len + new_dims = list(dims) + new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0] + return new_dims + + +def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: + """Applies the permutation as given by node onto the dimensions given in input.""" + assert node.target == exir_ops.edge.aten.permute_copy.default + # pyre-fixme[6]: This combined typecheck isn't supported yet. + permute_dims: List[int] = list(node.args[1]) + assert all(isinstance(x, int) for x in permute_dims) + return [dims[x] for x in permute_dims] + + +def get_arg( + node: torch.fx.Node, + kwarg_name: str, + expected_type: Type[T] = Argument, +) -> T: + """Get the arg with kwarg_name of the node.""" + if kwarg_name in node.kwargs: + value = node.kwargs[kwarg_name] + else: + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"get_arg: Node {node} does not support normalization of arguments" + ) + value = normalized_args.kwargs[kwarg_name] + + if expected_type is not Argument: + try: + type_ok = isinstance(value, expected_type) + except TypeError: + # Subscripted generics (e.g. List[int]) don't support isinstance. + # Fall through — caller is responsible for correctness. + type_ok = True + if not type_ok: + raise TypeError( + f"get_arg: expected {expected_type} for '{kwarg_name}', got {type(value)}" + ) + return value # type: ignore[return-value] + + +def set_arg( + node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument +) -> None: + """Set the node's arg with its name to the given value.""" + if kwarg_name in node.kwargs: + node.update_kwarg(kwarg_name, value) + return + + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"set_arg: Node {node} does not support normalization of arguments" + ) + + kwargs = normalized_args.kwargs + if kwarg_name not in kwargs: + raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used") + + idx = list(kwargs.keys()).index(kwarg_name) + if idx < len(node.args): + node.update_arg(idx, value) + else: + node.update_kwarg(kwarg_name, value) + + +class HierarchicalInplacePassInterface(ExportPass): + """A base class for passes that apply in-place modification to the graph module and its submodules.""" + + @abstractmethod + def _apply_flat_inplace(self, graph_module) -> bool: + raise NotImplementedError("`_apply_flat_inplace` must be implemented") + + def _apply_hierarchical_inplace(self, graph_module: torch.fx.GraphModule) -> bool: + modified: bool = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + modified |= self._apply_flat_inplace(module) + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = self._apply_hierarchical_inplace(graph_module) + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + return PassResult(graph_module, False) + + +class RemoveOrReplacePassInterface(HierarchicalInplacePassInterface): + @property + @abstractmethod + def targets(self) -> list[EdgeOpOverload]: + raise NotImplementedError("`targets` must be implemented") + + @abstractmethod + def maybe_remove_or_replace(self, node: Node) -> bool: + raise NotImplementedError("`maybe_remove_or_replace` must be implemented") + + def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: + changed = False + for target in self.targets: + for node in graph_module.graph.find_nodes( + op="call_function", target=target + ): + if len(node.users) == 0: + continue + changed |= self.maybe_remove_or_replace(node) + return changed + + +class FuseOpPairsAcrossBranchesPass(ExportPass): + """Base class for passes that fuse op pairs across branches.""" + + def check_ok_to_fuse( + self, + producer: torch.fx.Node, + consumers: list[torch.fx.Node], + ) -> bool: + return True + + def can_fuse_for_chain( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + ) -> bool: + if ( + isinstance(consumer.target, EdgeOpOverload) + and get_edge_overload_packet(consumer.target) in consumer_op_packets + ): + return True + return False + + def get_fuse_candidates( + self, + producer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + bypass_ops: set[EdgeOpOverload], + ) -> list[torch.fx.Node]: + users = deque(producer.users.keys()) + removal_candidates = [] + while users: + user = users.popleft() + if user.target in bypass_ops: + users.extend(list(user.users.keys())) + elif self.can_fuse_for_chain(producer, user, consumer_op_packets): + removal_candidates.append(user) + else: + removal_candidates.clear() + break + return removal_candidates + + def find_and_fuse( + self, + graph_module: torch.fx.GraphModule, + producer_op_packets: set[EdgeOpOverloadPacket], + consumer_op_packets: set[EdgeOpOverloadPacket], + bypass_ops: set[EdgeOpOverload], + ) -> bool: + modified = False + for node in graph_module.graph.nodes: + if not ( + isinstance(node.target, EdgeOpOverload) + and get_edge_overload_packet(node.target) in producer_op_packets + ): + continue + removal_candidates = self.get_fuse_candidates( + node, consumer_op_packets, bypass_ops + ) + if len(removal_candidates) == 0: + continue + if not self.check_ok_to_fuse(node, removal_candidates): + continue + self.fuse(node, removal_candidates, graph_module) + modified = True + if modified: + graph_module.recompile() + return modified + + def get_fused_node( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + graph_module: torch.fx.GraphModule, + ) -> torch.fx.Node: + return consumer + + def fuse( + self, + node: torch.fx.Node, + removal_candidates: list[torch.fx.Node], + graph_module: torch.fx.GraphModule, + ) -> None: + node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) + graph_module.graph.erase_node(node) + for rnode in removal_candidates: + rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) + graph_module.graph.erase_node(rnode) diff --git a/backends/transforms/postpone_permute_below_squeeze_view.py b/backends/transforms/postpone_permute_below_squeeze_view.py new file mode 100644 index 00000000000..f676e19fb65 --- /dev/null +++ b/backends/transforms/postpone_permute_below_squeeze_view.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +from typing import cast, List + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + get_shape, + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import PassResult + + +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(RemoveOrReplacePassInterface): + """ + A common pattern seen in transformer models. If the consumer of permute + is a view op, swap their order so permute is below view. + Change "permute -> view" to "view -> permute" + This is to optimize a chain of view->permute->view->permute... + so that the chain will be become view->v...->view->permute->p...->permute. + The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and + FuseCascadedViewOps(). + Notice the class name has ViewSqueeze to indicate the View is + functionally the same as a squeeze or unsqueeze. It does not necessarily + mean the view_copy is normalized from squeeze or unsqueeze. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.permute_copy.default] + + # If list1 and list2 are same (same values and in same order) except + # list1 has one more element with value of 1. Return index of the extra 1. + # Otherwise return -1. + def check_if_shapes_differ_in_single_dim_of_size_1( + self, list1: List, list2: List + ) -> int: + if len(list1) != len(list2) + 1: + return -1 + for i in range(len(list2)): + if list1[i] != list2[i]: + # Return index of the extra 1 if the remaining parts are the same + if list1[i] == 1 and list2[i:] == list1[i + 1 :]: + return i + else: + return -1 + # If no difference was found, the extra element is at the end + if list1[-1] == 1: + return len(list2) + else: + return -1 + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + users = list(node.users.keys()) + # Transform only for pattern permute_copy->view_copy, and + # view_copy op is the only user of permute_copy. + if len(users) != 1 or users[0].target not in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + ): + return False + + # If the permute_node/view_node was newly added to the + # graph, it may not have the meta["val"] FakeTensor. + # Skip in this case. + if node.meta.get("val") is None: + return False + + permute_node_shape = [*cast(list, get_shape(node.graph.owning_module, node))] + + permute_dims = cast(list, node.args[1]) + view_node = users[0] + + if view_node.meta.get("val") is None: + return False + + view_node_shape = [*cast(list, get_shape(node.graph.owning_module, view_node))] + + pred = node.args[0] + if not isinstance(pred, torch.fx.Node) or pred.meta.get("val") is None: + return False + + pred_shape = [*cast(list, get_shape(node.graph.owning_module, pred))] + + # Handle three cases + # 1. view_node_shape is almost same as permute_node_shape + # except the view_node has one more dim somewhere + # and the extra dim has value of 1. + # 2. view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # 3. view_node_shape is the same as permute_node_shape. + + if len(permute_node_shape) + 1 == len(view_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + view_node_shape, permute_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except it has one more dim somewhere + # and the extra dim has value of 1. + new_view_shape = copy.deepcopy(pred_shape) + new_view_shape.insert(index, 1) + new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] + new_permute_dims.insert(index, index) + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif len(view_node_shape) + 1 == len(permute_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + permute_node_shape, view_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # Convert permute_dims to list of ints + index_to_remove = permute_dims[index] + new_view_shape = copy.deepcopy(pred_shape) + del new_view_shape[index_to_remove] + new_permute_dims = [ + x - 1 if x > index_to_remove else x for x in permute_dims + ] + del new_permute_dims[index] + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif permute_node_shape == view_node_shape: + # view_node_shape is the same as permute_node_shape + # Replace the uses of view_node with permute_node + view_node.replace_all_uses_with(node) + return True + + return False + + def _insert_nodes( + self, + graph: torch.fx.Graph, + pred: torch.fx.Node, + permute_node: torch.fx.Node, + view_node: torch.fx.Node, + new_view_shape: List, + new_permute_dims: List, + ) -> None: + with graph.inserting_after(view_node): + # Target is guaranteed to be a callable since it's from the graph + view_target = view_node.target + assert callable(view_target), "View target must be callable" + new_view_node = graph.call_function( + view_target, + args=(pred, new_view_shape), + ) + + with graph.inserting_after(new_view_node): + # Target is guaranteed to be a callable since it's from our targets list + permute_target = permute_node.target + assert callable(permute_target), "Permute target must be callable" + new_permute_node = graph.call_function( + permute_target, + args=(new_view_node, new_permute_dims), + ) + new_permute_node.meta = view_node.meta + view_node.replace_all_uses_with(new_permute_node) + + # view_node is user of permute_node, so must erase view_node first + graph.erase_node(view_node) + graph.erase_node(permute_node) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # This pass needs to iterate until convergence because postponing + # one permute may enable postponing another in a chain + iter_count = 0 + local_modified = False + overall_modified = False + while local_modified or iter_count == 0: + result = super().call(graph_module) + local_modified = result.modified + overall_modified |= local_modified + graph_module = result.graph_module + iter_count += 1 + if iter_count == 4: + break + + return PassResult(graph_module, overall_modified) diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py new file mode 100644 index 00000000000..dd28b13045d --- /dev/null +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from dataclasses import dataclass, field +from typing import cast + +import torch +import torch.fx +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemovePermutesAroundElementwiseOps(ExportPass): + """ + Looks for subgraphs of elementwise ops sandwiched between permutes and removes those + permutes if possible. + Allows special handling for certain non-elementwise ops that can be easily updated + based on the permute's parameter such as mean, cat, and slice. + """ + + @dataclass() + class Subgraph: + start_permute: list[int] + end_permute: list[int] + # Nodes in the subgraph, does not include permutes. + nodes: set[torch.fx.Node] = field(default_factory=set) + # Incoming edges to the subgraph from permute nodes. + edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) + # Outgoing edges of the subgraph to permute nodes. + edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) + # Incoming edges from constant nodes that need a compensating permute. + constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field( + default_factory=set + ) + + permutable_ops: set[EdgeOpOverload] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + # Ops that require special handling. + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.slice_copy.Tensor, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] + processed_nodes: set[torch.fx.Node] = set() + for node in graph_module.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ): + start_permute = self.get_permutation(node) + # Expected end permutation for the subgraph. + end_permute = [start_permute.index(i) for i in range(len(start_permute))] + + for user in node.users: + if user.target not in self.permutable_ops: + continue + # Create a separate subgraph for each user since there may be cases + # where only a portion of the users are permutable. + subgraph = self.Subgraph(start_permute, end_permute) + if self.visit(user, subgraph, processed_nodes): + subgraphs_found.append(subgraph) + for node in subgraph.nodes: + processed_nodes.add(node) + + modified = False + for subgraph in subgraphs_found: + self.permute_subgraph(subgraph) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) + + def visit( # noqa: C901 + self, + node: torch.fx.Node, + subgraph: Subgraph, + processed_nodes: set[torch.fx.Node], + ) -> bool: + if node in subgraph.nodes: + return True + if node in processed_nodes or not self.is_node_permutable(node): + return False + subgraph.nodes.add(node) + + # Traverse downstream: + for user in node.users: + # Output should either go to a matching permute or another permutable op. + if user.target == exir_ops.edge.aten.permute_copy.default: + if self.get_permutation(user) != subgraph.end_permute: + return False + subgraph.edges_out.add((node, user)) + elif user.op == "output": + # Graph output requires the data in its original layout. + # Removing permutes here would silently change the output + # format, so treat this as an invalid subgraph boundary. + return False + elif not self.visit(user, subgraph, processed_nodes): + return False + + # Traverse upstream: + for inp in node.all_input_nodes: + # Input should either come from a matching permute or another permutable op. + if inp.target == exir_ops.edge.aten.permute_copy.default: + if self.get_permutation(inp) != subgraph.start_permute: + return False + subgraph.edges_in.add((inp, node)) + elif self._is_constant(inp): + # Only accept the constant if we can insert a compensating + # permute or view. Otherwise reject the subgraph. + const_rank = self._get_node_rank(inp) + permute_rank = len(subgraph.end_permute) + if const_rank is None: + return False + if const_rank > permute_rank: + return False + if const_rank < permute_rank and inp.meta.get("val") is None: + return False + subgraph.constant_edges_in.add((inp, node)) + elif not self.visit(inp, subgraph, processed_nodes): + return False + + return True + + def _is_constant(self, node: torch.fx.Node) -> bool: + """Check if a node's value is available at compile time. + Only considers direct constants (get_attr, parameter/buffer/constant + placeholders) — does not recurse into call_function chains to avoid + stack overflow on deep graphs.""" + if node.op == "get_attr": + return True + if node.op == "placeholder": + target = str(node.target) + return target.startswith(("b_", "p_", "c_")) + return False + + def _get_node_rank(self, node: torch.fx.Node) -> int | None: + """Return the tensor rank of a node's output, or None if unknown.""" + val = node.meta.get("val") + if val is not None and hasattr(val, "shape"): + return len(val.shape) + return None + + def is_node_permutable(self, node: torch.fx.Node) -> bool: + if node.target not in self.permutable_ops: + return False + if node.target == exir_ops.edge.aten.mean.dim: + # keepdim should be True. + if len(node.args) >= 3: + if not node.args[2]: + return False + elif "keepdim" in node.kwargs: + if not node.kwargs["keepdim"]: + return False + else: + # Default keepdim is False. + return False + return True + + def permute_subgraph(self, subgraph: Subgraph) -> None: + # Skip incoming permutes. + for inp, out in subgraph.edges_in: + assert inp.target == exir_ops.edge.aten.permute_copy.default + if len(inp.args) >= 1: + out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) + else: + out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) + + # Insert compensating permute on constant inputs. + # Since the subgraph's start permutes are being removed, the subgraph + # will operate in the un-permuted (original) layout. Constants that + # were in the permuted layout need end_permute (the inverse of + # start_permute) to convert back to the original layout. + for const_node, user_node in subgraph.constant_edges_in: + graph = const_node.graph + const_rank = self._get_node_rank(const_node) + permute_rank = len(subgraph.end_permute) + + with graph.inserting_after(const_node): + if const_rank is not None and const_rank == permute_rank: + new_node = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(const_node, subgraph.end_permute), + ) + elif ( + const_rank is not None + and const_rank < permute_rank + and const_node.meta.get("val") is not None + ): + # Rank mismatch (e.g. rank-1 bias with rank-4 permute). + # The constant is broadcastable and its shape is smaller + # than the permute rank, so we can't apply the permute + # directly. Instead, use view_copy to rearrange the + # shape according to the end_permute restricted to + # the trailing dimensions. + original_shape = list(const_node.meta["val"].shape) + # Pad shape to match permute rank for reordering + padded = [1] * (permute_rank - const_rank) + original_shape + target_shape = [padded[d] for d in subgraph.end_permute] + # Strip leading 1s back to original rank + target_shape = target_shape[permute_rank - const_rank :] + new_node = graph.create_node( + "call_function", + exir_ops.edge.aten.view_copy.default, + args=(const_node, target_shape), + ) + else: + # Cannot determine rank or handle this case; skip. + continue + user_node.replace_input_with(const_node, new_node) + + # Skip outgoing permutes. + for inp, out in subgraph.edges_out: + assert out.target == exir_ops.edge.aten.permute_copy.default + out.replace_all_uses_with(inp) + + # Handle dimension related node arguments. + for node in subgraph.nodes: + if node.target == exir_ops.edge.aten.cat.default: + self.update_cat(node, subgraph.start_permute) + elif node.target == exir_ops.edge.aten.mean.dim: + self.update_mean_dim(node, subgraph.start_permute) + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: + self.update_slice_copy(node, subgraph.start_permute) + + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg(1, start_permute[cast(int, node.args[1])]) + elif "dim" in node.kwargs: + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) + else: + # Default cat dim is 0. + node.update_kwarg("dim", start_permute[0]) + + def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg( + 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] + ) + else: + node.update_kwarg( + "dim", + [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], + ) + + def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg(1, start_permute[cast(int, node.args[1])]) + else: + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) + + def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: + assert permute_node.target == exir_ops.edge.aten.permute_copy.default + if len(permute_node.args) >= 2: + return cast(list[int], permute_node.args[1]) + assert "dim" in permute_node.kwargs + return cast(list[int], permute_node.kwargs["dim"]) diff --git a/backends/transforms/replace_nop_transpose_or_permute_with_view.py b/backends/transforms/replace_nop_transpose_or_permute_with_view.py new file mode 100644 index 00000000000..ccfb4ebe8b9 --- /dev/null +++ b/backends/transforms/replace_nop_transpose_or_permute_with_view.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast, Sequence + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload + + +class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): + """ + If the transpose/permute op does not change the byte order (e.g., + transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced + by view op. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor and shape + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape + # Get the output tensor shape + out_shape = node.meta["val"].shape + + if node.target == exir_ops.edge.aten.transpose_copy.int: + # Get the two dims to be transposed + dim0 = cast(int, node.args[1]) + dim1 = cast(int, node.args[2]) + dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 + dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 + # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; + # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. + both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 + either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( + in_shape[dim0] == 1 or in_shape[dim1] == 1 + ) + if both_one or either_one_and_consecutive: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + elif node.target == exir_ops.edge.aten.permute_copy.default: + old_dims = list(range(len(in_shape))) + new_dims = cast(Sequence[int], node.args[1]) + # If the permute does not change anything, return the input as output. + if old_dims == list(new_dims): + node.replace_all_uses_with(in_tensor_node) + return True + # Get the old dim order, and the permuted dim order for all dims that + # are not 1. + old_order = [ + dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 + ] + new_order = [ + dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 + ] + # If the byte ordering for non-unit dims is unchanged, this is a nop. + if old_order == new_order: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + return False diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 463c89e43b2..5c3343469ce 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -268,3 +268,114 @@ def define_common_targets(): "//executorch/exir/tests:test_memory_format_ops_pass_utils", ], ) + + # Shared permute optimization passes (used by both Cadence and Arm backends) + runtime.python_library( + name = "permute_pass_utils", + srcs = ["permute_pass_utils.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "fuse_cascaded_transpose_or_permute_ops", + srcs = ["fuse_cascaded_transpose_or_permute_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "fuse_cascaded_view_ops", + srcs = ["fuse_cascaded_view_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "fuse_transpose_or_permute_op_pairs_pass", + srcs = ["fuse_transpose_or_permute_op_pairs_pass.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "remove_permutes_around_elementwise_ops", + srcs = ["remove_permutes_around_elementwise_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "postpone_permute_below_squeeze_view", + srcs = ["postpone_permute_below_squeeze_view.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "replace_nop_transpose_or_permute_with_view", + srcs = ["replace_nop_transpose_or_permute_with_view.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_test( + name = "test_permute_optimization_passes", + srcs = [ + "test/test_permute_optimization_passes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/test:graph_builder", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":fuse_cascaded_transpose_or_permute_ops", + ":fuse_cascaded_view_ops", + ":postpone_permute_below_squeeze_view", + ":replace_nop_transpose_or_permute_with_view", + ], + ) diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py new file mode 100644 index 00000000000..bb326f125bc --- /dev/null +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest +from typing import cast + +import torch +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.fuse_cascaded_view_ops import FuseCascadedViewOps +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.utils import _pytree as pytree + + +def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int: + """Count the number of nodes with target `target` in the graph.""" + total = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + total += 1 + return total + + +def validate_numerics( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs.""" + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}." + ) + + +def get_compute_nodes( + graph_module: torch.fx.GraphModule, +) -> list: + """Return the target of each call_function node in order.""" + return [ + n.target + for n in graph_module.graph.nodes + if n.op == "call_function" + and n.target + not in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel.default, + ) + ] + + +# ────────────────────────────────────────────────────────────────────── +# Tests for FuseCascadedTransposeOrPermuteOps +# ────────────────────────────────────────────────────────────────────── + + +class FuseCascadedTransposeOrPermuteOpsTest(unittest.TestCase): + def test_permute_transpose_fusion(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4)) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3]) + ) + transpose = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, args=(permute, 1, 0) + ) + builder.output([transpose]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedTransposeOrPermuteOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) + validate_numerics( + gm_before, + gm, + [torch.randn(3, 1, 3, 1, 4)], + "FuseCascadedTransposeOrPermuteOps", + ) + + def test_cascaded_permutes_multiple_users(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permute1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + # permute2 reverses permute1 => identity + permute2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 3, 1, 2]) + ) + # permute3: different permutation + permute3 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 2, 1, 3]) + ) + # permute4 -> permute5: chained + permute4 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [3, 2, 0, 1]) + ) + permute5 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute4, [0, 1, 3, 2]) + ) + builder.output([permute2, permute3, permute5]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedTransposeOrPermuteOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + validate_numerics( + gm_before, + result.graph_module, + [torch.randn(2, 3, 4, 5)], + "FuseCascadedTransposeOrPermuteOps", + ) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for FuseCascadedViewOps +# ────────────────────────────────────────────────────────────────────── + + +class FuseCascadedViewOpsTest(unittest.TestCase): + def test_view_fusion(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3)) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(v1, [1, 1, 120]) + ) + v3 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(v2, [120]) + ) + builder.output([v3]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedViewOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, + gm, + [torch.randn(8, 5, 3)], + "FuseCascadedViewOps", + ) + + def test_view_fusion_branched(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3)) + y = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + branch1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120]) + ) + branch2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1]) + ) + builder.output([branch1, branch2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedViewOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + validate_numerics( + gm_before, + gm, + [torch.randn(8, 5, 3)], + "FuseCascadedViewOps", + ) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView +# ────────────────────────────────────────────────────────────────────── + + +class PostponePermuteBelowSqueezeViewTest(unittest.TestCase): + def test_permute3_view4_chains(self) -> None: + """view→permute→view→permute reordered to view→view→permute→permute.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 12, 3, 64]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [0, 1, 3, 2]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertTrue(result.modified) + gm = result.graph_module + gm.graph.eliminate_dead_code() + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2) + # Verify order: views before permutes + targets = get_compute_nodes(gm) + view_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.view_copy.default + ] + permute_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.permute_copy.default + ] + self.assertTrue(all(v < p for v in view_indices for p in permute_indices)) + + validate_numerics( + gm_before, + gm, + [x_data], + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + + def test_permute4_view3_chains(self) -> None: + """4d→permute→view→3d→permute reordered to view→view→permute→permute.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [3, 1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [64, 3, 12]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [2, 1, 0]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertTrue(result.modified) + gm = result.graph_module + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2) + targets = get_compute_nodes(gm) + view_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.view_copy.default + ] + permute_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.permute_copy.default + ] + self.assertTrue(all(v < p for v in view_indices for p in permute_indices)) + + validate_numerics( + gm_before, + gm, + [x_data], + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + + def test_negative_not_squeeze_like(self) -> None: + """View that reshapes (not just squeeze/unsqueeze) should NOT be reordered.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [3, 1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [64, 6, 6]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [2, 1, 0]) + ) + builder.output([p2]) + original = builder.get_graph_module() + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertFalse(result.modified) + + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 2 + ) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), + 2, + ) + # Order unchanged: view, permute, view, permute + targets = get_compute_nodes(result.graph_module) + self.assertEqual(targets[0], exir_ops.edge.aten.view_copy.default) + self.assertEqual(targets[1], exir_ops.edge.aten.permute_copy.default) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for ReplaceNopTransposeOrPermuteWithViewPass +# ────────────────────────────────────────────────────────────────────── + + +class ReplaceNopTransposeOrPermuteWithViewTest(unittest.TestCase): + def test_replace_nop_transpose_with_view_float(self) -> None: + x = torch.randn(2, 1, 3, 1) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 1, 3), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_transpose_with_view_int(self) -> None: + x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 1, 0), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_5d(self) -> None: + x = torch.randn(3, 1, 3, 1, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 4, 1, 3]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_3d(self) -> None: + x = torch.randn(1, 3, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 2, 0]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + )