xref: /aosp_15_r20/external/android-nn-driver/1.2/HalPolicy.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2019-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "../ConversionUtils.hpp"
9 #include "../ConversionUtils_1_2.hpp"
10 
11 #include <HalInterfaces.h>
12 
13 #include <armnn/Types.hpp>
14 
15 namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
16 
17 namespace armnn_driver
18 {
19 class DriverOptions;
20 namespace hal_1_2
21 {
22 
23 class HalPolicy
24 {
25 public:
26     using Model                     = V1_2::Model;
27     using Operand                   = V1_2::Operand;
28     using OperandLifeTime           = V1_0::OperandLifeTime;
29     using OperandType               = V1_2::OperandType;
30     using Operation                 = V1_2::Operation;
31     using OperationType             = V1_2::OperationType;
32     using ExecutionCallback         = V1_2::IExecutionCallback;
33     using getSupportedOperations_cb = V1_2::IDevice::getSupportedOperations_1_2_cb;
34     using ErrorStatus               = V1_0::ErrorStatus;
35     using DeviceType                = V1_2::DeviceType;
36 
37     static DeviceType GetDeviceTypeFromOptions(const DriverOptions& options);
38 
39     static bool ConvertOperation(const Operation& operation, const Model& model, ConversionData& data);
40 
41 private:
42     static bool ConvertArgMinMax(const Operation& operation,
43                                  const Model& model,
44                                  ConversionData& data,
45                                  armnn::ArgMinMaxFunction argMinMaxFunction);
46 
47     static bool ConvertAveragePool2d(const Operation& operation, const Model& model, ConversionData& data);
48 
49     static bool ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data);
50 
51     static bool ConvertCast(const Operation& operation, const Model& model, ConversionData& data);
52 
53     static bool ConvertChannelShuffle(const Operation& operation, const Model& model, ConversionData& data);
54 
55     static bool ConvertComparison(const Operation& operation,
56                                   const Model& model,
57                                   ConversionData& data,
58                                   armnn::ComparisonOperation comparisonOperation);
59 
60     static bool ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data);
61 
62     static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data);
63 
64     static bool ConvertDepthToSpace(const Operation& operation, const Model& model, ConversionData& data);
65 
66     static bool ConvertDepthwiseConv2d(const Operation& operation, const Model& model, ConversionData& data);
67 
68     static bool ConvertDequantize(const Operation& operation, const Model& model, ConversionData& data);
69 
70     static bool ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data);
71 
72     static bool ConvertElementwiseBinary(const Operation& operation,
73                                          const Model& model,
74                                          ConversionData& data,
75                                          armnn::BinaryOperation binaryOperation);
76 
77     static bool ConvertElementwiseUnary(const Operation& operation,
78                                         const Model& model,
79                                         ConversionData& data,
80                                         armnn::UnaryOperation unaryOperation);
81 
82     static bool ConvertFloor(const Operation& operation, const Model& model, ConversionData& data);
83 
84     static bool ConvertFullyConnected(const Operation& operation, const Model& model, ConversionData& data);
85 
86     static bool ConvertGather(const Operation& operation, const Model& model, ConversionData& data);
87 
88     static bool ConvertGroupedConv2d(const Operation& operation, const Model& model, ConversionData& data);
89 
90     static bool ConvertInstanceNormalization(const Operation& operation, const Model& model, ConversionData& data);
91 
92     static bool ConvertL2Normalization(const Operation& operation, const Model& model, ConversionData& data);
93 
94     static bool ConvertL2Pool2d(const Operation& operation, const Model& model, ConversionData& data);
95 
96     static bool ConvertLocalResponseNormalization(const Operation& operation,
97                                                   const Model& model,
98                                                   ConversionData& data);
99 
100     static bool ConvertLogistic(const Operation& operation, const Model& model, ConversionData& data);
101 
102     static bool ConvertLogSoftmax(const Operation& operation, const Model& model, ConversionData& data);
103 
104     static bool ConvertLstm(const Operation& operation, const Model& model, ConversionData& data);
105 
106     static bool ConvertMaxPool2d(const Operation& operation, const Model& model, ConversionData& data);
107 
108     static bool ConvertMean(const Operation& operation, const Model& model, ConversionData& data);
109 
110     static bool ConvertPad(const Operation& operation, const Model& model, ConversionData& data);
111 
112     static bool ConvertPadV2(const Operation& operation, const Model& model, ConversionData& data);
113 
114     static bool ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data);
115 
116     static bool ConvertQuantize(const Operation& operation, const Model& model, ConversionData& data);
117 
118     static bool ConvertQuantized16BitLstm(const Operation& operation, const Model& model, ConversionData& data);
119 
120     static bool ConvertReduce(const Operation& operation,
121                               const Model& model,
122                               ConversionData& data,
123                               ReduceOperation reduce_operation);
124 
125     static bool ConvertReLu(const Operation& operation, const Model& model, ConversionData& data);
126 
127     static bool ConvertReLu1(const Operation& operation, const Model& model, ConversionData& data);
128 
129     static bool ConvertReLu6(const Operation& operation, const Model& model, ConversionData& data);
130 
131     static bool ConvertReshape(const Operation& operation, const Model& model, ConversionData& data);
132 
133     static bool ConvertResize(const Operation& operation,
134                               const Model& model,
135                               ConversionData& data,
136                               armnn::ResizeMethod resizeMethod);
137 
138     static bool ConvertSoftmax(const Operation& operation, const Model& model, ConversionData& data);
139 
140     static bool ConvertSpaceToBatchNd(const Operation& operation, const Model& model, ConversionData& data);
141 
142     static bool ConvertSpaceToDepth(const Operation& operation, const Model& model, ConversionData& data);
143 
144     static bool ConvertSqrt(const Operation& operation, const Model& model, ConversionData& data);
145 
146     static bool ConvertSqueeze(const Operation& operation, const Model& model, ConversionData& data);
147 
148     static bool ConvertStridedSlice(const Operation& operation, const Model& model, ConversionData& data);
149 
150     static bool ConvertTanH(const Operation& operation, const Model& model, ConversionData& data);
151 
152     static bool ConvertTranspose(const Operation& operation, const Model& model, ConversionData& data);
153 
154     static bool ConvertTransposeConv2d(const Operation& operation, const Model& model, ConversionData& data);
155 
156     static bool ConvertUnidirectionalSequenceLstm(const Operation& operation,
157                                                   const Model& model,
158                                                   ConversionData& data);
159 };
160 
161 } // namespace hal_1_2
162 } // namespace armnn_driver
163