1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import cast, List 9 10import executorch.backends.arm.tosa_quant_utils as tqutils 11import executorch.backends.arm.tosa_utils as tutils 12 13import serializer.tosa_serializer as ts 14import torch 15 16from executorch.backends.arm.operators.node_visitor import ( 17 NodeVisitor, 18 register_node_visitor, 19) 20from executorch.backends.arm.tosa_mapping import TosaArg 21from serializer.tosa_serializer import TosaOp 22 23 24@register_node_visitor 25class MulVisitor(NodeVisitor): 26 target = "aten.mul.Tensor" 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 tosa_graph: ts.TosaSerializer, 32 inputs: List[TosaArg], 33 output: TosaArg, 34 is_quant_node: bool, 35 ) -> None: 36 37 if is_quant_node: 38 input_A = inputs[0] 39 input_B = inputs[1] 40 input_A_qargs = tqutils.get_quant_arg_upstream( 41 cast(torch.fx.Node, node.args[0]) 42 ) 43 input_B_qargs = tqutils.get_quant_arg_upstream( 44 cast(torch.fx.Node, node.args[1]) 45 ) 46 47 input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) 48 input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) 49 output_shape = tutils.tosa_shape(output.shape, output.dim_order) 50 51 # Rescale inputs to INT32 with zp=0 52 input_A_rescaled = tqutils.build_rescale_to_int32( 53 tosa_graph, 54 input_A, 55 input_A_qargs.zp, 56 rescale_scale=1.0, 57 ) 58 input_B_rescaled = tqutils.build_rescale_to_int32( 59 tosa_graph, 60 input_B, 61 input_B_qargs.zp, 62 rescale_scale=1.0, 63 ) 64 65 mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) 66 67 # Do the INT32 Mul 68 attr = ts.TosaSerializerAttribute() 69 attr.MulAttribute(shift=0) 70 tosa_graph.addOperator( 71 TosaOp.Op().MUL, 72 [ 73 input_A_rescaled.name, 74 input_B_rescaled.name, 75 ], 76 [mul_output.name], 77 attr, 78 ) 79 80 tqutils.rescale_node_back_to_int8( 81 node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph 82 ) 83 84 else: 85 attr = ts.TosaSerializerAttribute() 86 attr.MulAttribute(shift=0) 87 tosa_graph.addOperator( 88 TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr 89 ) 90