1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8from typing import List 9 10import serializer.tosa_serializer as ts 11import torch 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from executorch.backends.arm.tosa_quant_utils import ( 18 build_rescale, 19 get_quant_arg_downstream, 20 get_quant_arg_upstream, 21) 22from executorch.backends.arm.tosa_utils import ( 23 build_reshape, 24 expand_dims, 25 get_two_inputs, 26) 27from serializer.tosa_serializer import TosaOp 28 29 30@register_node_visitor 31class MMVisitor(NodeVisitor): 32 target = "aten.mm.default" 33 34 def __init__(self, *args): 35 super().__init__(*args) 36 37 def define_node( 38 self, 39 node: torch.fx.Node, 40 tosa_graph: ts.TosaSerializer, 41 inputs: List[TosaArg], 42 output: TosaArg, 43 is_quant_node: bool, 44 ) -> None: 45 input0, input1 = get_two_inputs(node) 46 47 # For atem.mm, the two inputs are of rank 2 48 # For TOSA it needs to be rank 3 49 # So they need to be reshaped from (H, W) to (1, H, W) 50 # NOTE: For now, only INT8 & FP32 is supported 51 reshape_dtype = ts.DType.INT8 if is_quant_node else ts.DType.FP32 52 input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0) 53 input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0) 54 55 # The output also needs to be rank 3 56 output_new_shape = (1, output.shape[0], output.shape[1]) 57 58 # For INT8, we need to get the zero point, otherwise it is 0 59 input0_zp, input1_zp = 0, 0 60 if is_quant_node: 61 input0_zp = get_quant_arg_upstream(input0).zp 62 input1_zp = get_quant_arg_upstream(input1).zp 63 64 mat_mul_result = tosa_graph.addIntermediate( 65 output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype 66 ) 67 68 attr = ts.TosaSerializerAttribute() 69 attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) 70 71 tosa_graph.addOperator( 72 TosaOp.Op().MATMUL, 73 [input0_reshaped.name, input1_reshaped.name], 74 [mat_mul_result.name], 75 attr, 76 ) 77 78 if is_quant_node: 79 reshape_intermediate = tosa_graph.addIntermediate( 80 output.shape, ts.DType.INT32 81 ) 82 reshape_output_name = reshape_intermediate.name 83 else: 84 reshape_output_name = output.name 85 86 # Reshape the final output back to rank 2 87 build_reshape( 88 tosa_graph, mat_mul_result.name, output.shape, reshape_output_name 89 ) 90 91 # As INT8 accumulates into INT32, we need to rescale it back to INT8 92 if is_quant_node: 93 input0_q_params = get_quant_arg_upstream(input0) 94 input1_q_params = get_quant_arg_upstream(input1) 95 output_q_params = get_quant_arg_downstream(list(node.users)[0]) 96 97 final_output_scale = ( 98 input0_q_params.scale * input1_q_params.scale 99 ) / output_q_params.scale 100 101 # As the input will be INT32, the input_zp must be set to 0 102 build_rescale( 103 tosa_fb=tosa_graph, 104 scale=final_output_scale, 105 # pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined. 106 input_node=reshape_intermediate, 107 output_name=output.name, 108 output_type=ts.DType.INT8, 109 output_shape=reshape_intermediate.shape, 110 input_zp=0, 111 output_zp=output_q_params.zp, 112 is_double_round=False, 113 ) 114