xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_tanh.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import numpy as np
10
11import serializer.tosa_serializer as ts
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17
18from executorch.backends.arm.tosa_quant_utils import (
19    dequantize_value,
20    get_quant_arg_downstream,
21    get_quant_arg_upstream,
22    QuantArgs,
23    quantize_value,
24)
25from serializer.tosa_serializer import TosaOp
26from torch.fx import Node
27
28
29@register_node_visitor
30class TanhVisitor(NodeVisitor):
31    target = "aten.tanh.default"
32
33    def __init__(self, *args):
34        super().__init__(*args)
35
36    def define_node(
37        self,
38        node: Node,
39        tosa_graph: ts.TosaSerializer,
40        inputs: List[TosaArg],
41        output: TosaArg,
42        is_quant_node: bool,
43    ) -> None:
44
45        assert len(node.all_input_nodes) == 1
46
47        if is_quant_node:
48            # Assume quantized input is 8 bit.
49            assert len(node.users) == 1
50
51            # Create attribute for 8 bit table lookup.
52            input_node = node.all_input_nodes[0]
53            in_quantargs = get_quant_arg_upstream(input_node)
54            output_node = list(node.users)[0]
55            out_quantargs = get_quant_arg_downstream(output_node)
56
57            table = tanh_table_8bit(in_quantargs, out_quantargs)
58            table_attr = ts.TosaSerializerAttribute()
59            table_attr.TableAttribute(table)
60
61            tosa_graph.addOperator(
62                TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
63            )
64        else:
65            tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name])
66
67
68def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
69    """
70    Returns a table mapping 256 entries to tanh([qmin,qmax])
71    Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh
72    """
73
74    def tanh(x):
75        # Convert quantized input to floating point tanh input space.
76        v = dequantize_value(x, in_quantargs)
77        # Compute tanh.
78        v = np.exp(-2.0 * v)
79        v = (1.0 - v) / (1.0 + v)
80
81        # Convert tanh output back to quantized space.
82        return quantize_value(v, out_quantargs)
83
84    return [
85        tanh(x)
86        for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
87    ]
88