1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10 11import serializer.tosa_serializer as ts 12import torch 13from executorch.backends.arm.operators.node_visitor import ( 14 NodeVisitor, 15 register_node_visitor, 16) 17from executorch.backends.arm.tosa_mapping import TosaArg 18from executorch.backends.arm.tosa_quant_utils import ( 19 dequantize_value, 20 get_quant_arg_downstream, 21 get_quant_arg_upstream, 22 QuantArgs, 23 quantize_value, 24) 25from serializer.tosa_serializer import TosaOp 26 27 28@register_node_visitor 29class DivVisitor(NodeVisitor): 30 target = "aten.reciprocal.default" 31 32 def __init__(self, *args): 33 super().__init__(*args) 34 35 def define_node( 36 self, 37 node: torch.fx.Node, 38 tosa_graph: ts.TosaSerializer, 39 inputs: List[TosaArg], 40 output: TosaArg, 41 is_quant_node: bool, 42 ) -> None: 43 # 1/X 44 45 if is_quant_node: 46 input = inputs[0] 47 input_qargs = get_quant_arg_upstream(node.all_input_nodes[0]) 48 output_qargs = get_quant_arg_downstream(list(node.users)[0]) 49 50 div_table = div_table_8bit(input_qargs, output_qargs) 51 52 table_attr = ts.TosaSerializerAttribute() 53 table_attr.TableAttribute(div_table) 54 tosa_graph.addOperator( 55 TosaOp.Op().TABLE, [input.name], [output.name], table_attr 56 ) 57 58 else: 59 tosa_graph.addOperator( 60 TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] 61 ) 62 63 64def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): 65 """ 66 Returns a table mapping 256 entries to div([qmin,qmax]) 67 """ 68 69 def div(x): 70 # Convert quantized input to floating point div input space. 71 v1 = dequantize_value(x, in_quantargs) 72 # Compute div. 73 v2 = 1.0 / v1 74 # Convert div output back to quantized space. 75 v3 = quantize_value(v2, out_quantargs) 76 77 return v3 78 79 return [ 80 div(x) 81 for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) 82 ] 83