Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 33 additions & 100 deletions backends/arm/test/misc/test_high_rank_permute_view_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,31 @@
# LICENSE file in the root directory of this source tree.

import random
from pathlib import Path
from typing import Any
from dataclasses import dataclass
from typing import Any, Tuple

import torch
import torch.nn as nn

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec


InputT = Tuple[Any, ...]


class HighRankPermuteViewModel(torch.nn.Module):
def __init__(self, ops: list[tuple[str, Any]]):
super().__init__()
self.ops = ops
self.block = nn.Sequential(
nn.Conv2d(
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
),
nn.ReLU(),
torch.nn.ReLU(),
)

def forward(self, x):
Expand All @@ -41,6 +43,13 @@ def forward(self, x):
return x


@dataclass(frozen=True)
class TransposeInvariantCase:
module: torch.nn.Module
inputs: InputT
expected_transposes: int


def _random_non_identity_permutation(
rng: random.Random, rank: int
) -> tuple[int, ...] | None:
Expand Down Expand Up @@ -130,7 +139,6 @@ def _generate_chain(
shape = new_shape
break

# Ensure each case has at least one rank>4 permute.
while len(shape) <= 4:
new_shape = _reshape_add_singleton(rng, shape)
if new_shape is None:
Expand All @@ -146,108 +154,33 @@ def _generate_chain(
return ops


def _build_cases() -> dict[str, HighRankPermuteViewModel]:
def _build_high_rank_permute_cases() -> dict[str, TransposeInvariantCase]:
rng = random.Random(
20260225
) # nosec B311: deterministic RNG for test case generation
start_shape = [1, 16, 16, 64] # conv output from input 1x3x32x32 after NHWC permute
cases: dict[str, HighRankPermuteViewModel] = {}
start_shape = [1, 16, 16, 64]
expected_transpose_counts = [6, 11, 10, 10, 7, 7, 10, 10, 8, 10]
cases: dict[str, TransposeInvariantCase] = {}
for idx in range(10):
ops = _generate_chain(rng, start_shape, steps=8)
cases[f"fuzz_case_{idx}"] = HighRankPermuteViewModel(ops)
cases[f"high_rank_permute_fuzz_case_{idx}"] = TransposeInvariantCase(
module=HighRankPermuteViewModel(ops).eval(),
inputs=(torch.randn(1, 3, 32, 32),),
expected_transposes=expected_transpose_counts[idx],
)
return cases


def _run_model(model: torch.nn.Module, out_dir: str) -> Path:
sample = torch.randn(1, 3, 32, 32)
pipeline = TosaPipelineINT[tuple[torch.Tensor]](
model.eval(),
(sample,),
@common.parametrize("case", _build_high_rank_permute_cases())
def test_transpose_invariants_tosa_INT_high_rank_permute_view(
case: TransposeInvariantCase,
) -> None:
pipeline = TosaPipelineINT[InputT](
case.module,
case.inputs,
aten_op=[],
exir_op=[],
run_on_tosa_ref_model=False,
custom_path=out_dir,
tosa_debug_mode=TosaCompileSpec.DebugMode.JSON,
tosa_extensions=["int16", "int4", "cf"],
)
pipeline.count_tosa_ops({"TRANSPOSE": case.expected_transposes})
pipeline.run()
Comment on lines +174 to 186

tosa_files = sorted(Path(out_dir).glob("*.tosa"))
assert tosa_files, f"No TOSA artifacts found in {out_dir}"
return tosa_files[0]


def _assert_transpose_invariants(tosa_path: Path) -> int:
import tosa.Op as Op # type: ignore[import-not-found,import-untyped]
from tosa.TosaGraph import ( # type: ignore[import-not-found,import-untyped]
TosaGraph,
)
from tosa.TransposeAttribute import ( # type: ignore[import-not-found,import-untyped]
TransposeAttribute,
)

graph = TosaGraph.GetRootAs(tosa_path.read_bytes(), 0)
block = graph.Regions(0).Blocks(0)

shape_by_name = {
block.Tensors(i).Name().decode(): list(block.Tensors(i).ShapeAsNumpy())
for i in range(block.TensorsLength())
}

op_enum = Op.Op()
op_value_to_name = {
getattr(op_enum, name): name for name in dir(op_enum) if name.isupper()
}

high_rank_transpose_count = 0
for i in range(block.OperatorsLength()):
op = block.Operators(i)
if op_value_to_name.get(op.Op()) != "TRANSPOSE":
continue

inputs = [op.Inputs(j).decode() for j in range(op.InputsLength())]
outputs = [op.Outputs(j).decode() for j in range(op.OutputsLength())]
assert len(inputs) == 1 and len(outputs) == 1, (
f"Unexpected TRANSPOSE arity at op #{i}: "
f"{len(inputs)} inputs, {len(outputs)} outputs"
)

attr_tbl = op.Attribute()
transpose_attr = TransposeAttribute()
transpose_attr.Init(attr_tbl.Bytes, attr_tbl.Pos)
perms = [int(perm) for perm in transpose_attr.PermsAsNumpy()]

in_shape = [int(v) for v in shape_by_name[inputs[0]]]
out_shape = [int(v) for v in shape_by_name[outputs[0]]]

rank = len(in_shape)
assert (
len(perms) == rank
), f"Invalid TRANSPOSE rank at op #{i}: len(perms)={len(perms)} rank={rank}"
assert sorted(perms) == list(
range(rank)
), f"Invalid TRANSPOSE permutation at op #{i}: perms={perms}, rank={rank}"
expected_out_shape = [in_shape[perm] for perm in perms]
assert expected_out_shape == out_shape, (
f"Invalid TRANSPOSE shape mapping at op #{i}: "
f"perms={perms}, in_shape={in_shape}, out_shape={out_shape}, "
f"expected_out_shape={expected_out_shape}"
)
if rank > 4:
high_rank_transpose_count += 1

return high_rank_transpose_count


@common.parametrize("model", _build_cases())
def test_high_rank_permute_view_tosa_INT_transpose_invariants(
model: torch.nn.Module, tmp_path
):
out_dir = tmp_path / "high_rank_permute_view_fuzz"
out_dir.mkdir(parents=True, exist_ok=True)
tosa_path = _run_model(model, str(out_dir))
assert tosa_path.exists(), f"Missing TOSA dump: {tosa_path}"
high_rank_count = _assert_transpose_invariants(tosa_path)
assert (
high_rank_count > 0
), "Expected at least one rank>4 TRANSPOSE in generated case."
Loading
Loading