1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import warnings 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpTile, QNN_OP_PACKAGE_NAME_QTI_AISW 16 17 18@register_node_visitor 19class Expand(NodeVisitor): 20 target = ["aten.expand_copy.default"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 29 ) -> PyQnnWrapper.PyQnnOpWrapper: 30 input_node = node.args[0] 31 input_tensor = self.get_tensor(input_node, node) 32 input_tensor_wrapper = self.define_tensor( 33 input_node, 34 input_tensor, 35 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 36 nodes_to_wrappers, 37 is_input_tensor=True, 38 ) 39 40 output_tensor = self.get_tensor(node, node) 41 output_tensor_wrapper = self.define_tensor( 42 node, 43 output_tensor, 44 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 45 nodes_to_wrappers, 46 is_input_tensor=False, 47 ) 48 49 sizes = cast(List[int], node.args[1]) 50 51 shape = input_tensor.shape 52 input_dims = len(input_tensor.size()) 53 output_dims = len(output_tensor.size()) 54 55 if input_dims < output_dims: 56 warnings.warn( 57 f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.", 58 stacklevel=1, 59 ) 60 return 61 62 multiples = [1] * input_dims 63 multiples_shape = [input_dims] 64 for i in range(input_dims): 65 if sizes[i] != -1 and shape[i] == 1: 66 multiples[i] = sizes[i] 67 68 tile_op = PyQnnWrapper.PyQnnOpWrapper( 69 node.name, 70 QNN_OP_PACKAGE_NAME_QTI_AISW, 71 OpTile.op_name, 72 ) 73 tile_op.AddInputTensors([input_tensor_wrapper]) 74 tile_op.AddOutputTensors([output_tensor_wrapper]) 75 tile_op.AddTensorParam( 76 OpTile.param_multiples, 77 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 78 len(multiples_shape), 79 multiples_shape, 80 np.array(multiples, dtype=np.uint32), 81 True, 82 ) 83 return tile_op 84