1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import executorch.backends.arm.tosa_quant_utils as tqutils 9import serializer.tosa_serializer as ts 10import torch.fx 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16from serializer.tosa_serializer import TosaOp 17 18 19@register_node_visitor 20class ReluVisitor(NodeVisitor): 21 target = "aten.relu.default" 22 23 def __init__(self, *args): 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 tosa_graph: ts.TosaSerializer, 30 inputs: list[TosaArg], 31 output: TosaArg, 32 is_quant_node: bool, 33 ) -> None: 34 attr = ts.TosaSerializerAttribute() 35 36 clamp_min_fp = 0.0 37 clamp_max_fp = 0.0 38 clamp_min_qs = 0 39 clamp_max_qs = 0 40 if is_quant_node: 41 out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0]) 42 clamp_min_qs = tqutils.quantize_value(0, out_qargs) 43 clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) 44 45 else: 46 clamp_min_fp = 0 47 clamp_max_fp = float("inf") 48 49 attr.ClampAttribute( 50 tosa_graph.builder, 51 clamp_min_qs, 52 clamp_max_qs, 53 clamp_min_fp, 54 clamp_max_fp, 55 ) 56 57 tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) 58