xref: /aosp_15_r20/external/executorch/backends/qualcomm/runtime/backends/QnnGraphCommon.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #include <executorch/backends/qualcomm/runtime/backends/QnnGraphCommon.h>
9 namespace executorch {
10 namespace backends {
11 namespace qnn {
12 
13 using executorch::runtime::Error;
14 
Configure(const std::string & graph_name)15 Error QnnGraph::Configure(const std::string& graph_name) {
16   // create qnn backend
17   const QnnInterface& qnn_interface = implementation_.GetQnnInterface();
18   Qnn_ErrorHandle_t error = QNN_SUCCESS;
19 
20   std::vector<const QnnGraph_Config_t*> temp_graph_config;
21   ET_CHECK_OR_RETURN_ERROR(
22       MakeConfig(temp_graph_config) == Error::Ok,
23       Internal,
24       "Fail to make graph config.");
25 
26   if (handle_.count(graph_name)) {
27     QNN_EXECUTORCH_LOG_ERROR(
28         "Graph '%s' has been configured.", graph_name.c_str());
29     return Error::Ok;
30   }
31 
32   Qnn_GraphHandle_t graph_handle = nullptr;
33   if (context_->GetCacheState() == QnnBackendCache::DESERIALIZE) {
34     // retrieve QNN Graph
35     error = qnn_interface.qnn_graph_retrieve(
36         context_->GetHandle(), graph_name.c_str(), &graph_handle);
37     if (error != QNN_SUCCESS) {
38       QNN_EXECUTORCH_LOG_ERROR(
39           "Can't retrieve graph "
40           "%s from context. Error %d.",
41           graph_name.c_str(),
42           QNN_GET_ERROR_CODE(error));
43       return Error::Internal;
44     }
45   } else if (
46       context_->GetCacheState() == QnnBackendCache::SERIALIZE ||
47       context_->GetCacheState() == QnnBackendCache::ONLINE_PREPARE) {
48     Qnn_ErrorHandle_t error = qnn_interface.qnn_graph_create(
49         context_->GetHandle(),
50         graph_name.c_str(),
51         temp_graph_config.empty() ? nullptr : temp_graph_config.data(),
52         &graph_handle);
53 
54     if (error != QNN_SUCCESS) {
55       QNN_EXECUTORCH_LOG_ERROR(
56           "qnn_graph_create failed. Error  %d", QNN_GET_ERROR_CODE(error));
57       return Error::Internal;
58     }
59   } else {
60     QNN_EXECUTORCH_LOG_ERROR("QNN context cache is invalid.");
61     return Error::Internal;
62   }
63 
64   // book keep valid handle of created graph
65   handle_[graph_name] = graph_handle;
66   // The profiler needs to be created after the backend is created.
67   profile_[graph_name] =
68       std::make_unique<QnnProfile>(implementation_, backend_, profile_level_);
69   return Error::Ok;
70 }
71 
GraphExecute(const std::string & graph_name,const std::vector<Qnn_Tensor_t> & input_tensor_structs,std::vector<Qnn_Tensor_t> & output_tensor_structs)72 Qnn_ErrorHandle_t QnnGraph::GraphExecute(
73     const std::string& graph_name,
74     const std::vector<Qnn_Tensor_t>& input_tensor_structs,
75     std::vector<Qnn_Tensor_t>& output_tensor_structs) {
76   if (!handle_.count(graph_name)) {
77     QNN_EXECUTORCH_LOG_ERROR(
78         "graph name: %s does not exist.", graph_name.c_str());
79     return QNN_COMMON_ERROR_GENERAL;
80   }
81 
82   return implementation_.GetQnnInterface().qnn_graph_execute(
83       handle_[graph_name],
84       input_tensor_structs.data(),
85       input_tensor_structs.size(),
86       output_tensor_structs.data(),
87       output_tensor_structs.size(),
88       profile_[graph_name]->GetHandle(),
89       /*signalHandle=*/nullptr);
90 };
91 
EnsureTensorInQnnGraph(const std::string & graph_name,const std::shared_ptr<TensorWrapper> & tensor_wrapper)92 Error QnnGraph::EnsureTensorInQnnGraph(
93     const std::string& graph_name,
94     const std::shared_ptr<TensorWrapper>& tensor_wrapper) {
95   const QnnInterface& qnn_interface = implementation_.GetQnnInterface();
96   Qnn_ErrorHandle_t error = QNN_SUCCESS;
97 
98   if (!tensor_wrapper->IsTensorCreated()) {
99     Qnn_Tensor_t tensor = tensor_wrapper->CloneTensorStruct();
100 
101     error = qnn_interface.qnn_tensor_create_graph_tensor(
102         handle_[graph_name], &tensor);
103 
104     int name_conflict_count = 0;
105     while (error == QNN_TENSOR_ERROR_NAME_HASH_COLLISION) {
106       const std::string& old_name = tensor_wrapper->GetName();
107 
108       std::string new_name =
109           old_name + "_" + std::to_string(name_conflict_count);
110       tensor_wrapper->SetName(new_name);
111       QNN_VER_PTR(tensor)->name = new_name.c_str();
112 
113       QNN_EXECUTORCH_LOG_INFO(
114           "tensor name %s hash collision, change to %s",
115           old_name.c_str(),
116           new_name.c_str());
117 
118       // update
119       name_conflict_count++;
120       error = qnn_interface.qnn_tensor_create_graph_tensor(
121           handle_[graph_name], &tensor);
122     }
123     tensor_wrapper->UpdateQnnTensorMeta(tensor);
124     tensor_wrapper->SetTensorCreated();
125   }
126   return Error::Ok;
127 }
128 } // namespace qnn
129 } // namespace backends
130 } // namespace executorch
131