xref: /aosp_15_r20/external/android-nn-driver/ArmnnDriverImpl.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "DriverOptions.hpp"
9 
10 #include <HalInterfaces.h>
11 
12 #ifdef ARMNN_ANDROID_R
13 using namespace android::nn::hal;
14 #endif
15 
16 #ifdef ARMNN_ANDROID_S
17 using namespace android::hardware;
18 #endif
19 
20 namespace V1_0 = ::android::hardware::neuralnetworks::V1_0;
21 namespace V1_1 = ::android::hardware::neuralnetworks::V1_1;
22 
23 #ifdef ARMNN_ANDROID_NN_V1_2 // Using ::android::hardware::neuralnetworks::V1_2
24 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
25 #endif
26 
27 #ifdef ARMNN_ANDROID_NN_V1_3 // Using ::android::hardware::neuralnetworks::V1_3
28 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
29 namespace V1_3 = ::android::hardware::neuralnetworks::V1_3;
30 #endif
31 
32 namespace armnn_driver
33 {
34 
35 template <typename Callback, typename Context>
36 struct CallbackContext
37 {
38     Callback callback;
39     Context ctx;
40 };
41 
42 template<typename HalPolicy>
43 class ArmnnDriverImpl
44 {
45 public:
46     using HalModel                     = typename HalPolicy::Model;
47     using HalGetSupportedOperations_cb = typename HalPolicy::getSupportedOperations_cb;
48     using HalErrorStatus               = typename HalPolicy::ErrorStatus;
49 
50     static Return<void> getSupportedOperations(
51             const armnn::IRuntimePtr& runtime,
52             const DriverOptions& options,
53             const HalModel& model,
54             HalGetSupportedOperations_cb);
55 
56     static Return<V1_0::ErrorStatus> prepareModel(
57             const armnn::IRuntimePtr& runtime,
58             const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
59             const DriverOptions& options,
60             const HalModel& model,
61             const android::sp<V1_0::IPreparedModelCallback>& cb,
62             bool float32ToFloat16 = false);
63 
64     static Return<V1_0::DeviceStatus> getStatus();
65 
66 };
67 
68 } // namespace armnn_driver
69