1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNGraph, 17 XNNPReLU, 18 XNode, 19) 20 21from executorch.backends.xnnpack.utils.utils import get_input_node 22 23 24@register_node_visitor 25class PReLUVisitor(NodeVisitor): 26 target = "aten.prelu.default" 27 28 def __init__(self, *args) -> None: 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 xnn_graph: XNNGraph, 35 vals_to_ids: Dict[torch.fx.Node, int], 36 debug_handle: int, 37 ) -> None: 38 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 39 40 # input 41 input_id = vals_to_ids[get_input_node(node, 0)] 42 weight_id = vals_to_ids[get_input_node(node, 1)] 43 44 # output 45 output_id = vals_to_ids[node] 46 47 ser_node = XNode( 48 xnode_union=XNNPReLU( 49 input1_id=input_id, 50 input2_id=weight_id, 51 output_id=output_id, 52 flags=0, 53 ), 54 debug_handle=debug_handle, 55 ) 56 xnn_graph.xnodes.append(ser_node) 57