xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_bmm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8from typing import List
9
10import serializer.tosa_serializer as ts
11import torch.fx
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from executorch.backends.arm.tosa_quant_utils import (
18    build_rescale,
19    get_quant_arg_downstream,
20    get_quant_arg_upstream,
21)
22from executorch.backends.arm.tosa_utils import get_two_inputs
23from serializer.tosa_serializer import TosaOp
24
25
26@register_node_visitor
27class BMMVisitor(NodeVisitor):
28    target = "aten.bmm.default"
29
30    def __init__(self, *args):
31        super().__init__(*args)
32
33    def define_node(
34        self,
35        node: torch.fx.Node,
36        tosa_graph: ts.TosaSerializer,
37        inputs: List[TosaArg],
38        output: TosaArg,
39        is_quant_node: bool,
40    ) -> None:
41        input0, input1 = get_two_inputs(node)
42
43        # aten.bmm maps directly to MATMUL
44        # NOTE: For now, only INT8 & FP32 is supported
45
46        # For INT8, we need to get the zero points and add an intermediate tensor
47        # for a later rescale.
48        if is_quant_node:
49            input0_q_params = get_quant_arg_upstream(input0)
50            input1_q_params = get_quant_arg_upstream(input1)
51            input0_zp = input0_q_params.zp
52            input1_zp = input1_q_params.zp
53            bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
54            bmm_output_name = bmm_result.name
55        else:
56            input0_zp, input1_zp = 0, 0
57            bmm_output_name = output.name
58
59        # Add the MATMUL to the TOSA graph.
60        attr = ts.TosaSerializerAttribute()
61        attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)
62
63        tosa_graph.addOperator(
64            TosaOp.Op().MATMUL,
65            [input0.name, input1.name],
66            [bmm_output_name],
67            attr,
68        )
69
70        # As INT8 accumulates into INT32, we need to rescale it back to INT8
71        if is_quant_node:
72            output_q_params = get_quant_arg_downstream(list(node.users)[0])
73
74            final_output_scale = (
75                input0_q_params.scale * input1_q_params.scale
76            ) / output_q_params.scale
77
78            build_rescale(
79                tosa_fb=tosa_graph,
80                scale=final_output_scale,
81                # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
82                input_node=bmm_result,
83                output_name=output.name,
84                output_type=ts.DType.INT8,
85                output_shape=bmm_result.shape,
86                input_zp=0,
87                output_zp=output_q_params.zp,
88                is_double_round=False,
89            )
90