1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 get_input_node, 12 InputTypeToIndex, 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 17 XNNFullyConnected, 18 XNNGraph, 19 XNode, 20) 21 22from executorch.backends.xnnpack.utils.xnnpack_constants import ( 23 XNN_FLAG_TRANSPOSE_WEIGHTS, 24) 25 26 27@register_node_visitor 28class AddmmVisitor(NodeVisitor): 29 target = "aten.addmm.default" 30 31 def __init__(self, *args) -> None: 32 super().__init__(*args) 33 34 def define_node( 35 self, 36 node: torch.fx.Node, 37 xnn_graph: XNNGraph, 38 vals_to_ids: Dict[torch.fx.Node, int], 39 debug_handle: int, 40 ) -> None: 41 input_type_map = InputTypeToIndex(node_input=1, node_weight=2, node_bias=0) 42 self.define_nodes_tensor_inputs_outputs( 43 node, xnn_graph, vals_to_ids, input_type_map=input_type_map 44 ) 45 46 # bias 47 bias_id = vals_to_ids[get_input_node(node, 0)] 48 49 # input 50 input_id = vals_to_ids[get_input_node(node, 1)] 51 52 # filter 53 filter_id = vals_to_ids[get_input_node(node, 2)] 54 55 # output 56 output_id = vals_to_ids[node] 57 58 flag = XNN_FLAG_TRANSPOSE_WEIGHTS 59 60 ser_node = XNode( 61 xnode_union=XNNFullyConnected( 62 input1_id=input_id, 63 filter_id=filter_id, 64 bias_id=bias_id, 65 output_id=output_id, 66 flags=flag, 67 ), 68 debug_handle=debug_handle, 69 ) 70 xnn_graph.xnodes.append(ser_node) 71