1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 XNNGraph, 16 XNNNegate, 17 XNode, 18) 19 20 21@register_node_visitor 22class NegateVisitor(NodeVisitor): 23 target = "aten.neg.default" 24 25 def __init__(self, *args) -> None: 26 super().__init__(*args) 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 xnn_graph: XNNGraph, 32 vals_to_ids: Dict[torch.fx.Node, int], 33 debug_handle: int, 34 ) -> None: 35 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 36 37 # input_id 38 input_id = vals_to_ids[node.all_input_nodes[0]] 39 40 # output 41 output_id = vals_to_ids[node] 42 43 ser_node = XNode( 44 xnode_union=XNNNegate( 45 input_id=input_id, 46 output_id=output_id, 47 flags=0, 48 ), 49 debug_handle=debug_handle, 50 ) 51 xnn_graph.xnodes.append(ser_node) 52