xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_add.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import List
9
10import executorch.backends.arm.tosa_quant_utils as tqutils
11import executorch.backends.arm.tosa_utils as tutils
12
13import serializer.tosa_serializer as ts
14import torch
15from executorch.backends.arm.operators.node_visitor import (
16    NodeVisitor,
17    register_node_visitor,
18)
19from executorch.backends.arm.tosa_mapping import TosaArg
20from executorch.backends.arm.tosa_specification import TosaSpecification
21from serializer.tosa_serializer import TosaOp
22from torch.fx import Node
23
24
25@register_node_visitor
26class AddVisitor_080_BI(NodeVisitor):
27    target = "aten.add.Tensor"
28
29    tosa_specs = [
30        TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
31    ]
32
33    def __init__(self, *args):
34        super().__init__(*args)
35
36    def define_node(
37        self,
38        node: Node,
39        tosa_graph: ts.TosaSerializer,
40        inputs: List[TosaArg],
41        output: TosaArg,
42        is_quant_node: bool,
43    ) -> None:
44        input_nodes = tutils.get_two_inputs(node)
45
46        if not is_quant_node and not all(
47            tensor.meta["val"].dtype in (torch.int8, torch.int32)
48            for tensor in input_nodes
49        ):
50            raise RuntimeError(
51                f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52            )
53
54        needs_rescale = not (
55            all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56            and node.meta["val"].dtype == torch.int32
57        )
58
59        if needs_rescale:
60            # Rescale inputs to 32 bit
61            rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62                input_nodes, tosa_graph
63            )
64
65            # Prepare add output tensor
66            broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
67            add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
68        else:
69            add_output = output
70            rescaled_inputs = inputs
71
72        # Do the INT32 Add
73        tosa_graph.addOperator(
74            TosaOp.Op().ADD,
75            [
76                rescaled_inputs[0].name,
77                rescaled_inputs[1].name,
78            ],
79            [add_output.name],
80            None,
81        )
82
83        if needs_rescale:
84            # Scale output back to 8 bit
85            # pyre-ignore
86            tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
87
88
89@register_node_visitor
90class AddVisitor_080_MI(AddVisitor_080_BI):
91    # inheriting 'target' from BI class
92
93    tosa_specs = [
94        TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
95    ]
96
97    def __init__(self, *args):
98        super().__init__(*args)
99
100    def define_node(
101        self,
102        node: Node,
103        tosa_graph: ts.TosaSerializer,
104        inputs: List[TosaArg],
105        output: TosaArg,
106        is_quant_node: bool,
107    ) -> None:
108        if is_quant_node:
109            # Call the inherited define_node for handling integers
110            super().define_node(node, tosa_graph, inputs, output, is_quant_node)
111        else:
112            # FP32 Add lowering
113            tosa_graph.addOperator(
114                TosaOp.Op().ADD,
115                [inputs[0].name, inputs[1].name],
116                [output.name],
117                None,
118            )
119