xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_relu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import executorch.backends.arm.tosa_quant_utils as tqutils
9import serializer.tosa_serializer as ts
10import torch.fx
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16from serializer.tosa_serializer import TosaOp
17
18
19@register_node_visitor
20class ReluVisitor(NodeVisitor):
21    target = "aten.relu.default"
22
23    def __init__(self, *args):
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        tosa_graph: ts.TosaSerializer,
30        inputs: list[TosaArg],
31        output: TosaArg,
32        is_quant_node: bool,
33    ) -> None:
34        attr = ts.TosaSerializerAttribute()
35
36        clamp_min_fp = 0.0
37        clamp_max_fp = 0.0
38        clamp_min_qs = 0
39        clamp_max_qs = 0
40        if is_quant_node:
41            out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0])
42            clamp_min_qs = tqutils.quantize_value(0, out_qargs)
43            clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs)
44
45        else:
46            clamp_min_fp = 0
47            clamp_max_fp = float("inf")
48
49        attr.ClampAttribute(
50            tosa_graph.builder,
51            clamp_min_qs,
52            clamp_max_qs,
53            clamp_min_fp,
54            clamp_max_fp,
55        )
56
57        tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)
58