xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_sum.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import cast, List
9
10import executorch.backends.arm.tosa_quant_utils as tqutils
11import executorch.backends.arm.tosa_utils as tutils
12
13import serializer.tosa_serializer as ts
14from executorch.backends.arm.operators.node_visitor import (
15    NodeVisitor,
16    register_node_visitor,
17)
18from executorch.backends.arm.tosa_mapping import TosaArg
19from serializer.tosa_serializer import TosaOp
20from torch.fx import Node
21
22
23@register_node_visitor
24class AddVisitor(NodeVisitor):
25    target = "aten.sum.dim_IntList"
26
27    def __init__(self, *args):
28        super().__init__(*args)
29
30    def define_node(
31        self,
32        node: Node,
33        tosa_graph: ts.TosaSerializer,
34        inputs: List[TosaArg],
35        output: TosaArg,
36        is_quant_node: bool,
37    ) -> None:
38        input_node = inputs[0]
39        input_shape = list(input_node.shape)
40        dim_list = cast(list[int], inputs[1].special)
41        dim_list = [dim % len(input_node.shape) for dim in dim_list]
42        keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
43        assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
44
45        if is_quant_node:
46
47            # Rescale input to 32 bit
48            rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
49                [node.all_input_nodes[0]], tosa_graph
50            )
51
52            prev_node = rescaled_inputs[0]
53            reduced_shape = input_shape
54
55            # Reduce all dims in dim_list one-by-one.
56            for dim in dim_list:
57                # When reduced, the size of the dim becomes 1.
58                reduced_shape[dim] = 1
59
60                attr = ts.TosaSerializerAttribute()
61                attr.AxisAttribute(input_node.dim_order.index(dim))
62
63                next_node = tosa_graph.addIntermediate(
64                    tutils.tosa_shape(reduced_shape, input_node.dim_order),
65                    dtype=ts.DType.INT32,
66                )
67
68                tosa_graph.addOperator(
69                    TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
70                )
71
72                prev_node = next_node
73            tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
74        else:
75            input_name = input_node.name
76            reduced_shape = input_shape
77
78            # Reduce all dims in dim_list one-by-one.
79            for dim in dim_list:
80                # When reduced, the size of the dim becomes 1
81                reduced_shape[dim] = 1
82
83                attr = ts.TosaSerializerAttribute()
84                attr.AxisAttribute(input_node.dim_order.index(dim))
85
86                if dim == dim_list[-1]:
87                    output_name = output.name
88                else:
89                    output_name = tosa_graph.addIntermediate(
90                        tutils.tosa_shape(reduced_shape, input_node.dim_order),
91                        dtype=ts.DType.FP32,
92                    ).name
93
94                tosa_graph.addOperator(
95                    TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
96                )
97
98                input_name = output_name
99