xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_permute.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import List
9
10import serializer.tosa_serializer as ts
11import torch
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from serializer.tosa_serializer import TosaOp
18
19
20def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor:
21    """
22    Converts a permutation vector of length N to a NxN matrix that describes the same permutation.
23    for example:
24    (1,0,2)
25    ->
26    [0 1 0]
27    |1 0 0|
28    [0 0 1]
29    """
30    N = len(permutation_vector)
31    P = torch.zeros(N, N)
32    for row_index, col_index in enumerate(permutation_vector):
33        P[row_index][col_index] = 1
34    return P
35
36
37def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]:
38    """
39    Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation.
40    [0 1 0]
41    |1 0 0|
42    [0 0 1]
43    ->
44    (1,0,2)
45    """
46    N = len(permutation_matrix)
47    assert N == len(
48        permutation_matrix[0]
49    ), f"A permutation matrix must be square, got shape {permutation_matrix.shape}"
50
51    p = [0] * N
52    for row_index, row in enumerate(permutation_matrix):
53        saw_one = False
54        for col_index, value in enumerate(row):
55            if value == 1:
56                assert (
57                    not saw_one
58                ), f"A permutation matrix can only have one 1 per row, got row {row}."
59                p[row_index] = col_index
60                saw_one = True
61            else:
62                assert (
63                    value == 0
64                ), f"A permutation matrix only contains 1's and 0's, got value {value}."
65    return p
66
67
68@register_node_visitor
69class PermuteVisitor(NodeVisitor):
70    target = "aten.permute_copy.default"
71
72    def __init__(self, *args):
73        super().__init__(*args)
74
75    def define_node(
76        self,
77        node: torch.fx.Node,
78        tosa_graph: ts.TosaSerializer,
79        inputs: List[TosaArg],
80        output: TosaArg,
81        is_quant_node: bool,
82    ) -> None:
83        # The permutation vector describes a permutation P in default Pytorch dim_order.
84        # For rank 4, the default dim_order NCHW.
85        # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
86        permutation_vector = inputs[1].special
87
88        if output.dim_order != tuple(range(len(output.dim_order))):
89            # the permutation vector can't be used directly if we are not in NCHW dim_order.
90            # We need to first transform to NCHW, apply P,
91            # and then transform back to the original dim_order.
92            # This transformation, S, is also a permutation, with the dim_order as permutation vector.
93
94            # To do this, represent P and S with permutation matrices.
95            # Matrices can handle chained transformations and inversion easily.
96            S = permutation_vector_to_matrix(output.dim_order)
97            # The inverse of a permutation matrix is its transpose.
98            S_inverse = S.transpose(1, 0)
99            P = permutation_vector_to_matrix(permutation_vector)
100
101            # The complete transformation is S * P * S_inverse.
102            transformation_matrix = S.matmul(P.matmul(S_inverse))
103
104            # Luckily, since it is just a combination of permutations, the result is also a permutation
105            # that can again be described by a new permutation vector.
106            permutation_vector = permutation_matrix_to_vector(transformation_matrix)
107
108        attr = ts.TosaSerializerAttribute()
109        attr.TransposeAttribute(permutation_vector)
110        tosa_graph.addOperator(
111            TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
112        )
113