1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10 11import serializer.tosa_serializer as ts 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17 18from executorch.backends.arm.tosa_quant_utils import ( 19 dequantize_value, 20 get_quant_arg_downstream, 21 get_quant_arg_upstream, 22 QuantArgs, 23 quantize_value, 24) 25from serializer.tosa_serializer import TosaOp 26from torch.fx import Node 27 28 29@register_node_visitor 30class TanhVisitor(NodeVisitor): 31 target = "aten.tanh.default" 32 33 def __init__(self, *args): 34 super().__init__(*args) 35 36 def define_node( 37 self, 38 node: Node, 39 tosa_graph: ts.TosaSerializer, 40 inputs: List[TosaArg], 41 output: TosaArg, 42 is_quant_node: bool, 43 ) -> None: 44 45 assert len(node.all_input_nodes) == 1 46 47 if is_quant_node: 48 # Assume quantized input is 8 bit. 49 assert len(node.users) == 1 50 51 # Create attribute for 8 bit table lookup. 52 input_node = node.all_input_nodes[0] 53 in_quantargs = get_quant_arg_upstream(input_node) 54 output_node = list(node.users)[0] 55 out_quantargs = get_quant_arg_downstream(output_node) 56 57 table = tanh_table_8bit(in_quantargs, out_quantargs) 58 table_attr = ts.TosaSerializerAttribute() 59 table_attr.TableAttribute(table) 60 61 tosa_graph.addOperator( 62 TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr 63 ) 64 else: 65 tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) 66 67 68def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): 69 """ 70 Returns a table mapping 256 entries to tanh([qmin,qmax]) 71 Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh 72 """ 73 74 def tanh(x): 75 # Convert quantized input to floating point tanh input space. 76 v = dequantize_value(x, in_quantargs) 77 # Compute tanh. 78 v = np.exp(-2.0 * v) 79 v = (1.0 - v) / (1.0 + v) 80 81 # Convert tanh output back to quantized space. 82 return quantize_value(v, out_quantargs) 83 84 return [ 85 tanh(x) 86 for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) 87 ] 88