1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 OutputMinMax, 16 XNNClamp, 17 XNNGraph, 18 XNode, 19) 20 21 22@register_node_visitor 23class ReluVisitor(NodeVisitor): 24 target = "aten.clamp.default" 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 xnn_graph: XNNGraph, 33 vals_to_ids: Dict[torch.fx.Node, int], 34 debug_handle: int, 35 ) -> None: 36 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 37 38 min_val = "-inf" 39 max_val = "inf" 40 41 if len(node.args) >= 2 and node.args[1] is not None: 42 min_val = cast(float, node.args[1]) 43 44 if len(node.args) >= 3 and node.args[2] is not None: 45 max_val = cast(float, node.args[2]) 46 47 # input_id 48 input_id = vals_to_ids[node.all_input_nodes[0]] 49 50 # output 51 output_id = vals_to_ids[node] 52 53 output_min_max = OutputMinMax(output_min=min_val, output_max=max_val) 54 55 ser_node = XNode( 56 xnode_union=XNNClamp( 57 input_id=input_id, 58 output_id=output_id, 59 flags=0, 60 ), 61 debug_handle=debug_handle, 62 output_min_max=output_min_max, 63 ) 64 xnn_graph.xnodes.append(ser_node) 65