1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9from typing import Tuple 10 11import torch 12 13from executorch.backends.arm.quantizer.arm_quantizer import ( 14 ArmQuantizer, 15 get_symmetric_quantization_config, 16) 17 18from executorch.backends.arm.test import common 19from executorch.backends.arm.test.tester.arm_tester import ArmTester 20from executorch.backends.xnnpack.test.tester.tester import Quantize 21from executorch.exir.backend.compile_spec_schema import CompileSpec 22from parameterized import parameterized 23from torchvision.ops import Permute 24 25test_data_suite = [ 26 # (test_name,test_data,dims) 27 ("rank_2", torch.rand(10, 10), [1, 0]), 28 ("rank_3", torch.rand(10, 10, 10), [2, 0, 1]), 29 ("rank_3", torch.rand(10, 10, 10), [1, 2, 0]), 30 ("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), 31 ("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]), 32 ("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]), 33] 34 35 36class TestPermute(unittest.TestCase): 37 """Tests Permute Operator.""" 38 39 class Permute(torch.nn.Module): 40 41 def __init__(self, dims: list[int]): 42 super().__init__() 43 44 self.permute = Permute(dims=dims) 45 46 def forward(self, x): 47 return self.permute(x) 48 49 def _test_permute_tosa_MI_pipeline( 50 self, 51 module: torch.nn.Module, 52 test_data: Tuple[torch.tensor], 53 permute_memory_to_nhwc: bool, 54 ): 55 ( 56 ArmTester( 57 module, 58 example_inputs=test_data, 59 compile_spec=common.get_tosa_compile_spec( 60 "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc 61 ), 62 ) 63 .export() 64 .check(["torch.ops.aten.permute.default"]) 65 .check_not(["torch.ops.quantized_decomposed"]) 66 .to_edge() 67 .partition() 68 .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) 69 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 70 .to_executorch() 71 .run_method_and_compare_outputs(inputs=test_data) 72 ) 73 74 def _test_permute_tosa_BI_pipeline( 75 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 76 ): 77 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 78 ( 79 ArmTester( 80 module, 81 example_inputs=test_data, 82 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 83 ) 84 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 85 .export() 86 .check_count({"torch.ops.aten.permute.default": 1}) 87 .check(["torch.ops.quantized_decomposed"]) 88 .to_edge() 89 .partition() 90 .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) 91 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 92 .to_executorch() 93 .run_method_and_compare_outputs(inputs=test_data) 94 ) 95 96 def _test_permute_ethos_BI_pipeline( 97 self, 98 module: torch.nn.Module, 99 compile_spec: CompileSpec, 100 test_data: Tuple[torch.Tensor], 101 ): 102 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 103 ( 104 ArmTester( 105 module, 106 example_inputs=test_data, 107 compile_spec=compile_spec, 108 ) 109 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 110 .export() 111 .check_count({"torch.ops.aten.permute.default": 1}) 112 .check(["torch.ops.quantized_decomposed"]) 113 .to_edge() 114 .partition() 115 .check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) 116 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 117 .to_executorch() 118 .serialize() 119 ) 120 121 @parameterized.expand(test_data_suite) 122 def test_permute_tosa_MI( 123 self, test_name: str, test_data: torch.Tensor, dims: list[int] 124 ): 125 self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True) 126 self._test_permute_tosa_MI_pipeline( 127 self.Permute(dims=dims), (test_data,), False 128 ) 129 130 @parameterized.expand(test_data_suite) 131 def test_permute_tosa_BI( 132 self, test_name: str, test_data: torch.Tensor, dims: list[int] 133 ): 134 self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,)) 135 136 # Expected to fail as TOSA.Transpose is not supported by Ethos-U55. 137 @parameterized.expand(test_data_suite[0:1]) 138 @unittest.expectedFailure 139 def test_permute_u55_BI( 140 self, test_name: str, test_data: torch.Tensor, dims: list[int] 141 ): 142 self._test_permute_ethos_BI_pipeline( 143 self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,) 144 ) 145 146 @parameterized.expand(test_data_suite) 147 def test_permute_u85_BI( 148 self, test_name: str, test_data: torch.Tensor, dims: list[int] 149 ): 150 self._test_permute_ethos_BI_pipeline( 151 self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,) 152 ) 153