xref: /aosp_15_r20/external/executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <executorch/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h>
11 #include <executorch/backends/qualcomm/runtime/Logging.h>
12 #include <executorch/runtime/core/error.h>
13 
14 #include <memory>
15 #include <string>
16 
17 #include "QnnTypes.h"
18 
19 #define QNN_VER_PTR(x) (&((x).v1))
20 namespace executorch {
21 namespace backends {
22 namespace qnn {
23 class TensorWrapper {
24  public:
25   explicit TensorWrapper(
26       const std::string& tensor_name,
27       Qnn_TensorType_t tensor_type,
28       Qnn_DataType_t data_type,
29       std::unique_ptr<QuantizeParamsWrapper> quantize_params,
30       std::uint32_t rank,
31       const std::uint32_t dims[],
32       std::uint32_t bytes,
33       const void* data = nullptr,
34       bool copy_data = false);
35 
36   executorch::runtime::Error FillDataBuffer(
37       const void* data,
38       bool copy_data = false);
39 
40   executorch::runtime::Error AllocateDataBuffer();
41 
42   // update qnn tensor meta
43   // this function is used to recover metadata from QNN context binary.
44   void UpdateQnnTensorMeta(const Qnn_Tensor_t& tensor_src);
45 
CloneTensorStruct()46   Qnn_Tensor_t CloneTensorStruct() const {
47     return tensor_;
48   };
49 
50   // Return true if the tensor_handle_ is not null, and has been created:
IsTensorCreated()51   bool IsTensorCreated() const {
52     return created_;
53   };
54 
SetTensorCreated()55   void SetTensorCreated() {
56     created_ = true;
57   }
58 
59   // Return true if the tensor is static:
IsTensorStatic()60   bool IsTensorStatic() const {
61     return QNN_VER_PTR(tensor_)->type == QNN_TENSOR_TYPE_STATIC;
62   };
63 
GetDims()64   std::uint32_t* GetDims() const {
65     return QNN_VER_PTR(tensor_)->dimensions;
66   };
67 
GetDataType()68   Qnn_DataType_t GetDataType() const {
69     return QNN_VER_PTR(tensor_)->dataType;
70   };
71 
GetMemHandle()72   Qnn_MemHandle_t const GetMemHandle() {
73     return QNN_VER_PTR(tensor_)->memHandle;
74   };
75 
GetMemType()76   Qnn_TensorMemType_t GetMemType() const {
77     return QNN_VER_PTR(tensor_)->memType;
78   };
79 
GetQuantizeParams()80   Qnn_QuantizeParams_t GetQuantizeParams() const {
81     return QNN_VER_PTR(tensor_)->quantizeParams;
82   }
83 
GetName()84   const std::string& GetName() const {
85     return qnn_tensor_name_;
86   };
87 
GetRank()88   std::uint32_t GetRank() const {
89     return QNN_VER_PTR(tensor_)->rank;
90   };
91 
GetBytes()92   std::uint32_t GetBytes() const {
93     return bytes_;
94   };
95 
GetStaticTensorData()96   const void* GetStaticTensorData() const {
97     return QNN_VER_PTR(tensor_)->clientBuf.data;
98   };
99 
100   executorch::runtime::Error SetName(const std::string& name);
101 
102   executorch::runtime::Error SetMemHandle(Qnn_MemHandle_t mem_handle);
103 
104  private:
105   // need this to handle QNN_TENSOR_ERROR_NAME_HASH_COLLISION
106   std::string qnn_tensor_name_;
107   std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper_;
108   std::vector<std::uint32_t> dims_;
109   std::uint32_t bytes_{0};
110   std::unique_ptr<char[]> owned_data_;
111   bool created_{false};
112 
113   Qnn_Tensor_t tensor_ = QNN_TENSOR_INIT;
114 };
115 // base function for Create TensorWrapper
116 std::shared_ptr<TensorWrapper> CreateTensorWrapper(
117     const std::string& tensor_name,
118     Qnn_TensorType_t tensor_type,
119     Qnn_DataType_t data_type,
120     std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper,
121     std::uint32_t rank,
122     const std::uint32_t dims[],
123     std::uint32_t bytes = 0,
124     const void* data = nullptr,
125     bool copy_data = false);
126 
127 // Factory function to create TensorWrapper
128 std::shared_ptr<TensorWrapper> CreateTensorWrapper(
129     Qnn_TensorType_t tensor_type,
130     Qnn_DataType_t data_type,
131     std::unique_ptr<QuantizeParamsWrapper> quantize_param_wrapper,
132     std::uint32_t rank,
133     const std::uint32_t dims[],
134     std::uint32_t bytes,
135     const void* data = nullptr,
136     bool copy_data = false);
137 
138 std::shared_ptr<TensorWrapper> CreateTensorWrapper(const Qnn_Tensor_t& tensor);
139 
140 // Utility to get size in bytes of QNN data type
141 std::uint32_t GetDataTypeSize(Qnn_DataType_t data_type);
142 } // namespace qnn
143 } // namespace backends
144 } // namespace executorch
145