1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import List 9 10import serializer.tosa_serializer as ts 11import torch 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from serializer.tosa_serializer import TosaOp 18 19 20@register_node_visitor 21class TransposeVisitor(NodeVisitor): 22 """ 23 This node visitor targets the _transpose op defined in the 24 passthrough_to_tosa library. Used when switching between tosa_dim_orders. 25 Inserts a TOSA TRANSPOSE. 26 """ 27 28 target = "_transpose" 29 30 def define_node( 31 self, 32 node: torch.fx.Node, 33 tosa_graph: ts.TosaSerializer, 34 inputs: List[TosaArg], 35 output: TosaArg, 36 is_quant_node: bool, 37 ) -> None: 38 output_rank = len(output.shape) 39 perms = [dim % output_rank for dim in inputs[1].special] 40 attr = ts.TosaSerializerAttribute() 41 attr.TransposeAttribute(perms) 42 tosa_graph.addOperator( 43 TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr 44 ) 45