xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_reciprocal.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import numpy as np
10
11import serializer.tosa_serializer as ts
12import torch
13from executorch.backends.arm.operators.node_visitor import (
14    NodeVisitor,
15    register_node_visitor,
16)
17from executorch.backends.arm.tosa_mapping import TosaArg
18from executorch.backends.arm.tosa_quant_utils import (
19    dequantize_value,
20    get_quant_arg_downstream,
21    get_quant_arg_upstream,
22    QuantArgs,
23    quantize_value,
24)
25from serializer.tosa_serializer import TosaOp
26
27
28@register_node_visitor
29class DivVisitor(NodeVisitor):
30    target = "aten.reciprocal.default"
31
32    def __init__(self, *args):
33        super().__init__(*args)
34
35    def define_node(
36        self,
37        node: torch.fx.Node,
38        tosa_graph: ts.TosaSerializer,
39        inputs: List[TosaArg],
40        output: TosaArg,
41        is_quant_node: bool,
42    ) -> None:
43        # 1/X
44
45        if is_quant_node:
46            input = inputs[0]
47            input_qargs = get_quant_arg_upstream(node.all_input_nodes[0])
48            output_qargs = get_quant_arg_downstream(list(node.users)[0])
49
50            div_table = div_table_8bit(input_qargs, output_qargs)
51
52            table_attr = ts.TosaSerializerAttribute()
53            table_attr.TableAttribute(div_table)
54            tosa_graph.addOperator(
55                TosaOp.Op().TABLE, [input.name], [output.name], table_attr
56            )
57
58        else:
59            tosa_graph.addOperator(
60                TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
61            )
62
63
64def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
65    """
66    Returns a table mapping 256 entries to div([qmin,qmax])
67    """
68
69    def div(x):
70        # Convert quantized input to floating point div input space.
71        v1 = dequantize_value(x, in_quantargs)
72        # Compute div.
73        v2 = 1.0 / v1
74        # Convert div output back to quantized space.
75        v3 = quantize_value(v2, out_quantargs)
76
77        return v3
78
79    return [
80        div(x)
81        for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
82    ]
83