From 8568135ef77acbaeee77b8a5c09a984d85c0a7b8 Mon Sep 17 00:00:00 2001 From: Vaclav Novak Date: Mon, 20 Apr 2026 18:26:48 +0200 Subject: [PATCH] feat: added transposed conv 1d support + refactored conv_converter --- .../nxp/aten_passes/convert_1d_conv_to_2d.py | 259 ++++++++++ .../aten_passes/neutron_aten_pass_manager.py | 4 + .../ops_converters/convolution_converter.py | 104 +--- backends/nxp/quantizer/neutron_quantizer.py | 2 - backends/nxp/quantizer/patterns.py | 42 +- backends/nxp/quantizer/utils.py | 26 + .../node_converter/test_conv_converter.py | 214 +-------- backends/nxp/tests/models.py | 64 ++- .../nxp/tests/test_convert_1d_conv_to_2d.py | 447 ++++++++++++++++++ 9 files changed, 827 insertions(+), 335 deletions(-) create mode 100644 backends/nxp/aten_passes/convert_1d_conv_to_2d.py create mode 100644 backends/nxp/tests/test_convert_1d_conv_to_2d.py diff --git a/backends/nxp/aten_passes/convert_1d_conv_to_2d.py b/backends/nxp/aten_passes/convert_1d_conv_to_2d.py new file mode 100644 index 00000000000..9ec1b8fde0b --- /dev/null +++ b/backends/nxp/aten_passes/convert_1d_conv_to_2d.py @@ -0,0 +1,259 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +Conv1dArgs = tuple[Node, Node, (Node | None), list[int], list[int], list[int], int] +Conv1dTranspArgs = tuple[ + Node, Node, (Node | None), list[int], list[int], list[int], int, list[int] +] + + +class ConvertConv1dToConv2dPass(PassBase): + r""" + The NXP backend supports only 2D convolutions. Rewrite 1D convolutions into an equivalent 2D form by + inserting a singleton spatial dimension and then removing it again. + + x W x W + [N, C1, H1] [I/O, I/O, k] [N, C1, H1] [I/O, I/O, k] + │ │ │ │ + │ │ ┌────────▼─────────┐ ┌─────────▼────────┐ + │ │ │ unsqueeze(x, 2) │ │ unsqueeze(x, 2) │ + │ │ └────────▼─────────┘ └─────────▼────────┘ + │ │ │ │ + │ │ [N, C1, 1, H1] [I/O, I/O, 1, k] + │ │ │ │ + └────────┐ ┌────────┘ └──────────┐ ┌──────────┘ + │ │ │ │ + ┌────────▼───────▼───────┐ ┌────────▼─────▼────────┐ + │ convolution ◄──B [O] replace │ convolution ◄──B [O] + │ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │ + └────────────┬───────────┘ with └───────────┬───────────┘ + │ │ + │ [N, C2, 1, H2] + │ │ + │ ┌────────▼─────────┐ + │ │ squeeze(x, 2) │ + │ └────────┬─────────┘ + │ │ + ▼ ▼ + [N, C2, H2] [N, C2, H2] + y y + """ + + @staticmethod + def _is_conv_1d(node: Node) -> bool: + return node.target == torch.ops.aten.conv1d.default + + @staticmethod + def _is_conv_transposed_1d(node: Node) -> bool: + return node.target == torch.ops.aten.conv_transpose1d.default + + @staticmethod + def _listify(x: int | list[int] | tuple[int]) -> list[int]: + if isinstance(x, int): + return [x] + + return list(x) + + @staticmethod + def _get_node_shape(node: Node): + return node.meta["val"].shape if hasattr(node, "meta") else node.shape + + @staticmethod + def _get_node_dtype(node: Node): + return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype + + def _create_some_conv_2d_node(self, target, *conv_args): + # some_conv_2d_node = could be regular 2d conv or transposed 2d conv + some_conv_node = self.graph_module.graph.call_function(target, conv_args) + some_conv_node.meta["source_fn_stack"] = [(some_conv_node.name, target)] + + # take out the bias node argument if bias=False, cannot calculate fake tensor for None + has_b_node = len(conv_args) >= 3 and conv_args[2] is not None + if has_b_node: + node_args = conv_args[:3] + scalar_args = conv_args[3:] + else: + node_args = conv_args[:2] + scalar_args = conv_args[2:] + + with FakeTensorMode() as mode: + node_arg_shapes = [self._get_node_shape(arg) for arg in node_args] + node_arg_dtypes = [self._get_node_dtype(arg) for arg in node_args] + fake_node_args = [ + FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) + for shape, dtype in zip(node_arg_shapes, node_arg_dtypes) + ] + + # insert back the bias node argument (= None) if it was taken out earlier + node_args = fake_node_args if has_b_node else fake_node_args + [None] + output = target(*fake_node_args, *scalar_args) + + some_conv_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(output.shape, dtype=output.dtype), mode + ) + + return some_conv_node + + def _create_sq_or_unsq_node(self, target, *sq_or_unsq_args) -> Node: + sq_or_unsq_node = self.graph_module.graph.call_function(target, sq_or_unsq_args) + + sq_or_unsq_node.meta["source_fn_stack"] = [(sq_or_unsq_node.name, target)] + with FakeTensorMode() as mode: + inp_node = sq_or_unsq_args[0] + fake_input = FakeTensor.from_tensor( + torch.empty( + self._get_node_shape(inp_node), dtype=self._get_node_dtype(inp_node) + ), + mode, + ) + + output = target(fake_input, *sq_or_unsq_args[1:]) + sq_or_unsq_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(output.shape, dtype=output.dtype), mode + ) + + return sq_or_unsq_node + + @staticmethod + def _get_conv_1d_transp_args(node: Node): + args = node.args + listify_fn = ConvertConv1dToConv2dPass._listify + + b_node = None if len(args) < 3 else args[2] + stride = [1] if len(args) < 4 else listify_fn(args[3]) + padding = [0] if len(args) < 5 else listify_fn(args[4]) + output_padding = [0] if len(args) < 6 else listify_fn(args[5]) + groups = 1 if len(args) < 7 else args[6] + dilation = [1] if len(args) < 8 else listify_fn(args[7]) + + return ( + args[0], + args[1], + b_node, + stride, + padding, + output_padding, + groups, + dilation, + ) + + @staticmethod + def _get_conv_1d_args(node: Node) -> Conv1dArgs: + args = node.args + listify_fn = ConvertConv1dToConv2dPass._listify + + b_node = None if len(args) < 3 else args[2] + stride = [1] if len(args) < 4 else listify_fn(args[3]) + padding = [0] if len(args) < 5 else listify_fn(args[4]) + dilation = [1] if len(args) < 6 else listify_fn(args[5]) + groups = 1 if len(args) < 7 else args[6] + + return args[0], args[1], b_node, stride, padding, dilation, groups + + def _convert_scalar_1d_args_to_2d(self, old_1d_node: Node): + if self._is_conv_transposed_1d(old_1d_node): + _, _, _, stride, pad, output_pad, groups, dil = ( + self._get_conv_1d_transp_args(old_1d_node) + ) + + # conversion of 1d args to 2d, ie. padding with default values + stride = [1] + stride + pad = [0] + pad + output_pad = [0] + output_pad + dil = [1] + dil + + return stride, pad, output_pad, groups, dil + + else: + _, _, _, stride, pad, dil, groups = self._get_conv_1d_args(old_1d_node) + + # conversion of 1d args to 2d, ie. padding with default values + stride = [1] + stride + pad = [0] + pad + dil = [1] + dil + + return stride, pad, dil, groups + + def _convert_node_1d_args_to_2d(self, old_1d_node: Node): + if self._is_conv_transposed_1d(old_1d_node): + input_node, w_node, b_node, _, _, _, _, _ = self._get_conv_1d_transp_args( + old_1d_node + ) + else: + input_node, w_node, b_node, _, _, _, _ = self._get_conv_1d_args(old_1d_node) + + with self.graph_module.graph.inserting_before(old_1d_node): + unsqueeze_target = torch.ops.aten.unsqueeze.default + + # weights = [i/o, i/o, k] => [i/o, i/o, 1, k] + w_unsq_args = (w_node, 2) + w_unsq_node = self._create_sq_or_unsq_node(unsqueeze_target, *w_unsq_args) + + # input = [n, c, h] => [n, c, 1, h] + inp_unsq_args = (input_node, 2) + inp_unsq_node = self._create_sq_or_unsq_node( + unsqueeze_target, *inp_unsq_args + ) + + return (inp_unsq_node, w_unsq_node, b_node) + + def call(self, graph_module: GraphModule) -> PassResult: + self.graph_module = graph_module + made_changes = False + + for node in list(graph_module.graph.nodes): + is_conv_1d = self._is_conv_1d(node) + is_conv_1d_transp = self._is_conv_transposed_1d(node) + + # some_1d_conv = regular 1d conv or 1d transposed conv + is_some_1d_conv = is_conv_1d or is_conv_1d_transp + if not is_some_1d_conv: + continue + + # invalid number of args + if len(node.args) < 2: + continue + + old_1d_node = node + + # get input, weight and bias arguments for the new 2d conv + node_args = self._convert_node_1d_args_to_2d(old_1d_node) + # get stride, padding etc. arguments for the new 2d conv + scalar_args = self._convert_scalar_1d_args_to_2d(old_1d_node) + + new_2d_target = ( + torch.ops.aten.conv_transpose2d.input + if is_conv_1d_transp + else torch.ops.aten.conv2d.default + ) + + # create the new conv 2d and unsqueeze the input and weights + with self.graph_module.graph.inserting_before(old_1d_node): + new_2d_args = node_args + scalar_args + new_2d_node = self._create_some_conv_2d_node( + new_2d_target, *new_2d_args + ) + + # the original 1d conv output shape must be retained, thus insert squeeze + with self.graph_module.graph.inserting_after(new_2d_node): + squeeze_target = torch.ops.aten.squeeze.dim + + out_sq_args = (new_2d_node, 2) + out_sq_node = self._create_sq_or_unsq_node(squeeze_target, *out_sq_args) + + old_1d_node.replace_all_uses_with(out_sq_node) + graph_module.graph.erase_node(old_1d_node) + + made_changes = True + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + return PassResult(graph_module, made_changes) diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 703a8cf03a5..4f1ff2648aa 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -7,6 +7,9 @@ import torch +from executorch.backends.nxp.aten_passes.convert_1d_conv_to_2d import ( + ConvertConv1dToConv2dPass, +) from executorch.backends.nxp.aten_passes.convert_div_to_mul import ConvertDivToMulPass from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import ( DecomposeSplitToSlicesPass, @@ -49,6 +52,7 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas FuseLinearAndAddPass(), MoveActivationBeforeConcat(neutron_target_spec), ConvertDivToMulPass(), + ConvertConv1dToConv2dPass(), ] if not qat_mode: diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 148b90a331e..5fa994be7ae 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -15,7 +15,6 @@ from executorch.backends.nxp.backend.ir.converter.conversion import ( aten_translator, common, - translator, ) from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( @@ -42,7 +41,6 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( conv_2d_options, depthwise_conv_2d_options, - reshape_options, transpose_conv_options, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec @@ -70,8 +68,9 @@ def _is_supported_on_target( return False if conv_params.transposed: - # TransposeConv1d is not supported on Neutron - if len(conv_params.dilation) == 1: + # TransposeConv2d with groups > 1 is not supported + # TODO: split into multiple convs with groups = 1 + if conv_params.groups > 1: return False if not node_is_effectively_static_tensor(weights, parameters_mapping): # Only supported if the weights are static, because TFLite `TransposeConv` uses permuted @@ -187,99 +186,6 @@ def _get_convolution_arguments( groups, ) - def _convert_1d_conv( - self, t_op: tflite_model.Operator, conv_params: ConvParameters - ) -> list[tflite_model.Operator]: - """Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'. - TFLite doesn't support 1D convolution, but this behaviour can be represented using - Reshape -> Conv2D -> Reshape. - The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension. - """ - # -- Calculate the shapes for equivalent 2D convolution -- - conv_2d_input_shape = translator.nhc_dimensions_to_nhwc( - t_op.tmp_inputs[0].shape.vector - ) - conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc( - t_op.tmp_inputs[1].shape.vector - ) - conv_2d_output_shape = translator.nhc_dimensions_to_nhwc( - t_op.tmp_outputs[0].shape.vector - ) - - # -- Generate tensors taking part in the conversion -- - reshape1_input = t_op.tmp_inputs[0] - - reshape1_output = self.builder.duplicate_tensor( - reshape1_input, name_suffix="_4D_" - ) - reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape) - - reshape2_input = self.builder.duplicate_tensor( - t_op.tmp_outputs[0], name_suffix="_4D_" - ) - reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape) - - reshape2_output = t_op.tmp_outputs[0] - - pre_reshapes = [] - - # Extend the weights tensor to 4D - weights_tensor = t_op.tmp_inputs[1] - if tensor_has_data(weights_tensor): - # Do it statically - weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape) - weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape( - conv_2d_weight_shape - ) - - else: - # Add a Reshape before the weights tensor - new_weights_tensor = self.builder.duplicate_tensor( - weights_tensor, name_suffix="_4D_" - ) - new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape) - - weight_reshape = tflite_model.Operator( - builtin_options=reshape_options.Reshape(conv_2d_weight_shape) - ) - weight_reshape.tmp_inputs = [weights_tensor] - weight_reshape.tmp_outputs = [new_weights_tensor] - - pre_reshapes.append(weight_reshape) - - # Save the new weights tensor, to assign it later. - weights_tensor = new_weights_tensor - - # -- Create the new operators -- - reshape1 = tflite_model.Operator( - builtin_options=reshape_options.Reshape(conv_2d_input_shape) - ) - reshape1.tmp_inputs = [reshape1_input] - reshape1.tmp_outputs = [reshape1_output] - pre_reshapes.append(reshape1) - - reshape2 = tflite_model.Operator( - builtin_options=reshape_options.Reshape(reshape2_output.shape.vector) - ) - reshape2.tmp_inputs = [reshape2_input] - reshape2.tmp_outputs = [reshape2_output] - - # Assign the new input and output of the Conv2D - t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[ - 2: - ] # Add bias as well, if present - t_op.tmp_outputs = [reshape2_input] - - # Extend all Conv attributes to 2D - common.extend_1d_stride_to_2d(conv_params.stride) - common.extend_1d_dilation_to_2d(conv_params.dilation) - common.extend_1d_padding_to_2d(conv_params.padding) - - # Convert the now 2D Conv - converted_conv_ops = self._convert_2d_conv(t_op, conv_params) - - return pre_reshapes + converted_conv_ops + [reshape2] - # noinspection PyPep8Naming def _convert_unpadded_2D( self, t_op: tflite_model.Operator, conv_params: ConvParameters @@ -523,9 +429,7 @@ def convert(self, node: Node): ) rank = t_op.tmp_inputs[1].shape.len() - if rank == 3: # Conv1D - ops_to_add = self._convert_1d_conv(t_op, conv_params) - elif rank == 4: # Conv2D + if rank == 4: # Conv2D ops_to_add = self._convert_2d_conv(t_op, conv_params) else: raise NotImplementedError( diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 73c3167d728..6efdaf31250 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -23,7 +23,6 @@ BMMPattern, CatPattern, ClampPattern, - Conv1dPattern, Conv2dPattern, ConvTranspose2dPattern, DropoutPattern, @@ -266,7 +265,6 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False) OpQuantizer(BMMPattern(is_qat=is_qat), static_qconfig), OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig), - OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig), OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig), diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 60afa6bf4d2..961880d76fe 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -7,10 +7,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field +from functools import partial import torch -from executorch.backends.nxp.quantizer.utils import get_bias_qparams +from executorch.backends.nxp.quantizer.utils import ( + get_bias_qparams, + get_padded_bias_qparams, +) from torch import fx from torch._ops import OpOverload from torch.fx import Node @@ -482,16 +486,6 @@ def get_anchors( ) -class Conv1dPattern(ConvPattern): - def partition_types(self) -> list[OpOverload]: - return [torch.ops.aten.conv1d.default] - - -class ConvTranspose1dPattern(ConvPattern): - def partition_types(self) -> list[OpOverload]: - return [torch.ops.aten.conv_transpose1d.default] - - class Conv2dPattern(ConvPattern): def __init__(self, neutron_quantizer, is_qat: bool = False): super().__init__(is_qat=is_qat) @@ -580,12 +574,25 @@ def get_anchors( ) -> PartitionAnchors: conv_node = fused_partition[0].nodes[-1] + # When `groups` > 1, the per-channel weight qparams have shape (`out_channels` / `groups`), + # but bias qparams have shape (`out_channels`) - not divided by `groups`. + # So the weight qparams must be expanded to match the shape correctly. + groups = 1 if len(conv_node.args) < 7 else conv_node.args[6] + if groups > 1: + out_channels = conv_node.meta["val"].shape[1] + derive_qparams_fn = partial( + get_padded_bias_qparams, out_channels=out_channels + ) + + else: + derive_qparams_fn = get_bias_qparams + bias_quantization_qspec = DerivedQuantizationSpec( derived_from=[ (conv_node.args[0], conv_node), (conv_node.args[1], conv_node), ], - derive_qparams_fn=get_bias_qparams, + derive_qparams_fn=derive_qparams_fn, dtype=torch.int32, quant_min=-(2**31) + 1, quant_max=2**31 - 1, @@ -593,14 +600,21 @@ def get_anchors( ch_axis=0, ) - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + w_ch_axis = 1 + weight_observer_or_fake_quant_ctr = ( + FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, ch_axis=w_ch_axis + ) + if self.is_qat + else PerChannelMinMaxObserver.with_args(ch_axis=w_ch_axis) + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, quant_min=-127, quant_max=127, qscheme=torch.per_channel_symmetric, - ch_axis=1, + ch_axis=w_ch_axis, ) # Keep bias empty if not supplied diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index da2448fb773..fccd29b245e 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -73,6 +73,32 @@ def get_bias_qparams( return bias_scale, bias_zero_point +def get_padded_bias_qparams( + obs_or_fqs: List[ObserverOrFakeQuantize], + out_channels: int | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + act_scale, _ = obs_or_fqs[0].calculate_qparams() + weight_scale, _ = obs_or_fqs[1].calculate_qparams() + + # It may happen that `torch.ao` incorrectly sets the weight qparams, not matching bias qparams. + # If `out_channels` is given, ensure bias qparams are per-output-channel: + # So for example w = [w1, w2, w3] -> [w1, w2, w3, w1, w2, w3, ...] + if out_channels is not None: + weight_scale = weight_scale.flatten() + if weight_scale.numel() != out_channels: + if out_channels % weight_scale.numel() != 0: + raise RuntimeError( + "Weight qparams cannot be repeated if not divisible by `out_channels`." + ) + weight_scale = weight_scale.repeat(out_channels // weight_scale.numel()) + + act_scale = act_scale.flatten()[0] + + bias_scale = act_scale * weight_scale + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64) + return bias_scale, bias_zero_point + + def get_aten_node_target_partitions( graph: torch.fx.Graph, wanted_original_aten_op: List[OpOverload], diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index 785bd5cc854..5580d0ca729 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -27,7 +27,7 @@ ToChannelFirstPreprocess, ToChannelLastPreprocess, ) -from executorch.backends.nxp.tests.models import Conv1dModule, Conv2dModule +from executorch.backends.nxp.tests.models import Conv2dModule from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -39,218 +39,6 @@ def reseed_model_per_test_run(): np.random.seed(23) -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion(bias, stride, dilation, kernel_size, mocker, use_qat): - input_shape = (1, 4, 16) - model = Conv1dModule( - bias=bias, stride=stride, dilation=dilation, kernel_size=kernel_size - ) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) - - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - - # Capture converted program - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) - - convert_run_compare( - exported_program, - tflite_input_preprocess=ToChannelLastPreprocess(), - tfl_model=tflite_flatbuffers_model, - tflite_output_preprocess=ToChannelFirstPreprocess(), - input_data=input_data, - atol=1.0, - ) - - # Capture IR model ops - conversion_result = ops_spy.spy_return - ops = conversion_result.sub_graphs[0].operators.vector - - assert len(ops) == 3 - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[2].builtin_options.operator_type == BuiltinOperator.RESHAPE - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -@pytest.mark.parametrize("padding", [(1,), 2]) -def test_conv1d_quant_conversion__padded( - stride, dilation, kernel_size, padding, mocker, use_qat -): - input_shape = (1, 4, 16) - model = Conv1dModule( - stride=stride, dilation=dilation, kernel_size=kernel_size, padding=padding - ) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) - - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - - # Capture converted program - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) - - convert_run_compare( - exported_program, - tflite_input_preprocess=ToChannelLastPreprocess(), - tfl_model=tflite_flatbuffers_model, - tflite_output_preprocess=ToChannelFirstPreprocess(), - input_data=input_data, - atol=1.0, - ) - - # Capture IR model ops - conversion_result = ops_spy.spy_return - ops = conversion_result.sub_graphs[0].operators.vector - - assert len(ops) == 4 - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.PADV2 - assert ops[2].builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[3].builtin_options.operator_type == BuiltinOperator.RESHAPE - - # Make sure the padding used the `zero-point`. - pad_value = ops[1].tmp_inputs[2].tmp_buffer.data.item() - assert ( - pad_value == ops[1].tmp_inputs[0].quantization.zero_point[0] - ) # `Pad` input zp. - assert ( - pad_value == ops[1].tmp_outputs[0].quantization.zero_point[0] - ) # `Pad` output zp. - assert ( - pad_value == ops[2].tmp_inputs[0].quantization.zero_point[0] - ) # `Conv` input zp. - - -@pytest.mark.parametrize("bias", [False, True]) -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion__depthwise( - bias, stride, dilation, kernel_size, mocker, use_qat -): - input_shape = (1, 4, 16) - group = input_shape[1] - model = Conv1dModule( - bias=bias, - group=group, - in_channels=group, - out_channels=group, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - ) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) - - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - - # Capture converted program - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) - - convert_run_compare( - exported_program, - tflite_input_preprocess=ToChannelLastPreprocess(), - tfl_model=tflite_flatbuffers_model, - tflite_output_preprocess=ToChannelFirstPreprocess(), - input_data=input_data, - atol=1.0, - ) - - # Capture IR model ops - ops = ops_spy.spy_return.sub_graphs[0].operators.vector - - assert len(ops) == 3 - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D - assert ops[2].builtin_options.operator_type == BuiltinOperator.RESHAPE - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -@pytest.mark.parametrize("padding", [(1,), 2]) -def test_conv1d_quant_conversion__depthwise__padded( - stride, dilation, kernel_size, padding, mocker, use_qat -): - input_shape = (1, 4, 16) - group = input_shape[1] - model = Conv1dModule( - group=group, - in_channels=group, - out_channels=group, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - padding=padding, - ) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) - - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - - # Capture converted program - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) - - convert_run_compare( - exported_program, - tflite_input_preprocess=ToChannelLastPreprocess(), - tfl_model=tflite_flatbuffers_model, - tflite_output_preprocess=ToChannelFirstPreprocess(), - input_data=input_data, - atol=1.0, - ) - - # Capture IR model ops - ops = ops_spy.spy_return.sub_graphs[0].operators.vector - - assert len(ops) == 4 - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.PADV2 - assert ops[2].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D - assert ops[3].builtin_options.operator_type == BuiltinOperator.RESHAPE - - # Make sure the padding used the `zero-point`. - pad_value = ops[1].tmp_inputs[2].tmp_buffer.data.item() - assert ( - pad_value == ops[1].tmp_inputs[0].quantization.zero_point[0] - ) # `Pad` input zp. - assert ( - pad_value == ops[1].tmp_outputs[0].quantization.zero_point[0] - ) # `Pad` output zp. - assert ( - pad_value == ops[2].tmp_inputs[0].quantization.zero_point[0] - ) # `Conv` input zp. - - @pytest.mark.parametrize( "model, input_shape", [ diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 17bea708352..2e93404e03b 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -8,20 +8,21 @@ from typing import Callable, Collection, Union import torch +import torch.nn.functional as F from torch import nn class Conv1dModule(torch.nn.Module): def __init__( self, - bias: bool = True, - dilation: Union[int, tuple[int, int]] = 1, in_channels: int = 4, - kernel_size: Union[int, tuple[int, int]] = 3, out_channels: int = 8, - padding: Union[str, int, Collection[int]] = 0, + kernel_size: Union[int, tuple[int, int]] = 3, stride: Union[int, tuple[int, int]] = 2, - group: int = 1, + padding: Union[str, int, tuple[int]] = 0, + dilation: Union[int, tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, ): super().__init__() @@ -33,13 +34,64 @@ def __init__( padding=padding, dilation=dilation, bias=bias, - groups=group, + groups=groups, ) def forward(self, x): return self.conv(x) +class ConvTranspose1dModule(torch.nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 8, + kernel_size: Union[int, tuple[int, int]] = 3, + stride: Union[int, tuple[int, int]] = 1, + padding: Union[int, tuple[int]] = 0, + output_padding: Union[int, tuple[int]] = 0, + groups: int = 1, + bias: bool = True, + dilation: Union[int, tuple[int, int]] = 1, + ): + super().__init__() + + self.conv_transp = torch.nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + ) + + def forward(self, x): + return self.conv_transp(x) + + +class Conv1dRuntimeWeightModule(torch.nn.Module): + def __init__(self, stride=1, padding=0, dilation=1, groups=1): + super().__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, x, weight, bias=None): + return F.conv1d( + x, + weight, + bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + class Conv2dModule(torch.nn.Module): def __init__( self, diff --git a/backends/nxp/tests/test_convert_1d_conv_to_2d.py b/backends/nxp/tests/test_convert_1d_conv_to_2d.py new file mode 100644 index 00000000000..21ff7cbe9b4 --- /dev/null +++ b/backends/nxp/tests/test_convert_1d_conv_to_2d.py @@ -0,0 +1,447 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + ConvertConv1dToConv2dPass, + NeutronAtenPassManager, +) + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_target_spec, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) +from executorch.backends.nxp.tests.models import ( + Conv1dModule, + Conv1dRuntimeWeightModule, + ConvTranspose1dModule, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +AtenConv1d = torch.ops.aten.conv1d.default +AtenConv2d = torch.ops.aten.conv2d.default +AtenConvTranspose1d = torch.ops.aten.conv_transpose1d.default +AtenConvTranspose2d = torch.ops.aten.conv_transpose2d.input +AtenSqueeze = torch.ops.aten.squeeze.dim +AtenUnsqueeze = torch.ops.aten.unsqueeze.default + +EdgeConvolution = exir_ops.edge.aten.convolution.default +ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate + + +@pytest.mark.parametrize( + "input_shape, kernel_size, stride, padding, dilation, groups, bias", + [ + pytest.param((3, 7, 23), 3, 1, 0, 1, 1, True, id="All default."), + pytest.param( + (3, 7, 23), 2, 1, 0, 1, 1, True, id="kernel_size=2, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 2, 0, 1, 1, True, id="stride=2, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 1, 1, 1, True, id="pad=1, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 2, 1, True, id="dilation=2, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 1, 7, True, id="group=7, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 1, 1, False, id="bias=False, otherwise all default." + ), + pytest.param((3, 7, 23), 5, 3, 2, 3, 7, False, id="Nothing is default."), + ], +) +def test_convert_conv_1d_to_conv2d( + input_shape, kernel_size, stride, padding, dilation, groups, bias +): + in_channels = input_shape[1] + out_channels = 14 + model = Conv1dModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + example_input = torch.rand(input_shape) + + exir_program_aten = torch.export.export(model, (example_input,)).module() + + # Make sure `aten.conv1d` is present. + assert graph_contains_any_of_ops(exir_program_aten.graph, [AtenConv1d]) + outputs_before = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [ConvertConv1dToConv2dPass()])( + exir_program_aten + ) + + # Make sure no `aten.conv1d` nodes are in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + AtenConv1d, + ], + ) + + # Check correct count and placement. + nodes = list(exir_program_aten.graph.nodes) + + conv_nodes = [i for i, n in enumerate(nodes) if n.target == AtenConv2d] + assert len(conv_nodes) == 1 + i = conv_nodes[0] + + assert nodes[i - 1].target == AtenUnsqueeze + assert nodes[i].target == AtenConv2d + assert nodes[i + 1].target == AtenSqueeze + + outputs_after = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "input_shape, kernel_size, stride, padding, dilation, groups, bias", + [ + pytest.param((3, 7, 23), 3, 1, 0, 1, 1, True, id="All default."), + pytest.param((3, 7, 23), 5, 3, 2, 3, 7, False, id="Nothing is default."), + ], +) +def test_convert_conv_1d_to_conv2d_runtime_weight( + input_shape, kernel_size, stride, padding, dilation, groups, bias +): + in_channels = input_shape[1] + out_channels = 14 + + model = Conv1dRuntimeWeightModule( + stride=stride, padding=padding, dilation=dilation, groups=groups + ) + example_input = torch.rand(input_shape) + + # Runtime-provided weights/bias. + weight_t = torch.rand(out_channels, in_channels // groups, kernel_size) + bias_t = torch.rand(out_channels) if bias else None + + exir_program_aten = torch.export.export( + model, (example_input, weight_t, bias_t) + ).module() + + # Make sure `aten.conv1d` is present. + assert graph_contains_any_of_ops(exir_program_aten.graph, [AtenConv1d]) + outputs_before = [ + o.detach().numpy() for o in exir_program_aten(example_input, weight_t, bias_t) + ] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [ConvertConv1dToConv2dPass()])( + exir_program_aten + ) + + # Make sure no `aten.conv1d` nodes are in the model. + assert not graph_contains_any_of_ops(exir_program_aten.graph, [AtenConv1d]) + + # Check correct count and placement. + nodes = list(exir_program_aten.graph.nodes) + + conv_nodes = [i for i, n in enumerate(nodes) if n.target == AtenConv2d] + assert len(conv_nodes) == 1 + i = conv_nodes[0] + + assert nodes[i - 1].target == AtenUnsqueeze + assert nodes[i].target == AtenConv2d + assert nodes[i + 1].target == AtenSqueeze + + outputs_after = [ + o.detach().numpy() for o in exir_program_aten(example_input, weight_t, bias_t) + ] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +# Note: The first case is the default; the remaining cases are chosen to test various parameter combinations. +# To satisfy requirements for delegation, some parameters could not be chosen arbitrarily. +@pytest.mark.parametrize( + "input_shape, kernel_size, stride, padding, output_padding, groups, bias, dilation", + [ + pytest.param((3, 7, 23), 3, 1, 0, 0, 1, True, 1, id="All default."), + pytest.param( + (3, 7, 23), + 2, + 1, + 0, + 0, + 1, + True, + 1, + id="kernel_size=2, otherwise all default.", + ), + pytest.param( + (3, 7, 23), 3, 2, 0, 0, 1, True, 1, id="stride=2, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 1, 0, 1, True, 1, id="pad=1, otherwise all default." + ), + pytest.param( + (3, 7, 23), + 3, + 2, + 0, + 1, + 1, + True, + 1, + id="output_padding=1 (stride=2 - restriction from definition), otherwise all default.", + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 0, 7, True, 1, id="group=7, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 0, 1, False, 1, id="bias=False, otherwise all default." + ), + pytest.param( + (3, 7, 23), 3, 1, 0, 0, 1, True, 2, id="dilation=2, otherwise all default." + ), + pytest.param((3, 7, 23), 5, 3, 2, 1, 7, False, 3, id="Nothing is default."), + ], +) +def test_convert_conv_1d_transp_to_conv2d_transp( + input_shape, kernel_size, stride, padding, output_padding, groups, bias, dilation +): + in_channels = input_shape[1] + out_channels = 14 + model = ConvTranspose1dModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + ) + example_input = torch.rand(input_shape) + + exir_program_aten = torch.export.export(model, (example_input,)).module() + + # Make sure `aten.conv_transpose1d` is present. + assert graph_contains_any_of_ops(exir_program_aten.graph, [AtenConvTranspose1d]) + outputs_before = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [ConvertConv1dToConv2dPass()])( + exir_program_aten + ) + + # Make sure no `aten.conv_transpose1d` nodes are in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + AtenConvTranspose1d, + ], + ) + + # Check correct count and placement. + nodes = list(exir_program_aten.graph.nodes) + + conv_nodes = [i for i, n in enumerate(nodes) if n.target == AtenConvTranspose2d] + assert len(conv_nodes) == 1 + i = conv_nodes[0] + + assert nodes[i - 1].target == AtenUnsqueeze + assert nodes[i].target == AtenConvTranspose2d + assert nodes[i + 1].target == AtenSqueeze + + outputs_after = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +# Note: The first case is the default; the remaining cases are chosen to test various parameter combinations. +# To satisfy requirements for delegation, some parameters could not be chosen arbitrarily. +@pytest.mark.parametrize("use_qat", [True, False]) +@pytest.mark.parametrize( + "kernel_size, stride, padding, dilation, groups, bias", + [ + pytest.param(3, 1, 1, 1, 1, True, id="All default, except for padding = 1."), + pytest.param(1, 1, 0, 1, 1, True, id="kernel_size = 1"), + pytest.param(3, 2, 5, 1, 1, True, id="stride = 2"), + pytest.param(3, 1, 2, 2, 1, True, id="dilation = 2"), + pytest.param(3, 1, 1, 1, 1, False, id="bias = False, padding = 1"), + ], +) +def test_convert_conv_1d_to_conv2d_full_pipeline( + mocker, kernel_size, stride, padding, dilation, groups, bias, use_qat +): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + input_shape = (1, 8, 24) + in_channels = input_shape[1] + out_channels = 16 + + model = Conv1dModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + delegated_ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure no `conv1d` nodes are in the model. + assert not graph_contains_any_of_ops( + delegated_ep.graph, + [ + AtenConv1d, + ], + ) + + # Check correct count and placement. + nodes = list(delegated_ep.graph.nodes) + assert len(nodes) == 7 + assert nodes[3].target == ExecutorchDelegateCall + + # Capture generated model. + neutron_ir_model = converter_spy.spy_return[0] + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Make sure `edge.aten.convolution.default` is in the model. + assert graph_contains_any_of_ops( + exported_program.graph, + [EdgeConvolution], + ) + + example_input = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data=example_input, + tfl_model=neutron_ir_model, + ) + + +# Note: The first case is the default; the remaining cases are chosen to test various parameter combinations. +# To satisfy requirements for delegation, some parameters could not be chosen arbitrarily. +@pytest.mark.parametrize("use_qat", [False, True]) +@pytest.mark.parametrize( + "kernel_size, stride, padding, output_padding, groups, bias, dilation", + [ + pytest.param(2, 2, 0, 0, 1, True, 1, id="All default."), + pytest.param(4, 2, 1, 0, 1, True, 1, id="kernel_size = 4 (and padding = 1)"), + pytest.param(4, 4, 0, 0, 1, True, 1, id="stride = 4 (and kernel_size = 4)"), + pytest.param( + 4, + 4, + 1, + 2, + 1, + True, + 1, + id="output_padding = 2 (and kernel_size = 4, stride = 4, padding = 1)", + ), + pytest.param(2, 2, 0, 0, 1, False, 1, id="bias=False"), + ], +) +def test_convert_conv_1d_to_conv2d_transp_full_pipeline( + mocker, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + use_qat, +): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + input_shape = (1, 8, 24) + in_channels = input_shape[1] + out_channels = 16 + model = ConvTranspose1dModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + ) + + # Run conversion. + delegated_ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure no `aten.conv_transpose1d` nodes are in the model. + assert not graph_contains_any_of_ops( + delegated_ep.graph, + [AtenConvTranspose1d], + ) + + # Check correct count and placement. + nodes = list(delegated_ep.graph.nodes) + assert len(nodes) == 7 + assert nodes[3].target == ExecutorchDelegateCall + + # Capture generated model. + neutron_ir_model = converter_spy.spy_return[0] + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Make sure `edge.aten.convolution.default` is in the model. + assert graph_contains_any_of_ops( + exported_program.graph, + [EdgeConvolution], + ) + + example_input = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data=example_input, + tfl_model=neutron_ir_model, + )