1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import cast, List 9 10import executorch.backends.arm.tosa_quant_utils as tqutils 11import executorch.backends.arm.tosa_utils as tutils 12 13import serializer.tosa_serializer as ts 14from executorch.backends.arm.operators.node_visitor import ( 15 NodeVisitor, 16 register_node_visitor, 17) 18from executorch.backends.arm.tosa_mapping import TosaArg 19from serializer.tosa_serializer import TosaOp 20from torch.fx import Node 21 22 23@register_node_visitor 24class AddVisitor(NodeVisitor): 25 target = "aten.sum.dim_IntList" 26 27 def __init__(self, *args): 28 super().__init__(*args) 29 30 def define_node( 31 self, 32 node: Node, 33 tosa_graph: ts.TosaSerializer, 34 inputs: List[TosaArg], 35 output: TosaArg, 36 is_quant_node: bool, 37 ) -> None: 38 input_node = inputs[0] 39 input_shape = list(input_node.shape) 40 dim_list = cast(list[int], inputs[1].special) 41 dim_list = [dim % len(input_node.shape) for dim in dim_list] 42 keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) 43 assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" 44 45 if is_quant_node: 46 47 # Rescale input to 32 bit 48 rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( 49 [node.all_input_nodes[0]], tosa_graph 50 ) 51 52 prev_node = rescaled_inputs[0] 53 reduced_shape = input_shape 54 55 # Reduce all dims in dim_list one-by-one. 56 for dim in dim_list: 57 # When reduced, the size of the dim becomes 1. 58 reduced_shape[dim] = 1 59 60 attr = ts.TosaSerializerAttribute() 61 attr.AxisAttribute(input_node.dim_order.index(dim)) 62 63 next_node = tosa_graph.addIntermediate( 64 tutils.tosa_shape(reduced_shape, input_node.dim_order), 65 dtype=ts.DType.INT32, 66 ) 67 68 tosa_graph.addOperator( 69 TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr 70 ) 71 72 prev_node = next_node 73 tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph) 74 else: 75 input_name = input_node.name 76 reduced_shape = input_shape 77 78 # Reduce all dims in dim_list one-by-one. 79 for dim in dim_list: 80 # When reduced, the size of the dim becomes 1 81 reduced_shape[dim] = 1 82 83 attr = ts.TosaSerializerAttribute() 84 attr.AxisAttribute(input_node.dim_order.index(dim)) 85 86 if dim == dim_list[-1]: 87 output_name = output.name 88 else: 89 output_name = tosa_graph.addIntermediate( 90 tutils.tosa_shape(reduced_shape, input_node.dim_order), 91 dtype=ts.DType.FP32, 92 ).name 93 94 tosa_graph.addOperator( 95 TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr 96 ) 97 98 input_name = output_name 99