1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8from typing import List 9 10import executorch.backends.arm.tosa_quant_utils as tqutils 11import executorch.backends.arm.tosa_utils as tutils 12 13import serializer.tosa_serializer as ts 14import torch 15from executorch.backends.arm.operators.node_visitor import ( 16 NodeVisitor, 17 register_node_visitor, 18) 19from executorch.backends.arm.tosa_mapping import TosaArg 20from executorch.backends.arm.tosa_specification import TosaSpecification 21from serializer.tosa_serializer import TosaOp 22from torch.fx import Node 23 24 25@register_node_visitor 26class AddVisitor_080_BI(NodeVisitor): 27 target = "aten.add.Tensor" 28 29 tosa_specs = [ 30 TosaSpecification.create_from_string("TOSA-0.80.0+BI"), 31 ] 32 33 def __init__(self, *args): 34 super().__init__(*args) 35 36 def define_node( 37 self, 38 node: Node, 39 tosa_graph: ts.TosaSerializer, 40 inputs: List[TosaArg], 41 output: TosaArg, 42 is_quant_node: bool, 43 ) -> None: 44 input_nodes = tutils.get_two_inputs(node) 45 46 if not is_quant_node and not all( 47 tensor.meta["val"].dtype in (torch.int8, torch.int32) 48 for tensor in input_nodes 49 ): 50 raise RuntimeError( 51 f"Unexpected non quantized {AddVisitor_080_BI.target} node." 52 ) 53 54 needs_rescale = not ( 55 all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) 56 and node.meta["val"].dtype == torch.int32 57 ) 58 59 if needs_rescale: 60 # Rescale inputs to 32 bit 61 rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( 62 input_nodes, tosa_graph 63 ) 64 65 # Prepare add output tensor 66 broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) 67 add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) 68 else: 69 add_output = output 70 rescaled_inputs = inputs 71 72 # Do the INT32 Add 73 tosa_graph.addOperator( 74 TosaOp.Op().ADD, 75 [ 76 rescaled_inputs[0].name, 77 rescaled_inputs[1].name, 78 ], 79 [add_output.name], 80 None, 81 ) 82 83 if needs_rescale: 84 # Scale output back to 8 bit 85 # pyre-ignore 86 tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) 87 88 89@register_node_visitor 90class AddVisitor_080_MI(AddVisitor_080_BI): 91 # inheriting 'target' from BI class 92 93 tosa_specs = [ 94 TosaSpecification.create_from_string("TOSA-0.80.0+MI"), 95 ] 96 97 def __init__(self, *args): 98 super().__init__(*args) 99 100 def define_node( 101 self, 102 node: Node, 103 tosa_graph: ts.TosaSerializer, 104 inputs: List[TosaArg], 105 output: TosaArg, 106 is_quant_node: bool, 107 ) -> None: 108 if is_quant_node: 109 # Call the inherited define_node for handling integers 110 super().define_node(node, tosa_graph, inputs, output, is_quant_node) 111 else: 112 # FP32 Add lowering 113 tosa_graph.addOperator( 114 TosaOp.Op().ADD, 115 [inputs[0].name, inputs[1].name], 116 [output.name], 117 None, 118 ) 119