xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/permute.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestPermute(unittest.TestCase):
14    class Permute(torch.nn.Module):
15        def __init__(self, dims):
16            self.dims = dims
17            super().__init__()
18
19        def forward(self, x):
20            y = x + x
21            z = torch.permute(y, self.dims)
22            return z
23
24    class PermuteCopy(torch.nn.Module):
25        def __init__(self, dims):
26            self.dims = dims
27            super().__init__()
28
29        def forward(self, x):
30            y = x + x
31            z = torch.permute_copy(y, self.dims)
32            return z
33
34    def _test_permute(self, inputs):
35        (
36            Tester(self.Permute([0, 2, 3, 1]), inputs)
37            .export()
38            .check_count({"torch.ops.aten.permute.default": 1})
39            .to_edge_transform_and_lower()
40            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
41            .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"])
42            .to_executorch()
43            .serialize()
44            .run_method_and_compare_outputs()
45        )
46
47    def test_fp16_permute(self):
48        inputs = (torch.randn(1, 1, 4, 4).to(torch.float16),)
49        self._test_permute(inputs)
50
51    def test_fp32_permute(self):
52        inputs = (torch.randn(1, 1, 4, 4),)
53        self._test_permute(inputs)
54
55    def test_fp32_permute_copy(self):
56        inputs = (torch.randn(1, 1, 4, 4),)
57        (
58            Tester(self.PermuteCopy([0, 2, 3, 1]), inputs)
59            .export()
60            .check_count({"torch.ops.aten.permute_copy.default": 1})
61            .to_edge_transform_and_lower()
62            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
63            .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"])
64            .to_executorch()
65            .serialize()
66            .run_method_and_compare_outputs()
67        )
68
69    def test_qs8_permute(self):
70        inputs = (torch.randn(1, 1, 4, 4),)
71        (
72            Tester(self.Permute([0, 2, 3, 1]), inputs)
73            .quantize()
74            .export()
75            .check_node_count(
76                {
77                    torch.ops.aten.permute.default: 1,
78                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
79                }
80            )
81            .to_edge_transform_and_lower()
82            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
83            .check_not(
84                [
85                    "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
86                    "torch.ops.quantized_decomposed",
87                ]
88            )
89            .to_executorch()
90            .serialize()
91            .run_method_and_compare_outputs()
92        )
93
94    def test_qs8_permute_copy(self):
95        inputs = (torch.randn(1, 1, 4, 4),)
96        (
97            Tester(self.PermuteCopy([0, 2, 3, 1]), inputs)
98            .quantize()
99            .export()
100            .check_node_count(
101                {
102                    torch.ops.aten.permute_copy.default: 1,
103                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
104                }
105            )
106            .to_edge_transform_and_lower()
107            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108            .check_not(
109                [
110                    "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
111                    "torch.ops.quantized_decomposed",
112                ]
113            )
114            .to_executorch()
115            .serialize()
116            .run_method_and_compare_outputs()
117        )
118