xref: /aosp_15_r20/external/android-nn-driver/1.2/ArmnnDriverImpl.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017, 2023 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
7*3e777be0SXin Li #include "../ArmnnPreparedModel_1_2.hpp"
8*3e777be0SXin Li #include "../ModelToINetworkConverter.hpp"
9*3e777be0SXin Li #include "../SystemPropertiesUtils.hpp"
10*3e777be0SXin Li 
11*3e777be0SXin Li #include <armnnDeserializer/IDeserializer.hpp>
12*3e777be0SXin Li 
13*3e777be0SXin Li #include <log/log.h>
14*3e777be0SXin Li #include <sys/stat.h>
15*3e777be0SXin Li #include <chrono>
16*3e777be0SXin Li 
17*3e777be0SXin Li namespace
18*3e777be0SXin Li {
19*3e777be0SXin Li 
20*3e777be0SXin Li const char *g_RelaxedFloat32toFloat16PerformanceExecTime    = "ArmNN.relaxedFloat32toFloat16Performance.execTime";
21*3e777be0SXin Li const char *g_RelaxedFloat32toFloat16PerformancePowerUsage  = "ArmNN.relaxedFloat32toFloat16Performance.powerUsage";
22*3e777be0SXin Li 
23*3e777be0SXin Li const char *g_OperandTypeTensorFloat32PerformanceExecTime   = "Armnn.operandTypeTensorFloat32Performance.execTime";
24*3e777be0SXin Li const char *g_OperandTypeTensorFloat32PerformancePowerUsage = "Armnn.operandTypeTensorFloat32Performance.powerUsage";
25*3e777be0SXin Li 
26*3e777be0SXin Li const char *g_OperandTypeFloat32PerformanceExecTime         = "Armnn.operandTypeFloat32Performance.execTime";
27*3e777be0SXin Li const char *g_OperandTypeFloat32PerformancePowerUsage       = "Armnn.operandTypeFloat32Performance.powerUsage";
28*3e777be0SXin Li 
29*3e777be0SXin Li const char *g_OperandTypeTensorFloat16PerformanceExecTime   = "Armnn.operandTypeTensorFloat16Performance.execTime";
30*3e777be0SXin Li const char *g_OperandTypeTensorFloat16PerformancePowerUsage = "Armnn.operandTypeTensorFloat16Performance.powerUsage";
31*3e777be0SXin Li 
32*3e777be0SXin Li const char *g_OperandTypeFloat16PerformanceExecTime         = "Armnn.operandTypeFloat16Performance.execTime";
33*3e777be0SXin Li const char *g_OperandTypeFloat16PerformancePowerUsage       = "Armnn.operandTypeFloat16Performance.powerUsage";
34*3e777be0SXin Li 
35*3e777be0SXin Li const char *g_OperandTypeTensorQuant8AsymmPerformanceExecTime =
36*3e777be0SXin Li         "Armnn.operandTypeTensorQuant8AsymmPerformance.execTime";
37*3e777be0SXin Li const char *g_OperandTypeTensorQuant8AsymmPerformancePowerUsage =
38*3e777be0SXin Li         "Armnn.operandTypeTensorQuant8AsymmPerformance.powerUsage";
39*3e777be0SXin Li 
40*3e777be0SXin Li const char *g_OperandTypeTensorQuant16SymmPerformanceExecTime =
41*3e777be0SXin Li         "Armnn.operandTypeTensorQuant16SymmPerformance.execTime";
42*3e777be0SXin Li const char *g_OperandTypeTensorQuant16SymmPerformancePowerUsage =
43*3e777be0SXin Li         "Armnn.operandTypeTensorQuant16SymmPerformance.powerUsage";
44*3e777be0SXin Li 
45*3e777be0SXin Li const char *g_OperandTypeTensorQuant8SymmPerformanceExecTime =
46*3e777be0SXin Li         "Armnn.operandTypeTensorQuant8SymmPerformance.execTime";
47*3e777be0SXin Li const char *g_OperandTypeTensorQuant8SymmPerformancePowerUsage =
48*3e777be0SXin Li         "Armnn.operandTypeTensorQuant8SymmPerformance.powerUsage";
49*3e777be0SXin Li 
50*3e777be0SXin Li const char *g_OperandTypeTensorQuant8SymmPerChannelPerformanceExecTime =
51*3e777be0SXin Li     "Armnn.operandTypeTensorQuant8SymmPerChannelPerformance.execTime";
52*3e777be0SXin Li const char *g_OperandTypeTensorQuant8SymmPerChannelPerformancePowerUsage =
53*3e777be0SXin Li     "Armnn.operandTypeTensorQuant8SymmPerChannelPerformance.powerUsage";
54*3e777be0SXin Li 
55*3e777be0SXin Li 
56*3e777be0SXin Li const char *g_OperandTypeTensorInt32PerformanceExecTime     = "Armnn.operandTypeTensorInt32Performance.execTime";
57*3e777be0SXin Li const char *g_OperandTypeTensorInt32PerformancePowerUsage   = "Armnn.operandTypeTensorInt32Performance.powerUsage";
58*3e777be0SXin Li 
59*3e777be0SXin Li const char *g_OperandTypeInt32PerformanceExecTime           = "Armnn.operandTypeInt32Performance.execTime";
60*3e777be0SXin Li const char *g_OperandTypeInt32PerformancePowerUsage         = "Armnn.operandTypeInt32Performance.powerUsage";
61*3e777be0SXin Li 
62*3e777be0SXin Li 
NotifyCallbackAndCheck(const android::sp<V1_2::IPreparedModelCallback> & callback,V1_0::ErrorStatus errorStatus,const android::sp<V1_2::IPreparedModel> & preparedModelPtr)63*3e777be0SXin Li void NotifyCallbackAndCheck(const android::sp<V1_2::IPreparedModelCallback>& callback,
64*3e777be0SXin Li                             V1_0::ErrorStatus errorStatus,
65*3e777be0SXin Li                             const android::sp<V1_2::IPreparedModel>& preparedModelPtr)
66*3e777be0SXin Li {
67*3e777be0SXin Li     Return<void> returned = callback->notify_1_2(errorStatus, preparedModelPtr);
68*3e777be0SXin Li     // This check is required, if the callback fails and it isn't checked it will bring down the service
69*3e777be0SXin Li     if (!returned.isOk())
70*3e777be0SXin Li     {
71*3e777be0SXin Li         ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
72*3e777be0SXin Li               returned.description().c_str());
73*3e777be0SXin Li     }
74*3e777be0SXin Li }
75*3e777be0SXin Li 
FailPrepareModel(V1_0::ErrorStatus error,const std::string & message,const android::sp<V1_2::IPreparedModelCallback> & callback)76*3e777be0SXin Li Return<V1_0::ErrorStatus> FailPrepareModel(V1_0::ErrorStatus error,
77*3e777be0SXin Li                                            const std::string& message,
78*3e777be0SXin Li                                            const android::sp<V1_2::IPreparedModelCallback>& callback)
79*3e777be0SXin Li {
80*3e777be0SXin Li     ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
81*3e777be0SXin Li     NotifyCallbackAndCheck(callback, error, nullptr);
82*3e777be0SXin Li     return error;
83*3e777be0SXin Li }
84*3e777be0SXin Li 
85*3e777be0SXin Li } // anonymous namespace
86*3e777be0SXin Li 
87*3e777be0SXin Li namespace armnn_driver
88*3e777be0SXin Li {
89*3e777be0SXin Li namespace hal_1_2
90*3e777be0SXin Li {
91*3e777be0SXin Li 
prepareArmnnModel_1_2(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const V1_2::Model & model,const android::hardware::hidl_vec<android::hardware::hidl_handle> & modelCacheHandle,const android::hardware::hidl_vec<android::hardware::hidl_handle> & dataCacheHandle,const HidlToken & token,const android::sp<V1_2::IPreparedModelCallback> & cb,bool float32ToFloat16)92*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareArmnnModel_1_2(
93*3e777be0SXin Li        const armnn::IRuntimePtr& runtime,
94*3e777be0SXin Li        const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
95*3e777be0SXin Li        const DriverOptions& options,
96*3e777be0SXin Li        const V1_2::Model& model,
97*3e777be0SXin Li        const android::hardware::hidl_vec<android::hardware::hidl_handle>& modelCacheHandle,
98*3e777be0SXin Li        const android::hardware::hidl_vec<android::hardware::hidl_handle>& dataCacheHandle,
99*3e777be0SXin Li        const HidlToken& token,
100*3e777be0SXin Li        const android::sp<V1_2::IPreparedModelCallback>& cb,
101*3e777be0SXin Li        bool float32ToFloat16)
102*3e777be0SXin Li {
103*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareArmnnModel_1_2()");
104*3e777be0SXin Li 
105*3e777be0SXin Li     std::chrono::time_point<std::chrono::system_clock> prepareModelTimepoint = std::chrono::system_clock::now();
106*3e777be0SXin Li 
107*3e777be0SXin Li     if (cb.get() == nullptr)
108*3e777be0SXin Li     {
109*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
110*3e777be0SXin Li         return V1_0::ErrorStatus::INVALID_ARGUMENT;
111*3e777be0SXin Li     }
112*3e777be0SXin Li 
113*3e777be0SXin Li     if (!runtime)
114*3e777be0SXin Li     {
115*3e777be0SXin Li         return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
116*3e777be0SXin Li     }
117*3e777be0SXin Li 
118*3e777be0SXin Li     if (!android::nn::validateModel(model))
119*3e777be0SXin Li     {
120*3e777be0SXin Li         return FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
121*3e777be0SXin Li     }
122*3e777be0SXin Li 
123*3e777be0SXin Li     // Deliberately ignore any unsupported operations requested by the options -
124*3e777be0SXin Li     // at this point we're being asked to prepare a model that we've already declared support for
125*3e777be0SXin Li     // and the operation indices may be different to those in getSupportedOperations anyway.
126*3e777be0SXin Li     std::set<unsigned int> unsupportedOperations;
127*3e777be0SXin Li     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
128*3e777be0SXin Li                                                        model,
129*3e777be0SXin Li                                                        unsupportedOperations);
130*3e777be0SXin Li 
131*3e777be0SXin Li     if (modelConverter.GetConversionResult() != ConversionResult::Success)
132*3e777be0SXin Li     {
133*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
134*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
135*3e777be0SXin Li     }
136*3e777be0SXin Li 
137*3e777be0SXin Li     // Serialize the network graph to a .armnn file if an output directory
138*3e777be0SXin Li     // has been specified in the drivers' arguments.
139*3e777be0SXin Li     std::vector<uint8_t> dataCacheData;
140*3e777be0SXin Li     bool serializeToFile = dataCacheHandle.size() < 1 ? false : true;
141*3e777be0SXin Li     auto serializedNetworkFileName =
142*3e777be0SXin Li         SerializeNetwork(*modelConverter.GetINetwork(),
143*3e777be0SXin Li                          options.GetRequestInputsAndOutputsDumpDir(),
144*3e777be0SXin Li                          dataCacheData,
145*3e777be0SXin Li                          serializeToFile);
146*3e777be0SXin Li 
147*3e777be0SXin Li     // Optimize the network
148*3e777be0SXin Li     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
149*3e777be0SXin Li     armnn::OptimizerOptionsOpaque OptOptions;
150*3e777be0SXin Li     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
151*3e777be0SXin Li     OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
152*3e777be0SXin Li 
153*3e777be0SXin Li     int cachedFd = -1;
154*3e777be0SXin Li     bool saveCachedNetwork = options.SaveCachedNetwork();
155*3e777be0SXin Li 
156*3e777be0SXin Li     unsigned int numberOfCachedModelFiles = 0;
157*3e777be0SXin Li     if (modelCacheHandle.size() > 0)
158*3e777be0SXin Li     {
159*3e777be0SXin Li         unsigned int index = 0;
160*3e777be0SXin Li         for (auto& backend : options.GetBackends())
161*3e777be0SXin Li         {
162*3e777be0SXin Li             // modelCacheHandle size should be equal to numberOfCachedModelFiles
163*3e777be0SXin Li             // modelCacheHandle vector should be in same order as backends
164*3e777be0SXin Li             auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
165*3e777be0SXin Li             if (numberOfCacheFiles > 0)
166*3e777be0SXin Li             {
167*3e777be0SXin Li                 numberOfCachedModelFiles += numberOfCacheFiles;
168*3e777be0SXin Li                 if (modelCacheHandle[index]->numFds == 1)
169*3e777be0SXin Li                 {
170*3e777be0SXin Li                     if (backend == armnn::Compute::GpuAcc)
171*3e777be0SXin Li                     {
172*3e777be0SXin Li                         cachedFd = modelCacheHandle[index]->data[0];
173*3e777be0SXin Li                         saveCachedNetwork = true;
174*3e777be0SXin Li                     }
175*3e777be0SXin Li                 }
176*3e777be0SXin Li                 index += numberOfCachedModelFiles;
177*3e777be0SXin Li             }
178*3e777be0SXin Li         }
179*3e777be0SXin Li     }
180*3e777be0SXin Li 
181*3e777be0SXin Li     armnn::BackendOptions gpuAcc("GpuAcc",
182*3e777be0SXin Li     {
183*3e777be0SXin Li         { "FastMathEnabled", options.IsFastMathEnabled() },
184*3e777be0SXin Li         { "SaveCachedNetwork", saveCachedNetwork },
185*3e777be0SXin Li         { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
186*3e777be0SXin Li         { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() },
187*3e777be0SXin Li         { "CachedFileDescriptor", cachedFd }
188*3e777be0SXin Li     });
189*3e777be0SXin Li 
190*3e777be0SXin Li     armnn::BackendOptions cpuAcc("CpuAcc",
191*3e777be0SXin Li     {
192*3e777be0SXin Li         { "FastMathEnabled", options.IsFastMathEnabled() },
193*3e777be0SXin Li         { "NumberOfThreads", options.GetNumberOfThreads() }
194*3e777be0SXin Li     });
195*3e777be0SXin Li     OptOptions.AddModelOption(gpuAcc);
196*3e777be0SXin Li     OptOptions.AddModelOption(cpuAcc);
197*3e777be0SXin Li 
198*3e777be0SXin Li     std::vector<std::string> errMessages;
199*3e777be0SXin Li     try
200*3e777be0SXin Li     {
201*3e777be0SXin Li         optNet = armnn::Optimize(*modelConverter.GetINetwork(),
202*3e777be0SXin Li                                  options.GetBackends(),
203*3e777be0SXin Li                                  runtime->GetDeviceSpec(),
204*3e777be0SXin Li                                  OptOptions,
205*3e777be0SXin Li                                  errMessages);
206*3e777be0SXin Li     }
207*3e777be0SXin Li     catch (std::exception &e)
208*3e777be0SXin Li     {
209*3e777be0SXin Li         std::stringstream message;
210*3e777be0SXin Li         message << "Exception (" << e.what() << ") caught from optimize.";
211*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
212*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
213*3e777be0SXin Li     }
214*3e777be0SXin Li 
215*3e777be0SXin Li     // Check that the optimized network is valid.
216*3e777be0SXin Li     if (!optNet)
217*3e777be0SXin Li     {
218*3e777be0SXin Li         std::stringstream message;
219*3e777be0SXin Li         message << "Invalid optimized network";
220*3e777be0SXin Li         for (const std::string& msg : errMessages)
221*3e777be0SXin Li         {
222*3e777be0SXin Li             message << "\n" << msg;
223*3e777be0SXin Li         }
224*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
225*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
226*3e777be0SXin Li     }
227*3e777be0SXin Li 
228*3e777be0SXin Li     // Export the optimized network graph to a dot file if an output dump directory
229*3e777be0SXin Li     // has been specified in the drivers' arguments.
230*3e777be0SXin Li     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
231*3e777be0SXin Li                                                                options.GetRequestInputsAndOutputsDumpDir());
232*3e777be0SXin Li 
233*3e777be0SXin Li     // Load it into the runtime.
234*3e777be0SXin Li     armnn::NetworkId netId = 0;
235*3e777be0SXin Li     std::string msg;
236*3e777be0SXin Li     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
237*3e777be0SXin Li                                                 MemorySource::Undefined,
238*3e777be0SXin Li                                                 MemorySource::Undefined,
239*3e777be0SXin Li                                                 options.IsGpuProfilingEnabled());
240*3e777be0SXin Li 
241*3e777be0SXin Li     auto numInputs  = getMainModel(model).inputIndexes.size();
242*3e777be0SXin Li     auto numOutputs = getMainModel(model).outputIndexes.size();
243*3e777be0SXin Li     try
244*3e777be0SXin Li     {
245*3e777be0SXin Li         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
246*3e777be0SXin Li         {
247*3e777be0SXin Li             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, msg, cb);
248*3e777be0SXin Li         }
249*3e777be0SXin Li     }
250*3e777be0SXin Li     catch (std::exception& e)
251*3e777be0SXin Li     {
252*3e777be0SXin Li         std::stringstream message;
253*3e777be0SXin Li         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
254*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
255*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
256*3e777be0SXin Li     }
257*3e777be0SXin Li 
258*3e777be0SXin Li     // Now that we have a networkId for the graph rename the exported files to use it
259*3e777be0SXin Li     // so that we can associate the graph file and the input/output tensor exported files
260*3e777be0SXin Li     RenameExportedFiles(serializedNetworkFileName,
261*3e777be0SXin Li                         dotGraphFileName,
262*3e777be0SXin Li                         options.GetRequestInputsAndOutputsDumpDir(),
263*3e777be0SXin Li                         netId);
264*3e777be0SXin Li 
265*3e777be0SXin Li     std::unique_ptr<ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>> preparedModel(
266*3e777be0SXin Li             new ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>(
267*3e777be0SXin Li                     netId,
268*3e777be0SXin Li                     runtime.get(),
269*3e777be0SXin Li                     model,
270*3e777be0SXin Li                     options.GetRequestInputsAndOutputsDumpDir(),
271*3e777be0SXin Li                     options.IsGpuProfilingEnabled(),
272*3e777be0SXin Li                     options.isAsyncModelExecutionEnabled(),
273*3e777be0SXin Li                     options.getNoOfArmnnThreads(),
274*3e777be0SXin Li                     options.isImportEnabled(),
275*3e777be0SXin Li                     options.isExportEnabled()));
276*3e777be0SXin Li 
277*3e777be0SXin Li     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
278*3e777be0SXin Li     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
279*3e777be0SXin Li     // Only run this if the GpuAcc backend has been added to options
280*3e777be0SXin Li     if (std::find(options.GetBackends().begin(),
281*3e777be0SXin Li                   options.GetBackends().end(),
282*3e777be0SXin Li                   armnn::Compute::GpuAcc) != options.GetBackends().end())
283*3e777be0SXin Li     {
284*3e777be0SXin Li         if (!preparedModel->ExecuteWithDummyInputs(numInputs, numOutputs))
285*3e777be0SXin Li         {
286*3e777be0SXin Li             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
287*3e777be0SXin Li         }
288*3e777be0SXin Li 
289*3e777be0SXin Li         if (clTunedParameters &&
290*3e777be0SXin Li             options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
291*3e777be0SXin Li         {
292*3e777be0SXin Li             // Now that we've done one inference the CL kernel parameters will have been tuned,
293*3e777be0SXin Li             // so save the updated file.
294*3e777be0SXin Li             try
295*3e777be0SXin Li             {
296*3e777be0SXin Li                 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
297*3e777be0SXin Li             }
298*3e777be0SXin Li             catch (std::exception& error)
299*3e777be0SXin Li             {
300*3e777be0SXin Li                 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
301*3e777be0SXin Li                       options.GetClTunedParametersFile().c_str(), error.what());
302*3e777be0SXin Li             }
303*3e777be0SXin Li         }
304*3e777be0SXin Li     }
305*3e777be0SXin Li 
306*3e777be0SXin Li     size_t hashValue = 0;
307*3e777be0SXin Li     // Cache the model
308*3e777be0SXin Li     if (dataCacheHandle.size() > 0)
309*3e777be0SXin Li     {
310*3e777be0SXin Li         // Cache the Arm NN model, should be only 1
311*3e777be0SXin Li         if (dataCacheHandle.size() != 1)
312*3e777be0SXin Li         {
313*3e777be0SXin Li             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
314*3e777be0SXin Li             return V1_0::ErrorStatus::NONE;
315*3e777be0SXin Li         }
316*3e777be0SXin Li 
317*3e777be0SXin Li         if (dataCacheHandle[0]->numFds != 1)
318*3e777be0SXin Li         {
319*3e777be0SXin Li             ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, numFds != 1.");
320*3e777be0SXin Li             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
321*3e777be0SXin Li             return V1_0::ErrorStatus::NONE;
322*3e777be0SXin Li         }
323*3e777be0SXin Li 
324*3e777be0SXin Li         if (dataCacheHandle[0]->data[0] < 0)
325*3e777be0SXin Li         {
326*3e777be0SXin Li             ALOGW("ArmnnDriverImpl::prepareArmnnModel_1_3: Cannot cache the data, fd < 0");
327*3e777be0SXin Li             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
328*3e777be0SXin Li             return V1_0::ErrorStatus::NONE;
329*3e777be0SXin Li         }
330*3e777be0SXin Li 
331*3e777be0SXin Li         int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
332*3e777be0SXin Li         if (dataCacheFileAccessMode != O_RDWR)
333*3e777be0SXin Li         {
334*3e777be0SXin Li             ALOGW("ArmnnDriverImpl::prepareModelFromCache_1_2(): Invalid Access Mode.");
335*3e777be0SXin Li             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
336*3e777be0SXin Li             return V1_0::ErrorStatus::NONE;
337*3e777be0SXin Li         }
338*3e777be0SXin Li 
339*3e777be0SXin Li         write(dataCacheHandle[0]->data[0], dataCacheData.data(), dataCacheData.size());
340*3e777be0SXin Li         hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
341*3e777be0SXin Li     }
342*3e777be0SXin Li 
343*3e777be0SXin Li     if (modelCacheHandle.size() > 0)
344*3e777be0SXin Li     {
345*3e777be0SXin Li         if (modelCacheHandle.size() != numberOfCachedModelFiles)
346*3e777be0SXin Li         {
347*3e777be0SXin Li             NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
348*3e777be0SXin Li             return V1_0::ErrorStatus::NONE;
349*3e777be0SXin Li         }
350*3e777be0SXin Li         for (uint32_t i = 0; i < modelCacheHandle.size(); ++i)
351*3e777be0SXin Li         {
352*3e777be0SXin Li             if (modelCacheHandle[i]->numFds == 1)
353*3e777be0SXin Li             {
354*3e777be0SXin Li                 int modelCacheFileAccessMode = fcntl(modelCacheHandle[i]->data[0], F_GETFL) & O_ACCMODE;
355*3e777be0SXin Li                 if (modelCacheFileAccessMode != O_RDONLY)
356*3e777be0SXin Li                 {
357*3e777be0SXin Li                     struct stat statBuffer;
358*3e777be0SXin Li                     if (fstat(modelCacheHandle[i]->data[0], &statBuffer) == 0)
359*3e777be0SXin Li                     {
360*3e777be0SXin Li                         long modelDataSize = statBuffer.st_size;
361*3e777be0SXin Li                         if (modelDataSize > 0)
362*3e777be0SXin Li                         {
363*3e777be0SXin Li                             std::vector <uint8_t> modelData(modelDataSize);
364*3e777be0SXin Li                             pread(modelCacheHandle[i]->data[0], modelData.data(), modelData.size(), 0);
365*3e777be0SXin Li                             hashValue ^= CacheDataHandlerInstance().Hash(modelData);
366*3e777be0SXin Li                         }
367*3e777be0SXin Li                     }
368*3e777be0SXin Li                 }
369*3e777be0SXin Li             }
370*3e777be0SXin Li         }
371*3e777be0SXin Li     }
372*3e777be0SXin Li     if (hashValue != 0)
373*3e777be0SXin Li     {
374*3e777be0SXin Li         CacheDataHandlerInstance().Register(token, hashValue, dataCacheData.size());
375*3e777be0SXin Li     }
376*3e777be0SXin Li 
377*3e777be0SXin Li     NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
378*3e777be0SXin Li 
379*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareModel cache timing = %lld µs", std::chrono::duration_cast<std::chrono::microseconds>
380*3e777be0SXin Li          (std::chrono::system_clock::now() - prepareModelTimepoint).count());
381*3e777be0SXin Li 
382*3e777be0SXin Li     return V1_0::ErrorStatus::NONE;
383*3e777be0SXin Li }
384*3e777be0SXin Li 
prepareModelFromCache(const armnn::IRuntimePtr & runtime,const DriverOptions & options,const android::hardware::hidl_vec<android::hardware::hidl_handle> & modelCacheHandle,const android::hardware::hidl_vec<android::hardware::hidl_handle> & dataCacheHandle,const HidlToken & token,const android::sp<V1_2::IPreparedModelCallback> & cb,bool float32ToFloat16)385*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnDriverImpl::prepareModelFromCache(
386*3e777be0SXin Li     const armnn::IRuntimePtr& runtime,
387*3e777be0SXin Li     const DriverOptions& options,
388*3e777be0SXin Li     const android::hardware::hidl_vec<android::hardware::hidl_handle>& modelCacheHandle,
389*3e777be0SXin Li     const android::hardware::hidl_vec<android::hardware::hidl_handle>& dataCacheHandle,
390*3e777be0SXin Li     const HidlToken& token,
391*3e777be0SXin Li     const android::sp<V1_2::IPreparedModelCallback>& cb,
392*3e777be0SXin Li     bool float32ToFloat16)
393*3e777be0SXin Li {
394*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareModelFromCache()");
395*3e777be0SXin Li     std::chrono::time_point<std::chrono::system_clock> modelFromCacheTimepoint = std::chrono::system_clock::now();
396*3e777be0SXin Li 
397*3e777be0SXin Li     if (cb.get() == nullptr)
398*3e777be0SXin Li     {
399*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid callback passed to prepareModel");
400*3e777be0SXin Li         return V1_0::ErrorStatus::INVALID_ARGUMENT;
401*3e777be0SXin Li     }
402*3e777be0SXin Li 
403*3e777be0SXin Li     if (!runtime)
404*3e777be0SXin Li     {
405*3e777be0SXin Li         return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
406*3e777be0SXin Li     }
407*3e777be0SXin Li 
408*3e777be0SXin Li     if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN)
409*3e777be0SXin Li     {
410*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid token passed!", cb);
411*3e777be0SXin Li         return V1_0::ErrorStatus::INVALID_ARGUMENT;
412*3e777be0SXin Li     }
413*3e777be0SXin Li 
414*3e777be0SXin Li     // DataCacheHandle size should always be 1
415*3e777be0SXin Li     // Arm NN model
416*3e777be0SXin Li     if (dataCacheHandle.size() != 1)
417*3e777be0SXin Li     {
418*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
419*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
420*3e777be0SXin Li     }
421*3e777be0SXin Li 
422*3e777be0SXin Li     // Check if model files cached they match the expected value
423*3e777be0SXin Li     unsigned int numberOfCachedModelFiles = 0;
424*3e777be0SXin Li     for (auto& backend : options.GetBackends())
425*3e777be0SXin Li     {
426*3e777be0SXin Li         numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
427*3e777be0SXin Li     }
428*3e777be0SXin Li     if (modelCacheHandle.size() != numberOfCachedModelFiles)
429*3e777be0SXin Li     {
430*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid model cache!", cb);
431*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
432*3e777be0SXin Li     }
433*3e777be0SXin Li 
434*3e777be0SXin Li     if (dataCacheHandle[0]->numFds != 1)
435*3e777be0SXin Li     {
436*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, numFds != 1.");
437*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
438*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
439*3e777be0SXin Li     }
440*3e777be0SXin Li 
441*3e777be0SXin Li     if (dataCacheHandle[0]->data[0] < 0)
442*3e777be0SXin Li     {
443*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the cache data, fd < 0");
444*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "No data cache!", cb);
445*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
446*3e777be0SXin Li     }
447*3e777be0SXin Li 
448*3e777be0SXin Li     int dataCacheFileAccessMode = fcntl(dataCacheHandle[0]->data[0], F_GETFL) & O_ACCMODE;
449*3e777be0SXin Li     if (dataCacheFileAccessMode != O_RDWR)
450*3e777be0SXin Li     {
451*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid Access Mode!", cb);
452*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
453*3e777be0SXin Li     }
454*3e777be0SXin Li 
455*3e777be0SXin Li     auto dataSize = CacheDataHandlerInstance().GetCacheSize(token);
456*3e777be0SXin Li     if (dataSize == 0)
457*3e777be0SXin Li     {
458*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
459*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
460*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
461*3e777be0SXin Li     }
462*3e777be0SXin Li 
463*3e777be0SXin Li     int offset = 0;
464*3e777be0SXin Li     {
465*3e777be0SXin Li         struct stat statBuffer;
466*3e777be0SXin Li         if (fstat(dataCacheHandle[0]->data[0], &statBuffer) == 0)
467*3e777be0SXin Li         {
468*3e777be0SXin Li             unsigned long bufferSize = statBuffer.st_size;
469*3e777be0SXin Li             if (bufferSize != dataSize)
470*3e777be0SXin Li             {
471*3e777be0SXin Li                 ALOGW("ArmnnDriverImpl::prepareModelFromCache: Invalid data to deserialize!");
472*3e777be0SXin Li                 FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid data to deserialize!", cb);
473*3e777be0SXin Li                 return V1_0::ErrorStatus::GENERAL_FAILURE;
474*3e777be0SXin Li             }
475*3e777be0SXin Li         }
476*3e777be0SXin Li     }
477*3e777be0SXin Li     std::vector<uint8_t> dataCacheData(dataSize);
478*3e777be0SXin Li     pread(dataCacheHandle[0]->data[0], dataCacheData.data(), dataCacheData.size(), offset);
479*3e777be0SXin Li     auto hashValue = CacheDataHandlerInstance().Hash(dataCacheData);
480*3e777be0SXin Li 
481*3e777be0SXin Li     int gpuAccCachedFd = -1;
482*3e777be0SXin Li     bool saveCachedNetwork = false;
483*3e777be0SXin Li     if (modelCacheHandle.size() > 0)
484*3e777be0SXin Li     {
485*3e777be0SXin Li         unsigned int index = 0;
486*3e777be0SXin Li         for (auto& backend : options.GetBackends())
487*3e777be0SXin Li         {
488*3e777be0SXin Li             // modelCacheHandle size should be equal to numberOfCachedModelFiles
489*3e777be0SXin Li             // modelCacheHandle vector should be in same order as backends
490*3e777be0SXin Li             auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
491*3e777be0SXin Li             if (numberOfCacheFiles > 0)
492*3e777be0SXin Li             {
493*3e777be0SXin Li                 if (modelCacheHandle[index]->numFds != 1)
494*3e777be0SXin Li                 {
495*3e777be0SXin Li                     ALOGW("ArmnnDriverImpl::prepareModelFromCache: Cannot read from the model cache, numFds != 1.");
496*3e777be0SXin Li                     FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE,
497*3e777be0SXin Li                                      "Cannot read from the model cache, numFds != 1.", cb);
498*3e777be0SXin Li                     return V1_0::ErrorStatus::GENERAL_FAILURE;
499*3e777be0SXin Li                 }
500*3e777be0SXin Li                 auto cachedFd = modelCacheHandle[index]->data[0];
501*3e777be0SXin Li 
502*3e777be0SXin Li                 int modelCacheFileAccessMode = fcntl(cachedFd, F_GETFL) & O_ACCMODE;
503*3e777be0SXin Li                 if (modelCacheFileAccessMode != O_RDWR)
504*3e777be0SXin Li                 {
505*3e777be0SXin Li                     FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Invalid Access Mode!", cb);
506*3e777be0SXin Li                     return V1_0::ErrorStatus::GENERAL_FAILURE;
507*3e777be0SXin Li                 }
508*3e777be0SXin Li 
509*3e777be0SXin Li                 struct stat statBuffer;
510*3e777be0SXin Li                 if (cachedFd != -1 && fstat(cachedFd, &statBuffer) == 0)
511*3e777be0SXin Li                 {
512*3e777be0SXin Li                     long modelDataSize = statBuffer.st_size;
513*3e777be0SXin Li                     if (modelDataSize <= 0)
514*3e777be0SXin Li                     {
515*3e777be0SXin Li                         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Wrong cached model size!", cb);
516*3e777be0SXin Li                         return V1_0::ErrorStatus::NONE;
517*3e777be0SXin Li                     }
518*3e777be0SXin Li                     std::vector<uint8_t> modelData(modelDataSize);
519*3e777be0SXin Li                     pread(cachedFd, modelData.data(), modelData.size(), 0);
520*3e777be0SXin Li                     hashValue ^= CacheDataHandlerInstance().Hash(modelData);
521*3e777be0SXin Li 
522*3e777be0SXin Li                     // For GpuAcc numberOfCachedFiles is 1
523*3e777be0SXin Li                     if (backend == armnn::Compute::GpuAcc)
524*3e777be0SXin Li                     {
525*3e777be0SXin Li                         gpuAccCachedFd = cachedFd;
526*3e777be0SXin Li                     }
527*3e777be0SXin Li                 }
528*3e777be0SXin Li                 index += numberOfCacheFiles;
529*3e777be0SXin Li             }
530*3e777be0SXin Li         }
531*3e777be0SXin Li     }
532*3e777be0SXin Li 
533*3e777be0SXin Li     if (!CacheDataHandlerInstance().Validate(token, hashValue, dataCacheData.size()))
534*3e777be0SXin Li     {
535*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModelFromCache: ValidateHash() failed!");
536*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ValidateHash Failed!", cb);
537*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
538*3e777be0SXin Li     }
539*3e777be0SXin Li 
540*3e777be0SXin Li     // Deserialize the network..
541*3e777be0SXin Li     armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
542*3e777be0SXin Li     try
543*3e777be0SXin Li     {
544*3e777be0SXin Li         network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
545*3e777be0SXin Li     }
546*3e777be0SXin Li     catch (std::exception& e)
547*3e777be0SXin Li     {
548*3e777be0SXin Li         std::stringstream message;
549*3e777be0SXin Li         message << "Exception (" << e.what() << ") caught from Deserializer.";
550*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
551*3e777be0SXin Li         return V1_0::ErrorStatus::GENERAL_FAILURE;
552*3e777be0SXin Li     }
553*3e777be0SXin Li 
554*3e777be0SXin Li     // Optimize the network
555*3e777be0SXin Li     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
556*3e777be0SXin Li     armnn::OptimizerOptionsOpaque OptOptions;
557*3e777be0SXin Li     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
558*3e777be0SXin Li     OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
559*3e777be0SXin Li 
560*3e777be0SXin Li     armnn::BackendOptions gpuAcc("GpuAcc",
561*3e777be0SXin Li                                  {
562*3e777be0SXin Li                                          {"FastMathEnabled",       options.IsFastMathEnabled()},
563*3e777be0SXin Li                                          {"SaveCachedNetwork",     saveCachedNetwork},
564*3e777be0SXin Li                                          {"CachedNetworkFilePath", options.GetCachedNetworkFilePath()},
565*3e777be0SXin Li                                          {"MLGOTuningFilePath",    options.GetClMLGOTunedParametersFile()},
566*3e777be0SXin Li                                          {"CachedFileDescriptor",  gpuAccCachedFd}
567*3e777be0SXin Li                                  });
568*3e777be0SXin Li 
569*3e777be0SXin Li     armnn::BackendOptions cpuAcc("CpuAcc",
570*3e777be0SXin Li                                  {
571*3e777be0SXin Li                                          {"FastMathEnabled", options.IsFastMathEnabled()},
572*3e777be0SXin Li                                          {"NumberOfThreads", options.GetNumberOfThreads()}
573*3e777be0SXin Li                                  });
574*3e777be0SXin Li     OptOptions.AddModelOption(gpuAcc);
575*3e777be0SXin Li     OptOptions.AddModelOption(cpuAcc);
576*3e777be0SXin Li 
577*3e777be0SXin Li     std::vector<std::string> errMessages;
578*3e777be0SXin Li     try
579*3e777be0SXin Li     {
580*3e777be0SXin Li         optNet = armnn::Optimize(*network.get(),
581*3e777be0SXin Li                                  options.GetBackends(),
582*3e777be0SXin Li                                  runtime->GetDeviceSpec(),
583*3e777be0SXin Li                                  OptOptions,
584*3e777be0SXin Li                                  errMessages);
585*3e777be0SXin Li     }
586*3e777be0SXin Li     catch (std::exception& e)
587*3e777be0SXin Li     {
588*3e777be0SXin Li         std::stringstream message;
589*3e777be0SXin Li         message << "Exception (" << e.what() << ") caught from optimize.";
590*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
591*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
592*3e777be0SXin Li     }
593*3e777be0SXin Li 
594*3e777be0SXin Li     // Check that the optimized network is valid.
595*3e777be0SXin Li     if (!optNet)
596*3e777be0SXin Li     {
597*3e777be0SXin Li         std::stringstream message;
598*3e777be0SXin Li         message << "Invalid optimized network";
599*3e777be0SXin Li         for (const std::string& msg : errMessages)
600*3e777be0SXin Li         {
601*3e777be0SXin Li             message << "\n" << msg;
602*3e777be0SXin Li         }
603*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
604*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
605*3e777be0SXin Li     }
606*3e777be0SXin Li 
607*3e777be0SXin Li     // Export the optimized network graph to a dot file if an output dump directory
608*3e777be0SXin Li     // has been specified in the drivers' arguments.
609*3e777be0SXin Li     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
610*3e777be0SXin Li                                                                options.GetRequestInputsAndOutputsDumpDir());
611*3e777be0SXin Li 
612*3e777be0SXin Li     // Load it into the runtime.
613*3e777be0SXin Li     armnn::NetworkId netId = 0;
614*3e777be0SXin Li     std::string msg;
615*3e777be0SXin Li     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
616*3e777be0SXin Li                                                 MemorySource::Undefined,
617*3e777be0SXin Li                                                 MemorySource::Undefined,
618*3e777be0SXin Li                                                 options.IsGpuProfilingEnabled());
619*3e777be0SXin Li 
620*3e777be0SXin Li     try
621*3e777be0SXin Li     {
622*3e777be0SXin Li         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
623*3e777be0SXin Li         {
624*3e777be0SXin Li             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, msg, cb);
625*3e777be0SXin Li         }
626*3e777be0SXin Li     }
627*3e777be0SXin Li     catch (std::exception& e)
628*3e777be0SXin Li     {
629*3e777be0SXin Li         std::stringstream message;
630*3e777be0SXin Li         message << "Exception (" << e.what() << ") caught from LoadNetwork.";
631*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
632*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
633*3e777be0SXin Li     }
634*3e777be0SXin Li 
635*3e777be0SXin Li     std::unique_ptr<ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>> preparedModel(
636*3e777be0SXin Li             new ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>(
637*3e777be0SXin Li                     netId,
638*3e777be0SXin Li                     runtime.get(),
639*3e777be0SXin Li                     options.GetRequestInputsAndOutputsDumpDir(),
640*3e777be0SXin Li                     options.IsGpuProfilingEnabled(),
641*3e777be0SXin Li                     options.isAsyncModelExecutionEnabled(),
642*3e777be0SXin Li                     options.getNoOfArmnnThreads(),
643*3e777be0SXin Li                     options.isImportEnabled(),
644*3e777be0SXin Li                     options.isExportEnabled(),
645*3e777be0SXin Li                     true));
646*3e777be0SXin Li 
647*3e777be0SXin Li     NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel.release());
648*3e777be0SXin Li 
649*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareModelFromCache cache timing = %lld µs",
650*3e777be0SXin Li           std::chrono::duration_cast<std::chrono::microseconds>
651*3e777be0SXin Li           (std::chrono::system_clock::now() - modelFromCacheTimepoint).count());
652*3e777be0SXin Li 
653*3e777be0SXin Li     return V1_0::ErrorStatus::NONE;
654*3e777be0SXin Li }
655*3e777be0SXin Li 
getCapabilities_1_2(const armnn::IRuntimePtr & runtime,V1_2::IDevice::getCapabilities_1_2_cb cb)656*3e777be0SXin Li Return<void> ArmnnDriverImpl::getCapabilities_1_2(const armnn::IRuntimePtr& runtime,
657*3e777be0SXin Li                                                   V1_2::IDevice::getCapabilities_1_2_cb cb)
658*3e777be0SXin Li {
659*3e777be0SXin Li     ALOGV("hal_1_2::ArmnnDriverImpl::getCapabilities()");
660*3e777be0SXin Li 
661*3e777be0SXin Li     V1_2::Capabilities capabilities;
662*3e777be0SXin Li 
663*3e777be0SXin Li     float defaultValue = .1f;
664*3e777be0SXin Li 
665*3e777be0SXin Li     if (runtime)
666*3e777be0SXin Li     {
667*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceScalar.execTime =
668*3e777be0SXin Li                 ParseSystemProperty(g_RelaxedFloat32toFloat16PerformanceExecTime, defaultValue);
669*3e777be0SXin Li 
670*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceScalar.powerUsage =
671*3e777be0SXin Li                 ParseSystemProperty(g_RelaxedFloat32toFloat16PerformancePowerUsage, defaultValue);
672*3e777be0SXin Li 
673*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceTensor.execTime =
674*3e777be0SXin Li                 ParseSystemProperty(g_RelaxedFloat32toFloat16PerformanceExecTime, defaultValue);
675*3e777be0SXin Li 
676*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceTensor.powerUsage =
677*3e777be0SXin Li                 ParseSystemProperty(g_RelaxedFloat32toFloat16PerformancePowerUsage, defaultValue);
678*3e777be0SXin Li 
679*3e777be0SXin Li         // Set the base value for all operand types
680*3e777be0SXin Li         #if defined(ARMNN_ANDROID_R) || defined(ARMNN_ANDROID_S)
681*3e777be0SXin Li         capabilities.operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_2>({FLT_MAX, FLT_MAX});
682*3e777be0SXin Li         #else
683*3e777be0SXin Li         capabilities.operandPerformance = nonExtensionOperandPerformance({FLT_MAX, FLT_MAX});
684*3e777be0SXin Li         #endif
685*3e777be0SXin Li 
686*3e777be0SXin Li         // Load supported operand types
687*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32,
688*3e777be0SXin Li                 {
689*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorFloat32PerformanceExecTime, defaultValue),
690*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorFloat32PerformancePowerUsage, defaultValue)
691*3e777be0SXin Li                 });
692*3e777be0SXin Li 
693*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::FLOAT32,
694*3e777be0SXin Li                 {
695*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeFloat32PerformanceExecTime, defaultValue),
696*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeFloat32PerformancePowerUsage, defaultValue)
697*3e777be0SXin Li                 });
698*3e777be0SXin Li 
699*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT16,
700*3e777be0SXin Li                 {
701*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorFloat16PerformanceExecTime, defaultValue),
702*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorFloat16PerformancePowerUsage, defaultValue)
703*3e777be0SXin Li                 });
704*3e777be0SXin Li 
705*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::FLOAT16,
706*3e777be0SXin Li                 {
707*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeFloat16PerformanceExecTime, defaultValue),
708*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeFloat16PerformancePowerUsage, defaultValue)
709*3e777be0SXin Li                 });
710*3e777be0SXin Li 
711*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_ASYMM,
712*3e777be0SXin Li                 {
713*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorQuant8AsymmPerformanceExecTime, defaultValue),
714*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorQuant8AsymmPerformancePowerUsage, defaultValue)
715*3e777be0SXin Li                 });
716*3e777be0SXin Li 
717*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_SYMM,
718*3e777be0SXin Li                 {
719*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorQuant8SymmPerformanceExecTime, defaultValue),
720*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorQuant8SymmPerformancePowerUsage, defaultValue)
721*3e777be0SXin Li                 });
722*3e777be0SXin Li 
723*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT16_SYMM,
724*3e777be0SXin Li                 {
725*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorQuant16SymmPerformanceExecTime, defaultValue),
726*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorQuant16SymmPerformancePowerUsage, defaultValue)
727*3e777be0SXin Li                 });
728*3e777be0SXin Li 
729*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
730*3e777be0SXin Li                {
731*3e777be0SXin Li                    .execTime =
732*3e777be0SXin Li                    ParseSystemProperty(g_OperandTypeTensorQuant8SymmPerChannelPerformanceExecTime, defaultValue),
733*3e777be0SXin Li                    .powerUsage =
734*3e777be0SXin Li                    ParseSystemProperty(g_OperandTypeTensorQuant8SymmPerChannelPerformancePowerUsage, defaultValue)
735*3e777be0SXin Li                });
736*3e777be0SXin Li 
737*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::TENSOR_INT32,
738*3e777be0SXin Li                 {
739*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeTensorInt32PerformanceExecTime, defaultValue),
740*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeTensorInt32PerformancePowerUsage, defaultValue)
741*3e777be0SXin Li                 });
742*3e777be0SXin Li 
743*3e777be0SXin Li         update(&capabilities.operandPerformance, V1_2::OperandType::INT32,
744*3e777be0SXin Li                 {
745*3e777be0SXin Li                     .execTime = ParseSystemProperty(g_OperandTypeInt32PerformanceExecTime, defaultValue),
746*3e777be0SXin Li                     .powerUsage = ParseSystemProperty(g_OperandTypeInt32PerformancePowerUsage, defaultValue)
747*3e777be0SXin Li                 });
748*3e777be0SXin Li 
749*3e777be0SXin Li         cb(V1_0::ErrorStatus::NONE, capabilities);
750*3e777be0SXin Li     }
751*3e777be0SXin Li     else
752*3e777be0SXin Li     {
753*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceScalar.execTime   = 0;
754*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceScalar.powerUsage = 0;
755*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceTensor.execTime   = 0;
756*3e777be0SXin Li         capabilities.relaxedFloat32toFloat16PerformanceTensor.powerUsage = 0;
757*3e777be0SXin Li 
758*3e777be0SXin Li         // Set the base value for all operand types
759*3e777be0SXin Li         #if defined(ARMNN_ANDROID_R) || defined(ARMNN_ANDROID_S)
760*3e777be0SXin Li         capabilities.operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_2>({0.f, 0.0f});
761*3e777be0SXin Li         #else
762*3e777be0SXin Li         capabilities.operandPerformance = nonExtensionOperandPerformance({0.f, 0.0f});
763*3e777be0SXin Li         #endif
764*3e777be0SXin Li 
765*3e777be0SXin Li         cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, capabilities);
766*3e777be0SXin Li     }
767*3e777be0SXin Li 
768*3e777be0SXin Li     return Void();
769*3e777be0SXin Li }
770*3e777be0SXin Li 
771*3e777be0SXin Li } // namespace hal_1_2
772*3e777be0SXin Li } // namespace armnn_driver