1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.operators.quant_params import QuantParams 16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 17 XNNAdd, 18 XNNGraph, 19 XNode, 20) 21 22from executorch.backends.xnnpack.utils.utils import get_input_node 23 24 25@register_node_visitor 26class AddVisitor(NodeVisitor): 27 target = "aten.add.Tensor" 28 29 def __init__(self, *args) -> None: 30 super().__init__(*args) 31 32 def define_node( 33 self, 34 node: torch.fx.Node, 35 xnn_graph: XNNGraph, 36 vals_to_ids: Dict[torch.fx.Node, int], 37 debug_handle: int, 38 ) -> None: 39 # input1 40 input1 = get_input_node(node, 0) 41 self.define_tensor( 42 input1, 43 xnn_graph, 44 vals_to_ids, 45 quant_params=QuantParams.from_inputs(input1, self._exported_program), 46 ) 47 input1_id = vals_to_ids[input1] 48 49 # input2 50 input2 = get_input_node(node, 1) 51 self.define_tensor( 52 input2, 53 xnn_graph, 54 vals_to_ids, 55 quant_params=QuantParams.from_inputs(input2, self._exported_program), 56 ) 57 input2_id = vals_to_ids[input2] 58 59 # output 60 output_min_max = FuseActivationPass.get_fused_activation(node) 61 self.define_tensor( 62 node, 63 xnn_graph, 64 vals_to_ids, 65 quant_params=QuantParams.from_outputs(node), 66 ) 67 68 output_id = vals_to_ids[node] 69 70 ser_node = XNode( 71 xnode_union=XNNAdd( 72 input1_id=input1_id, input2_id=input2_id, output_id=output_id, flags=0 73 ), 74 debug_handle=debug_handle, 75 output_min_max=output_min_max, 76 ) 77 xnn_graph.xnodes.append(ser_node) 78