Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions backends/nxp/aten_passes/convert_1d_conv_to_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult


Conv1dArgs = tuple[Node, Node, (Node | None), list[int], list[int], list[int], int]
Conv1dTranspArgs = tuple[
Node, Node, (Node | None), list[int], list[int], list[int], int, list[int]
]


class ConvertConv1dToConv2dPass(PassBase):
r"""
The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by
inserting a singleton spatial dimension and then removing it again.

x W x W
[N, C1, H1] [I/O, I/O, k] [N, C1, H1] [I/O, I/O, k]
│ │ │ │
│ │ ┌────────▼─────────┐ ┌─────────▼────────┐
│ │ │ unsqueeze(x, 2) │ │ unsqueeze(x, 2) │
│ │ └────────▼─────────┘ └─────────▼────────┘
│ │ │ │
│ │ [N, C1, 1, H1] [I/O, I/O, 1, k]
│ │ │ │
└────────┐ ┌────────┘ └──────────┐ ┌──────────┘
│ │ │ │
┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐
│ convolution ◄──B [O] replace │ convolution ◄──B [O]
│ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
└────────────┬───────────┘ with └───────────┬───────────┘
│ │
│ [N, C2, 1, H2]
│ │
│ ┌────────▼─────────┐
│ │ squeeze(x, 2) │
│ └────────┬─────────┘
│ │
▼ ▼
[N, C2, H2] [N, C2, H2]
y y
"""

@staticmethod
def _is_conv_1d(node: Node) -> bool:
return node.target == torch.ops.aten.conv1d.default

@staticmethod
def _is_conv_transposed_1d(node: Node) -> bool:
return node.target == torch.ops.aten.conv_transpose1d.default

@staticmethod
def _listify(x: int | list[int] | tuple[int]) -> list[int]:
if isinstance(x, int):
return [x]

return list(x)

@staticmethod
def _get_node_shape(node: Node):
return node.meta["val"].shape if hasattr(node, "meta") else node.shape

@staticmethod
def _get_node_dtype(node: Node):
return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype
Comment on lines +67 to +71
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_node_shape/_get_node_dtype check hasattr(node, "meta"), but torch.fx.Node always has a meta dict; if meta["val"] is missing this will raise a KeyError and the fallback to node.shape/node.dtype will never be used. Consider checking for "val" in node.meta (or node.meta.get("val")) and only indexing when present.

Suggested change
return node.meta["val"].shape if hasattr(node, "meta") else node.shape
@staticmethod
def _get_node_dtype(node: Node):
return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype
meta_val = node.meta.get("val") if hasattr(node, "meta") else None
return meta_val.shape if meta_val is not None else node.shape
@staticmethod
def _get_node_dtype(node: Node):
meta_val = node.meta.get("val") if hasattr(node, "meta") else None
return meta_val.dtype if meta_val is not None else node.dtype

Copilot uses AI. Check for mistakes.

def _create_some_conv_2d_node(self, target, *conv_args):
# some_conv_2d_node = could be regular 2d conv or transposed 2d conv
some_conv_node = self.graph_module.graph.call_function(target, conv_args)
some_conv_node.meta["source_fn_stack"] = [(some_conv_node.name, target)]

# take out the bias node argument if bias=False, cannot calculate fake tensor for None
has_b_node = len(conv_args) >= 3 and conv_args[2] is not None
if has_b_node:
node_args = conv_args[:3]
scalar_args = conv_args[3:]
else:
node_args = conv_args[:2]
scalar_args = conv_args[2:]

with FakeTensorMode() as mode:
node_arg_shapes = [self._get_node_shape(arg) for arg in node_args]
node_arg_dtypes = [self._get_node_dtype(arg) for arg in node_args]
fake_node_args = [
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
for shape, dtype in zip(node_arg_shapes, node_arg_dtypes)
]

# insert back the bias node argument (= None) if it was taken out earlier
node_args = fake_node_args if has_b_node else fake_node_args + [None]
Comment on lines +95 to +96
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _create_some_conv_2d_node, the local node_args = fake_node_args if has_b_node else fake_node_args + [None] is assigned but never used. This looks like leftover code from a prior version; removing it (or using it consistently) would reduce confusion when maintaining this pass.

Suggested change
# insert back the bias node argument (= None) if it was taken out earlier
node_args = fake_node_args if has_b_node else fake_node_args + [None]
# scalar_args already preserves the original bias position when bias is None

Copilot uses AI. Check for mistakes.
output = target(*fake_node_args, *scalar_args)

some_conv_node.meta["val"] = FakeTensor.from_tensor(
torch.empty(output.shape, dtype=output.dtype), mode
)

return some_conv_node

def _create_sq_or_unsq_node(self, target, *sq_or_unsq_args) -> Node:
sq_or_unsq_node = self.graph_module.graph.call_function(target, sq_or_unsq_args)

sq_or_unsq_node.meta["source_fn_stack"] = [(sq_or_unsq_node.name, target)]
with FakeTensorMode() as mode:
inp_node = sq_or_unsq_args[0]
fake_input = FakeTensor.from_tensor(
torch.empty(
self._get_node_shape(inp_node), dtype=self._get_node_dtype(inp_node)
),
mode,
)

