1 /* 2 * Copyright (c) Qualcomm Innovation Center, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <executorch/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h> 11 #include <executorch/backends/qualcomm/runtime/Logging.h> 12 #include <executorch/runtime/core/error.h> 13 14 #include <memory> 15 #include <string> 16 17 #include "QnnTypes.h" 18 19 #define QNN_VER_PTR(x) (&((x).v1)) 20 namespace executorch { 21 namespace backends { 22 namespace qnn { 23 class TensorWrapper { 24 public: 25 explicit TensorWrapper( 26 const std::string& tensor_name, 27 Qnn_TensorType_t tensor_type, 28 Qnn_DataType_t data_type, 29 std::unique_ptr<QuantizeParamsWrapper> quantize_params, 30 std::uint32_t rank, 31 const std::uint32_t dims[], 32 std::uint32_t bytes, 33 const void* data = nullptr, 34 bool copy_data = false); 35 36 executorch::runtime::Error FillDataBuffer( 37 const void* data, 38 bool copy_data = false); 39 40 executorch::runtime::Error AllocateDataBuffer(); 41 42 // update qnn tensor meta 43 // this function is used to recover metadata from QNN context binary. 44 void UpdateQnnTensorMeta(const Qnn_Tensor_t& tensor_src); 45 CloneTensorStruct()46 Qnn_Tensor_t CloneTensorStruct() const { 47 return tensor_; 48 }; 49 50 // Return true if the tensor_handle_ is not null, and has been created: IsTensorCreated()51 bool IsTensorCreated() const { 52 return created_; 53 }; 54 SetTensorCreated()55 void SetTensorCreated() { 56 created_ = true; 57 } 58 59 // Return true if the tensor is static: IsTensorStatic()60 bool IsTensorStatic() const { 61 return QNN_VER_PTR(tensor_)->type == QNN_TENSOR_TYPE_STATIC; 62 }; 63 GetDims()64 std::uint32_t* GetDims() const { 65 return QNN_VER_PTR(tensor_)->dimensions; 66 }; 67 GetDataType()68 Qnn_DataType_t GetDataType() const { 69 return QNN_VER_PTR(tensor_)->dataType; 70 }; 71 GetMemHandle()72 Qnn_MemHandle_t const GetMemHandle() { 73 return QNN_VER_PTR(tensor_)->memHandle; 74 }; 75 GetMemType()76 Qnn_TensorMemType_t GetMemType() const { 77 return QNN_VER_PTR(tensor_)->memType; 78 }; 79 GetQuantizeParams()80 Qnn_QuantizeParams_t GetQuantizeParams() const { 81 return QNN_VER_PTR(tensor_)->quantizeParams; 82 } 83 GetName()84 const std::string& GetName() const { 85 return qnn_tensor_name_; 86 }; 87 GetRank()88 std::uint32_t GetRank() const { 89 return QNN_VER_PTR(tensor_)->rank; 90 }; 91 GetBytes()92 std::uint32_t GetBytes() const { 93 return bytes_; 94 }; 95 GetStaticTensorData()96 const void* GetStaticTensorData() const { 97 return QNN_VER_PTR(tensor_)->clientBuf.data; 98 }; 99 100 executorch::runtime::Error SetName(const std::string& name); 101 102 executorch::runtime::Error SetMemHandle(Qnn_MemHandle_t mem_handle); 103 104 private: 105 // need this to handle QNN_TENSOR_ERROR_NAME_HASH_COLLISION 106 std::string qnn_tensor_name_; 107 std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper_; 108 std::vector<std::uint32_t> dims_; 109 std::uint32_t bytes_{0}; 110 std::unique_ptr<char[]> owned_data_; 111 bool created_{false}; 112 113 Qnn_Tensor_t tensor_ = QNN_TENSOR_INIT; 114 }; 115 // base function for Create TensorWrapper 116 std::shared_ptr<TensorWrapper> CreateTensorWrapper( 117 const std::string& tensor_name, 118 Qnn_TensorType_t tensor_type, 119 Qnn_DataType_t data_type, 120 std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper, 121 std::uint32_t rank, 122 const std::uint32_t dims[], 123 std::uint32_t bytes = 0, 124 const void* data = nullptr, 125 bool copy_data = false); 126 127 // Factory function to create TensorWrapper 128 std::shared_ptr<TensorWrapper> CreateTensorWrapper( 129 Qnn_TensorType_t tensor_type, 130 Qnn_DataType_t data_type, 131 std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper, 132 std::uint32_t rank, 133 const std::uint32_t dims[], 134 std::uint32_t bytes, 135 const void* data = nullptr, 136 bool copy_data = false); 137 138 std::shared_ptr<TensorWrapper> CreateTensorWrapper(const Qnn_Tensor_t& tensor); 139 140 // Utility to get size in bytes of QNN data type 141 std::uint32_t GetDataTypeSize(Qnn_DataType_t data_type); 142 } // namespace qnn 143 } // namespace backends 144 } // namespace executorch 145