-
Notifications
You must be signed in to change notification settings - Fork 954
NXP backend: added support for aten.conv_transpose1 and refactored convolution_converter
#19004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,259 @@ | ||||||||
| # Copyright 2026 NXP | ||||||||
| # | ||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||
|
|
||||||||
| import torch | ||||||||
| from torch._subclasses import FakeTensor, FakeTensorMode | ||||||||
| from torch.fx import GraphModule, Node | ||||||||
| from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||||||||
|
|
||||||||
|
|
||||||||
| Conv1dArgs = tuple[Node, Node, (Node | None), list[int], list[int], list[int], int] | ||||||||
| Conv1dTranspArgs = tuple[ | ||||||||
| Node, Node, (Node | None), list[int], list[int], list[int], int, list[int] | ||||||||
| ] | ||||||||
|
|
||||||||
|
|
||||||||
| class ConvertConv1dToConv2dPass(PassBase): | ||||||||
| r""" | ||||||||
| The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by | ||||||||
| inserting a singleton spatial dimension and then removing it again. | ||||||||
|
|
||||||||
| x W x W | ||||||||
| [N, C1, H1] [I/O, I/O, k] [N, C1, H1] [I/O, I/O, k] | ||||||||
| │ │ │ │ | ||||||||
| │ │ ┌────────▼─────────┐ ┌─────────▼────────┐ | ||||||||
| │ │ │ unsqueeze(x, 2) │ │ unsqueeze(x, 2) │ | ||||||||
| │ │ └────────▼─────────┘ └─────────▼────────┘ | ||||||||
| │ │ │ │ | ||||||||
| │ │ [N, C1, 1, H1] [I/O, I/O, 1, k] | ||||||||
| │ │ │ │ | ||||||||
| └────────┐ ┌────────┘ └──────────┐ ┌──────────┘ | ||||||||
| │ │ │ │ | ||||||||
| ┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐ | ||||||||
| │ convolution ◄──B [O] replace │ convolution ◄──B [O] | ||||||||
| │ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │ | ||||||||
| └────────────┬───────────┘ with └───────────┬───────────┘ | ||||||||
| │ │ | ||||||||
| │ [N, C2, 1, H2] | ||||||||
| │ │ | ||||||||
| │ ┌────────▼─────────┐ | ||||||||
| │ │ squeeze(x, 2) │ | ||||||||
| │ └────────┬─────────┘ | ||||||||
| │ │ | ||||||||
| ▼ ▼ | ||||||||
| [N, C2, H2] [N, C2, H2] | ||||||||
| y y | ||||||||
| """ | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _is_conv_1d(node: Node) -> bool: | ||||||||
| return node.target == torch.ops.aten.conv1d.default | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _is_conv_transposed_1d(node: Node) -> bool: | ||||||||
| return node.target == torch.ops.aten.conv_transpose1d.default | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _listify(x: int | list[int] | tuple[int]) -> list[int]: | ||||||||
| if isinstance(x, int): | ||||||||
| return [x] | ||||||||
|
|
||||||||
| return list(x) | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _get_node_shape(node: Node): | ||||||||
| return node.meta["val"].shape if hasattr(node, "meta") else node.shape | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _get_node_dtype(node: Node): | ||||||||
| return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype | ||||||||
|
|
||||||||
| def _create_some_conv_2d_node(self, target, *conv_args): | ||||||||
| # some_conv_2d_node = could be regular 2d conv or transposed 2d conv | ||||||||
| some_conv_node = self.graph_module.graph.call_function(target, conv_args) | ||||||||
| some_conv_node.meta["source_fn_stack"] = [(some_conv_node.name, target)] | ||||||||
|
|
||||||||
| # take out the bias node argument if bias=False, cannot calculate fake tensor for None | ||||||||
| has_b_node = len(conv_args) >= 3 and conv_args[2] is not None | ||||||||
| if has_b_node: | ||||||||
| node_args = conv_args[:3] | ||||||||
| scalar_args = conv_args[3:] | ||||||||
| else: | ||||||||
| node_args = conv_args[:2] | ||||||||
| scalar_args = conv_args[2:] | ||||||||
|
|
||||||||
| with FakeTensorMode() as mode: | ||||||||
| node_arg_shapes = [self._get_node_shape(arg) for arg in node_args] | ||||||||
| node_arg_dtypes = [self._get_node_dtype(arg) for arg in node_args] | ||||||||
| fake_node_args = [ | ||||||||
| FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) | ||||||||
| for shape, dtype in zip(node_arg_shapes, node_arg_dtypes) | ||||||||
| ] | ||||||||
|
|
||||||||
| # insert back the bias node argument (= None) if it was taken out earlier | ||||||||
| node_args = fake_node_args if has_b_node else fake_node_args + [None] | ||||||||
|
Comment on lines
+95
to
+96
|
||||||||
| # insert back the bias node argument (= None) if it was taken out earlier | |
| node_args = fake_node_args if has_b_node else fake_node_args + [None] | |
| # scalar_args already preserves the original bias position when bias is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_node_shape/_get_node_dtypecheckhasattr(node, "meta"), buttorch.fx.Nodealways has ametadict; ifmeta["val"]is missing this will raise aKeyErrorand the fallback tonode.shape/node.dtypewill never be used. Consider checking for"val" in node.meta(ornode.meta.get("val")) and only indexing when present.