xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/op_addmm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    get_input_node,
12    InputTypeToIndex,
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17    XNNFullyConnected,
18    XNNGraph,
19    XNode,
20)
21
22from executorch.backends.xnnpack.utils.xnnpack_constants import (
23    XNN_FLAG_TRANSPOSE_WEIGHTS,
24)
25
26
27@register_node_visitor
28class AddmmVisitor(NodeVisitor):
29    target = "aten.addmm.default"
30
31    def __init__(self, *args) -> None:
32        super().__init__(*args)
33
34    def define_node(
35        self,
36        node: torch.fx.Node,
37        xnn_graph: XNNGraph,
38        vals_to_ids: Dict[torch.fx.Node, int],
39        debug_handle: int,
40    ) -> None:
41        input_type_map = InputTypeToIndex(node_input=1, node_weight=2, node_bias=0)
42        self.define_nodes_tensor_inputs_outputs(
43            node, xnn_graph, vals_to_ids, input_type_map=input_type_map
44        )
45
46        # bias
47        bias_id = vals_to_ids[get_input_node(node, 0)]
48
49        # input
50        input_id = vals_to_ids[get_input_node(node, 1)]
51
52        # filter
53        filter_id = vals_to_ids[get_input_node(node, 2)]
54
55        # output
56        output_id = vals_to_ids[node]
57
58        flag = XNN_FLAG_TRANSPOSE_WEIGHTS
59
60        ser_node = XNode(
61            xnode_union=XNNFullyConnected(
62                input1_id=input_id,
63                filter_id=filter_id,
64                bias_id=bias_id,
65                output_id=output_id,
66                flags=flag,
67            ),
68            debug_handle=debug_handle,
69        )
70        xnn_graph.xnodes.append(ser_node)
71