1from typing import Dict 2 3import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 4 5import torch 6 7from .node_visitor import NodeVisitor, register_node_visitor 8from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW 9 10 11@register_node_visitor 12class IndexPutVisitor(NodeVisitor): 13 target = ["aten.index_put.default"] 14 15 def __init__(self, *args) -> None: 16 super().__init__(*args) 17 18 def define_node( 19 self, 20 node: torch.fx.Node, 21 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 22 ) -> PyQnnWrapper.PyQnnOpWrapper: 23 input_node = node.args[0] 24 input_tensor = self.get_tensor(input_node, node) 25 input_tensor_wrapper = self.define_tensor( 26 input_node, 27 input_tensor, 28 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 29 nodes_to_wrappers, 30 is_input_tensor=True, 31 ) 32 indicies_node = node.args[1] 33 indices_list = [ 34 self.get_tensor(idx, idx) for idx in indicies_node if idx is not None 35 ] 36 37 # Unpack the tuple 38 indices_unpacked = [torch.flatten(idx) for idx in indices_list] 39 40 # Convert to 2-D tensor 41 indices_qnn = torch.cat(indices_unpacked).unsqueeze(0) 42 indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)] 43 # TODO consider to write a pass to combine to one input tensor for indices 44 assert len(indice_node) == 1, "Not support mutilple indices tensor" 45 46 indices_tensor_wrapper = self.define_tensor( 47 indice_node[0], 48 indices_qnn, 49 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 50 nodes_to_wrappers, 51 is_input_tensor=True, 52 ) 53 value_node = node.args[2] 54 55 value_tensor = self.get_tensor(value_node, node) 56 57 value_tensor_wrapper = self.define_tensor( 58 value_node, 59 value_tensor, 60 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 61 nodes_to_wrappers, 62 is_input_tensor=True, 63 ) 64 output_tensor = self.get_tensor(node, node) 65 output_tensor_wrapper = self.define_tensor( 66 node, 67 output_tensor, 68 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 69 nodes_to_wrappers, 70 is_input_tensor=False, 71 ) 72 73 index_put_op = PyQnnWrapper.PyQnnOpWrapper( 74 node.name, 75 QNN_OP_PACKAGE_NAME_QTI_AISW, 76 OpScatterNd.op_name, 77 ) 78 index_put_op.AddInputTensors( 79 [input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper] 80 ) 81 index_put_op.AddOutputTensors([output_tensor_wrapper]) 82 83 return index_put_op 84