1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestPermute(unittest.TestCase): 14 class Permute(torch.nn.Module): 15 def __init__(self, dims): 16 self.dims = dims 17 super().__init__() 18 19 def forward(self, x): 20 y = x + x 21 z = torch.permute(y, self.dims) 22 return z 23 24 class PermuteCopy(torch.nn.Module): 25 def __init__(self, dims): 26 self.dims = dims 27 super().__init__() 28 29 def forward(self, x): 30 y = x + x 31 z = torch.permute_copy(y, self.dims) 32 return z 33 34 def _test_permute(self, inputs): 35 ( 36 Tester(self.Permute([0, 2, 3, 1]), inputs) 37 .export() 38 .check_count({"torch.ops.aten.permute.default": 1}) 39 .to_edge_transform_and_lower() 40 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 41 .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) 42 .to_executorch() 43 .serialize() 44 .run_method_and_compare_outputs() 45 ) 46 47 def test_fp16_permute(self): 48 inputs = (torch.randn(1, 1, 4, 4).to(torch.float16),) 49 self._test_permute(inputs) 50 51 def test_fp32_permute(self): 52 inputs = (torch.randn(1, 1, 4, 4),) 53 self._test_permute(inputs) 54 55 def test_fp32_permute_copy(self): 56 inputs = (torch.randn(1, 1, 4, 4),) 57 ( 58 Tester(self.PermuteCopy([0, 2, 3, 1]), inputs) 59 .export() 60 .check_count({"torch.ops.aten.permute_copy.default": 1}) 61 .to_edge_transform_and_lower() 62 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 63 .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) 64 .to_executorch() 65 .serialize() 66 .run_method_and_compare_outputs() 67 ) 68 69 def test_qs8_permute(self): 70 inputs = (torch.randn(1, 1, 4, 4),) 71 ( 72 Tester(self.Permute([0, 2, 3, 1]), inputs) 73 .quantize() 74 .export() 75 .check_node_count( 76 { 77 torch.ops.aten.permute.default: 1, 78 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 79 } 80 ) 81 .to_edge_transform_and_lower() 82 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 83 .check_not( 84 [ 85 "executorch_exir_dialects_edge__ops_aten_permute_copy_default", 86 "torch.ops.quantized_decomposed", 87 ] 88 ) 89 .to_executorch() 90 .serialize() 91 .run_method_and_compare_outputs() 92 ) 93 94 def test_qs8_permute_copy(self): 95 inputs = (torch.randn(1, 1, 4, 4),) 96 ( 97 Tester(self.PermuteCopy([0, 2, 3, 1]), inputs) 98 .quantize() 99 .export() 100 .check_node_count( 101 { 102 torch.ops.aten.permute_copy.default: 1, 103 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 104 } 105 ) 106 .to_edge_transform_and_lower() 107 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 108 .check_not( 109 [ 110 "executorch_exir_dialects_edge__ops_aten_permute_copy_default", 111 "torch.ops.quantized_decomposed", 112 ] 113 ) 114 .to_executorch() 115 .serialize() 116 .run_method_and_compare_outputs() 117 ) 118