1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import List 9 10import serializer.tosa_serializer as ts 11import torch 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from serializer.tosa_serializer import TosaOp 18 19 20def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: 21 """ 22 Converts a permutation vector of length N to a NxN matrix that describes the same permutation. 23 for example: 24 (1,0,2) 25 -> 26 [0 1 0] 27 |1 0 0| 28 [0 0 1] 29 """ 30 N = len(permutation_vector) 31 P = torch.zeros(N, N) 32 for row_index, col_index in enumerate(permutation_vector): 33 P[row_index][col_index] = 1 34 return P 35 36 37def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: 38 """ 39 Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation. 40 [0 1 0] 41 |1 0 0| 42 [0 0 1] 43 -> 44 (1,0,2) 45 """ 46 N = len(permutation_matrix) 47 assert N == len( 48 permutation_matrix[0] 49 ), f"A permutation matrix must be square, got shape {permutation_matrix.shape}" 50 51 p = [0] * N 52 for row_index, row in enumerate(permutation_matrix): 53 saw_one = False 54 for col_index, value in enumerate(row): 55 if value == 1: 56 assert ( 57 not saw_one 58 ), f"A permutation matrix can only have one 1 per row, got row {row}." 59 p[row_index] = col_index 60 saw_one = True 61 else: 62 assert ( 63 value == 0 64 ), f"A permutation matrix only contains 1's and 0's, got value {value}." 65 return p 66 67 68@register_node_visitor 69class PermuteVisitor(NodeVisitor): 70 target = "aten.permute_copy.default" 71 72 def __init__(self, *args): 73 super().__init__(*args) 74 75 def define_node( 76 self, 77 node: torch.fx.Node, 78 tosa_graph: ts.TosaSerializer, 79 inputs: List[TosaArg], 80 output: TosaArg, 81 is_quant_node: bool, 82 ) -> None: 83 # The permutation vector describes a permutation P in default Pytorch dim_order. 84 # For rank 4, the default dim_order NCHW. 85 # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) 86 permutation_vector = inputs[1].special 87 88 if output.dim_order != tuple(range(len(output.dim_order))): 89 # the permutation vector can't be used directly if we are not in NCHW dim_order. 90 # We need to first transform to NCHW, apply P, 91 # and then transform back to the original dim_order. 92 # This transformation, S, is also a permutation, with the dim_order as permutation vector. 93 94 # To do this, represent P and S with permutation matrices. 95 # Matrices can handle chained transformations and inversion easily. 96 S = permutation_vector_to_matrix(output.dim_order) 97 # The inverse of a permutation matrix is its transpose. 98 S_inverse = S.transpose(1, 0) 99 P = permutation_vector_to_matrix(permutation_vector) 100 101 # The complete transformation is S * P * S_inverse. 102 transformation_matrix = S.matmul(P.matmul(S_inverse)) 103 104 # Luckily, since it is just a combination of permutations, the result is also a permutation 105 # that can again be described by a new permutation vector. 106 permutation_vector = permutation_matrix_to_vector(transformation_matrix) 107 108 attr = ts.TosaSerializerAttribute() 109 attr.TransposeAttribute(permutation_vector) 110 tosa_graph.addOperator( 111 TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr 112 ) 113