xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_index_put.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1from typing import Dict
2
3import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
4
5import torch
6
7from .node_visitor import NodeVisitor, register_node_visitor
8from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
9
10
11@register_node_visitor
12class IndexPutVisitor(NodeVisitor):
13    target = ["aten.index_put.default"]
14
15    def __init__(self, *args) -> None:
16        super().__init__(*args)
17
18    def define_node(
19        self,
20        node: torch.fx.Node,
21        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
22    ) -> PyQnnWrapper.PyQnnOpWrapper:
23        input_node = node.args[0]
24        input_tensor = self.get_tensor(input_node, node)
25        input_tensor_wrapper = self.define_tensor(
26            input_node,
27            input_tensor,
28            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
29            nodes_to_wrappers,
30            is_input_tensor=True,
31        )
32        indicies_node = node.args[1]
33        indices_list = [
34            self.get_tensor(idx, idx) for idx in indicies_node if idx is not None
35        ]
36
37        # Unpack the tuple
38        indices_unpacked = [torch.flatten(idx) for idx in indices_list]
39
40        # Convert to 2-D tensor
41        indices_qnn = torch.cat(indices_unpacked).unsqueeze(0)
42        indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)]
43        # TODO consider to write a pass to combine to one input tensor for indices
44        assert len(indice_node) == 1, "Not support mutilple indices tensor"
45
46        indices_tensor_wrapper = self.define_tensor(
47            indice_node[0],
48            indices_qnn,
49            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
50            nodes_to_wrappers,
51            is_input_tensor=True,
52        )
53        value_node = node.args[2]
54
55        value_tensor = self.get_tensor(value_node, node)
56
57        value_tensor_wrapper = self.define_tensor(
58            value_node,
59            value_tensor,
60            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
61            nodes_to_wrappers,
62            is_input_tensor=True,
63        )
64        output_tensor = self.get_tensor(node, node)
65        output_tensor_wrapper = self.define_tensor(
66            node,
67            output_tensor,
68            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
69            nodes_to_wrappers,
70            is_input_tensor=False,
71        )
72
73        index_put_op = PyQnnWrapper.PyQnnOpWrapper(
74            node.name,
75            QNN_OP_PACKAGE_NAME_QTI_AISW,
76            OpScatterNd.op_name,
77        )
78        index_put_op.AddInputTensors(
79            [input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper]
80        )
81        index_put_op.AddOutputTensors([output_tensor_wrapper])
82
83        return index_put_op
84