xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_permute.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9from typing import Tuple
10
11import torch
12
13from executorch.backends.arm.quantizer.arm_quantizer import (
14    ArmQuantizer,
15    get_symmetric_quantization_config,
16)
17
18from executorch.backends.arm.test import common
19from executorch.backends.arm.test.tester.arm_tester import ArmTester
20from executorch.backends.xnnpack.test.tester.tester import Quantize
21from executorch.exir.backend.compile_spec_schema import CompileSpec
22from parameterized import parameterized
23from torchvision.ops import Permute
24
25test_data_suite = [
26    # (test_name,test_data,dims)
27    ("rank_2", torch.rand(10, 10), [1, 0]),
28    ("rank_3", torch.rand(10, 10, 10), [2, 0, 1]),
29    ("rank_3", torch.rand(10, 10, 10), [1, 2, 0]),
30    ("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
31    ("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
32    ("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
33]
34
35
36class TestPermute(unittest.TestCase):
37    """Tests Permute Operator."""
38
39    class Permute(torch.nn.Module):
40
41        def __init__(self, dims: list[int]):
42            super().__init__()
43
44            self.permute = Permute(dims=dims)
45
46        def forward(self, x):
47            return self.permute(x)
48
49    def _test_permute_tosa_MI_pipeline(
50        self,
51        module: torch.nn.Module,
52        test_data: Tuple[torch.tensor],
53        permute_memory_to_nhwc: bool,
54    ):
55        (
56            ArmTester(
57                module,
58                example_inputs=test_data,
59                compile_spec=common.get_tosa_compile_spec(
60                    "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc
61                ),
62            )
63            .export()
64            .check(["torch.ops.aten.permute.default"])
65            .check_not(["torch.ops.quantized_decomposed"])
66            .to_edge()
67            .partition()
68            .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
69            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
70            .to_executorch()
71            .run_method_and_compare_outputs(inputs=test_data)
72        )
73
74    def _test_permute_tosa_BI_pipeline(
75        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
76    ):
77        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
78        (
79            ArmTester(
80                module,
81                example_inputs=test_data,
82                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
83            )
84            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
85            .export()
86            .check_count({"torch.ops.aten.permute.default": 1})
87            .check(["torch.ops.quantized_decomposed"])
88            .to_edge()
89            .partition()
90            .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
91            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
92            .to_executorch()
93            .run_method_and_compare_outputs(inputs=test_data)
94        )
95
96    def _test_permute_ethos_BI_pipeline(
97        self,
98        module: torch.nn.Module,
99        compile_spec: CompileSpec,
100        test_data: Tuple[torch.Tensor],
101    ):
102        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
103        (
104            ArmTester(
105                module,
106                example_inputs=test_data,
107                compile_spec=compile_spec,
108            )
109            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
110            .export()
111            .check_count({"torch.ops.aten.permute.default": 1})
112            .check(["torch.ops.quantized_decomposed"])
113            .to_edge()
114            .partition()
115            .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
116            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
117            .to_executorch()
118            .serialize()
119        )
120
121    @parameterized.expand(test_data_suite)
122    def test_permute_tosa_MI(
123        self, test_name: str, test_data: torch.Tensor, dims: list[int]
124    ):
125        self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True)
126        self._test_permute_tosa_MI_pipeline(
127            self.Permute(dims=dims), (test_data,), False
128        )
129
130    @parameterized.expand(test_data_suite)
131    def test_permute_tosa_BI(
132        self, test_name: str, test_data: torch.Tensor, dims: list[int]
133    ):
134        self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,))
135
136    # Expected to fail as TOSA.Transpose is not supported by Ethos-U55.
137    @parameterized.expand(test_data_suite[0:1])
138    @unittest.expectedFailure
139    def test_permute_u55_BI(
140        self, test_name: str, test_data: torch.Tensor, dims: list[int]
141    ):
142        self._test_permute_ethos_BI_pipeline(
143            self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,)
144        )
145
146    @parameterized.expand(test_data_suite)
147    def test_permute_u85_BI(
148        self, test_name: str, test_data: torch.Tensor, dims: list[int]
149    ):
150        self._test_permute_ethos_BI_pipeline(
151            self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,)
152        )
153