xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_mul.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import cast, List
9
10import executorch.backends.arm.tosa_quant_utils as tqutils
11import executorch.backends.arm.tosa_utils as tutils
12
13import serializer.tosa_serializer as ts
14import torch
15
16from executorch.backends.arm.operators.node_visitor import (
17    NodeVisitor,
18    register_node_visitor,
19)
20from executorch.backends.arm.tosa_mapping import TosaArg
21from serializer.tosa_serializer import TosaOp
22
23
24@register_node_visitor
25class MulVisitor(NodeVisitor):
26    target = "aten.mul.Tensor"
27
28    def define_node(
29        self,
30        node: torch.fx.Node,
31        tosa_graph: ts.TosaSerializer,
32        inputs: List[TosaArg],
33        output: TosaArg,
34        is_quant_node: bool,
35    ) -> None:
36
37        if is_quant_node:
38            input_A = inputs[0]
39            input_B = inputs[1]
40            input_A_qargs = tqutils.get_quant_arg_upstream(
41                cast(torch.fx.Node, node.args[0])
42            )
43            input_B_qargs = tqutils.get_quant_arg_upstream(
44                cast(torch.fx.Node, node.args[1])
45            )
46
47            input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
48            input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
49            output_shape = tutils.tosa_shape(output.shape, output.dim_order)
50
51            # Rescale inputs to INT32 with zp=0
52            input_A_rescaled = tqutils.build_rescale_to_int32(
53                tosa_graph,
54                input_A,
55                input_A_qargs.zp,
56                rescale_scale=1.0,
57            )
58            input_B_rescaled = tqutils.build_rescale_to_int32(
59                tosa_graph,
60                input_B,
61                input_B_qargs.zp,
62                rescale_scale=1.0,
63            )
64
65            mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
66
67            # Do the INT32 Mul
68            attr = ts.TosaSerializerAttribute()
69            attr.MulAttribute(shift=0)
70            tosa_graph.addOperator(
71                TosaOp.Op().MUL,
72                [
73                    input_A_rescaled.name,
74                    input_B_rescaled.name,
75                ],
76                [mul_output.name],
77                attr,
78            )
79
80            tqutils.rescale_node_back_to_int8(
81                node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph
82            )
83
84        else:
85            attr = ts.TosaSerializerAttribute()
86            attr.MulAttribute(shift=0)
87            tosa_graph.addOperator(
88                TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
89            )
90