output = target(fake_input, *sq_or_unsq_args[1:])
sq_or_unsq_node.meta["val"] = FakeTensor.from_tensor(
torch.empty(output.shape, dtype=output.dtype), mode
)

return sq_or_unsq_node

@staticmethod
def _get_conv_1d_transp_args(node: Node):
args = node.args
listify_fn = ConvertConv1dToConv2dPass._listify

b_node = None if len(args) < 3 else args[2]
stride = [1] if len(args) < 4 else listify_fn(args[3])
padding = [0] if len(args) < 5 else listify_fn(args[4])
output_padding = [0] if len(args) < 6 else listify_fn(args[5])
groups = 1 if len(args) < 7 else args[6]
dilation = [1] if len(args) < 8 else listify_fn(args[7])

return (
args[0],
args[1],
b_node,
stride,
padding,
output_padding,
groups,
dilation,
)

@staticmethod
def _get_conv_1d_args(node: Node) -> Conv1dArgs:
args = node.args
listify_fn = ConvertConv1dToConv2dPass._listify

b_node = None if len(args) < 3 else args[2]
stride = [1] if len(args) < 4 else listify_fn(args[3])
padding = [0] if len(args) < 5 else listify_fn(args[4])
dilation = [1] if len(args) < 6 else listify_fn(args[5])
groups = 1 if len(args) < 7 else args[6]

return args[0], args[1], b_node, stride, padding, dilation, groups

def _convert_scalar_1d_args_to_2d(self, old_1d_node: Node):
if self._is_conv_transposed_1d(old_1d_node):
_, _, _, stride, pad, output_pad, groups, dil = (
self._get_conv_1d_transp_args(old_1d_node)
)

# conversion of 1d args to 2d, ie. padding with default values
stride = [1] + stride
pad = [0] + pad
output_pad = [0] + output_pad
dil = [1] + dil

return stride, pad, output_pad, groups, dil

else:
_, _, _, stride, pad, dil, groups = self._get_conv_1d_args(old_1d_node)

# conversion of 1d args to 2d, ie. padding with default values
stride = [1] + stride
pad = [0] + pad
dil = [1] + dil

return stride, pad, dil, groups

def _convert_node_1d_args_to_2d(self, old_1d_node: Node):
if self._is_conv_transposed_1d(old_1d_node):
input_node, w_node, b_node, _, _, _, _, _ = self._get_conv_1d_transp_args(
old_1d_node
)
else:
input_node, w_node, b_node, _, _, _, _ = self._get_conv_1d_args(old_1d_node)

with self.graph_module.graph.inserting_before(old_1d_node):
unsqueeze_target = torch.ops.aten.unsqueeze.default

# weights = [i/o, i/o, k] => [i/o, i/o, 1, k]
w_unsq_args = (w_node, 2)
w_unsq_node = self._create_sq_or_unsq_node(unsqueeze_target, *w_unsq_args)

# input = [n, c, h] => [n, c, 1, h]
inp_unsq_args = (input_node, 2)
inp_unsq_node = self._create_sq_or_unsq_node(
unsqueeze_target, *inp_unsq_args
)

return (inp_unsq_node, w_unsq_node, b_node)

def call(self, graph_module: GraphModule) -> PassResult:
self.graph_module = graph_module
made_changes = False

for node in list(graph_module.graph.nodes):
is_conv_1d = self._is_conv_1d(node)
is_conv_1d_transp = self._is_conv_transposed_1d(node)

# some_1d_conv = regular 1d conv or 1d transposed conv
is_some_1d_conv = is_conv_1d or is_conv_1d_transp
if not is_some_1d_conv:
continue

# invalid number of args
if len(node.args) < 2:
continue

old_1d_node = node

# get input, weight and bias arguments for the new 2d conv
node_args = self._convert_node_1d_args_to_2d(old_1d_node)
# get stride, padding etc. arguments for the new 2d conv
scalar_args = self._convert_scalar_1d_args_to_2d(old_1d_node)

new_2d_target = (
torch.ops.aten.conv_transpose2d.input
if is_conv_1d_transp
else torch.ops.aten.conv2d.default
)

# create the new conv 2d and unsqueeze the input and weights
with self.graph_module.graph.inserting_before(old_1d_node):
new_2d_args = node_args + scalar_args
new_2d_node = self._create_some_conv_2d_node(
new_2d_target, *new_2d_args
)

# the original 1d conv output shape must be retained, thus insert squeeze
with self.graph_module.graph.inserting_after(new_2d_node):
squeeze_target = torch.ops.aten.squeeze.dim

out_sq_args = (new_2d_node, 2)
out_sq_node = self._create_sq_or_unsq_node(squeeze_target, *out_sq_args)

old_1d_node.replace_all_uses_with(out_sq_node)
graph_module.graph.erase_node(old_1d_node)

made_changes = True

graph_module.recompile()
graph_module.graph.eliminate_dead_code()
return PassResult(graph_module, made_changes)
4 changes: 4 additions & 0 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import torch

from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import (
ConvertConv1dToConv2dPass,
)
from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
DecomposeSplitToSlicesPass,
Expand Down Expand Up @@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
ConvertDivToMulPass(),
ConvertConv1dToConv2dPass(),
]

if not qat_mode:
Expand Down
Loading
Loading