1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8from typing import List 9 10import serializer.tosa_serializer as ts 11import torch.fx 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from executorch.backends.arm.tosa_quant_utils import ( 18 build_rescale, 19 get_quant_arg_downstream, 20 get_quant_arg_upstream, 21) 22from executorch.backends.arm.tosa_utils import get_two_inputs 23from serializer.tosa_serializer import TosaOp 24 25 26@register_node_visitor 27class BMMVisitor(NodeVisitor): 28 target = "aten.bmm.default" 29 30 def __init__(self, *args): 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 tosa_graph: ts.TosaSerializer, 37 inputs: List[TosaArg], 38 output: TosaArg, 39 is_quant_node: bool, 40 ) -> None: 41 input0, input1 = get_two_inputs(node) 42 43 # aten.bmm maps directly to MATMUL 44 # NOTE: For now, only INT8 & FP32 is supported 45 46 # For INT8, we need to get the zero points and add an intermediate tensor 47 # for a later rescale. 48 if is_quant_node: 49 input0_q_params = get_quant_arg_upstream(input0) 50 input1_q_params = get_quant_arg_upstream(input1) 51 input0_zp = input0_q_params.zp 52 input1_zp = input1_q_params.zp 53 bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) 54 bmm_output_name = bmm_result.name 55 else: 56 input0_zp, input1_zp = 0, 0 57 bmm_output_name = output.name 58 59 # Add the MATMUL to the TOSA graph. 60 attr = ts.TosaSerializerAttribute() 61 attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) 62 63 tosa_graph.addOperator( 64 TosaOp.Op().MATMUL, 65 [input0.name, input1.name], 66 [bmm_output_name], 67 attr, 68 ) 69 70 # As INT8 accumulates into INT32, we need to rescale it back to INT8 71 if is_quant_node: 72 output_q_params = get_quant_arg_downstream(list(node.users)[0]) 73 74 final_output_scale = ( 75 input0_q_params.scale * input1_q_params.scale 76 ) / output_q_params.scale 77 78 build_rescale( 79 tosa_fb=tosa_graph, 80 scale=final_output_scale, 81 # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. 82 input_node=bmm_result, 83 output_name=output.name, 84 output_type=ts.DType.INT8, 85 output_shape=bmm_result.shape, 86 input_zp=0, 87 output_zp=output_q_params.zp, 88 is_double_round=False, 89 ) 90