xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_mm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8from typing import List
9
10import serializer.tosa_serializer as ts
11import torch
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from executorch.backends.arm.tosa_quant_utils import (
18    build_rescale,
19    get_quant_arg_downstream,
20    get_quant_arg_upstream,
21)
22from executorch.backends.arm.tosa_utils import (
23    build_reshape,
24    expand_dims,
25    get_two_inputs,
26)
27from serializer.tosa_serializer import TosaOp
28
29
30@register_node_visitor
31class MMVisitor(NodeVisitor):
32    target = "aten.mm.default"
33
34    def __init__(self, *args):
35        super().__init__(*args)
36
37    def define_node(
38        self,
39        node: torch.fx.Node,
40        tosa_graph: ts.TosaSerializer,
41        inputs: List[TosaArg],
42        output: TosaArg,
43        is_quant_node: bool,
44    ) -> None:
45        input0, input1 = get_two_inputs(node)
46
47        # For atem.mm, the two inputs are of rank 2
48        # For TOSA it needs to be rank 3
49        # So they need to be reshaped from (H, W) to (1, H, W)
50        # NOTE: For now, only INT8 & FP32 is supported
51        reshape_dtype = ts.DType.INT8 if is_quant_node else ts.DType.FP32
52        input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0)
53        input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0)
54
55        # The output also needs to be rank 3
56        output_new_shape = (1, output.shape[0], output.shape[1])
57
58        # For INT8, we need to get the zero point, otherwise it is 0
59        input0_zp, input1_zp = 0, 0
60        if is_quant_node:
61            input0_zp = get_quant_arg_upstream(input0).zp
62            input1_zp = get_quant_arg_upstream(input1).zp
63
64        mat_mul_result = tosa_graph.addIntermediate(
65            output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
66        )
67
68        attr = ts.TosaSerializerAttribute()
69        attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)
70
71        tosa_graph.addOperator(
72            TosaOp.Op().MATMUL,
73            [input0_reshaped.name, input1_reshaped.name],
74            [mat_mul_result.name],
75            attr,
76        )
77
78        if is_quant_node:
79            reshape_intermediate = tosa_graph.addIntermediate(
80                output.shape, ts.DType.INT32
81            )
82            reshape_output_name = reshape_intermediate.name
83        else:
84            reshape_output_name = output.name
85
86        # Reshape the final output back to rank 2
87        build_reshape(
88            tosa_graph, mat_mul_result.name, output.shape, reshape_output_name
89        )
90
91        # As INT8 accumulates into INT32, we need to rescale it back to INT8
92        if is_quant_node:
93            input0_q_params = get_quant_arg_upstream(input0)
94            input1_q_params = get_quant_arg_upstream(input1)
95            output_q_params = get_quant_arg_downstream(list(node.users)[0])
96
97            final_output_scale = (
98                input0_q_params.scale * input1_q_params.scale
99            ) / output_q_params.scale
100
101            # As the input will be INT32, the input_zp must be set to 0
102            build_rescale(
103                tosa_fb=tosa_graph,
104                scale=final_output_scale,
105                # pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined.
106                input_node=reshape_intermediate,
107                output_name=output.name,
108                output_type=ts.DType.INT8,
109                output_shape=reshape_intermediate.shape,
110                input_zp=0,
111                output_zp=output_q_params.zp,
112                is_double_round=False,
113            )
114