xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_skip_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10
11import torch
12
13from .node_visitor import NodeVisitor, register_node_visitor
14
15
16class OpSkipOps(NodeVisitor):
17    """
18    Parent Class for handling Skip Ops
19    """
20
21    def __init__(self, *args) -> None:
22        super().__init__(*args)
23
24    def define_node(
25        self,
26        node: torch.fx.Node,
27        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28    ) -> None:
29        return
30
31
32@register_node_visitor
33class OpGetItem(OpSkipOps):
34    """
35    do nothing if node is getitem
36    """
37
38    target = ["getitem"]
39
40    def define_node(
41        self,
42        node: torch.fx.Node,
43        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
44    ) -> None:
45        if isinstance(node.args[1], tuple) or isinstance(node.args[1], list):
46            raise AssertionError(
47                f"Invalid number of index for {node.name }: {len(node.args[1])}"
48            )
49        idx = node.args[1]
50        # to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
51        nodes_to_wrappers[node.name] = {
52            0: nodes_to_wrappers.get(node.args[0].name).get(idx)
53        }
54        return
55