1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import torch 12 13from .node_visitor import NodeVisitor, register_node_visitor 14 15 16class OpSkipOps(NodeVisitor): 17 """ 18 Parent Class for handling Skip Ops 19 """ 20 21 def __init__(self, *args) -> None: 22 super().__init__(*args) 23 24 def define_node( 25 self, 26 node: torch.fx.Node, 27 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 28 ) -> None: 29 return 30 31 32@register_node_visitor 33class OpGetItem(OpSkipOps): 34 """ 35 do nothing if node is getitem 36 """ 37 38 target = ["getitem"] 39 40 def define_node( 41 self, 42 node: torch.fx.Node, 43 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 44 ) -> None: 45 if isinstance(node.args[1], tuple) or isinstance(node.args[1], list): 46 raise AssertionError( 47 f"Invalid number of index for {node.name }: {len(node.args[1])}" 48 ) 49 idx = node.args[1] 50 # to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], 51 nodes_to_wrappers[node.name] = { 52 0: nodes_to_wrappers.get(node.args[0].name).get(idx) 53 } 54 return 55