xref: /aosp_15_r20/external/armnn/include/armnn/BackendHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/BackendId.hpp>
9 #include <armnn/BackendOptions.hpp>
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/Optional.hpp>
12 #include <functional>
13 #include <memory>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 
18 namespace armnn
19 {
20 class ILayerSupport;
21 class TensorInfo;
22 struct LstmInputParamsInfo;
23 struct QuantizedLstmInputParamsInfo;
24 
25 // This handle calls its own IsXXXLayerSupported() functions which then call the polymorphic
26 // ILayerSupport::IsXXXLayerSupported() at the framework level so there is no risk of VTable misalignment.
27 // This is to make ILayerSupport in its abstract form a solely Backend interface alongside a
28 // separate ABI stable frontend class free of virtual functions via an added layer of indirection.
29 class LayerSupportHandle
30 {
31 public:
LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport)32     explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport)
33         : m_LayerSupport(std::move(layerSupport)), m_BackendId(Compute::Undefined) {};
34 
LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport,const BackendId & backendId)35     explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport, const BackendId& backendId)
36         : m_LayerSupport(std::move(layerSupport)), m_BackendId(backendId) {};
37 
38     bool IsBackendRegistered() const;
39 
40     bool IsActivationSupported(const TensorInfo& input,
41                                const TensorInfo& output,
42                                const ActivationDescriptor& descriptor,
43                                Optional<std::string&> reasonIfUnsupported = EmptyOptional());
44 
45     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
46     bool IsAdditionSupported(const TensorInfo& input0,
47                              const TensorInfo& input1,
48                              const TensorInfo& output,
49                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
50 
51     bool IsArgMinMaxSupported(const TensorInfo& input,
52                               const TensorInfo& output,
53                               const ArgMinMaxDescriptor& descriptor,
54                               Optional<std::string&> reasonIfUnsupported = EmptyOptional());
55 
56     bool IsBatchMatMulSupported(const TensorInfo& input0,
57                                 const TensorInfo& input1,
58                                 const TensorInfo& output,
59                                 const BatchMatMulDescriptor& descriptor,
60                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
61 
62     bool IsBatchNormalizationSupported(const TensorInfo& input,
63                                        const TensorInfo& output,
64                                        const TensorInfo& mean,
65                                        const TensorInfo& var,
66                                        const TensorInfo& beta,
67                                        const TensorInfo& gamma,
68                                        const BatchNormalizationDescriptor& descriptor,
69                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional());
70 
71     bool IsBatchToSpaceNdSupported(const TensorInfo& input,
72                                    const TensorInfo& output,
73                                    const BatchToSpaceNdDescriptor& descriptor,
74                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional());
75 
76     bool IsCastSupported(const TensorInfo& input,
77                          const TensorInfo& output,
78                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
79 
80     bool IsChannelShuffleSupported(const TensorInfo& input,
81                                    const TensorInfo& output,
82                                    const ChannelShuffleDescriptor& descriptor,
83                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional());
84 
85     bool IsComparisonSupported(const TensorInfo& input0,
86                                const TensorInfo& input1,
87                                const TensorInfo& output,
88                                const ComparisonDescriptor& descriptor,
89                                Optional<std::string&> reasonIfUnsupported = EmptyOptional());
90 
91     bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
92                            const TensorInfo& output,
93                            const OriginsDescriptor& descriptor,
94                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
95 
96     bool IsConstantSupported(const TensorInfo& output,
97                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
98 
99     bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
100                                       const TensorInfo& output,
101                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional());
102 
103     bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
104                                       const TensorInfo& output,
105                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional());
106 
107     bool IsConvolution2dSupported(const TensorInfo& input,
108                                   const TensorInfo& output,
109                                   const Convolution2dDescriptor& descriptor,
110                                   const TensorInfo& weights,
111                                   const Optional<TensorInfo>& biases,
112                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional());
113 
114     bool IsConvolution3dSupported(const TensorInfo& input,
115                                   const TensorInfo& output,
116                                   const Convolution3dDescriptor& descriptor,
117                                   const TensorInfo& weights,
118                                   const Optional<TensorInfo>& biases,
119                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional());
120 
121     bool IsDebugSupported(const TensorInfo& input,
122                           const TensorInfo& output,
123                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
124 
125     bool IsDepthToSpaceSupported(const TensorInfo& input,
126                                  const TensorInfo& output,
127                                  const DepthToSpaceDescriptor& descriptor,
128                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional());
129 
130     bool IsDepthwiseConvolutionSupported(
131             const TensorInfo& input,
132             const TensorInfo& output,
133             const DepthwiseConvolution2dDescriptor& descriptor,
134             const TensorInfo& weights,
135             const Optional<TensorInfo>& biases,
136             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
137 
138     bool IsDequantizeSupported(const TensorInfo& input,
139                                const TensorInfo& output,
140                                Optional<std::string&> reasonIfUnsupported = EmptyOptional());
141 
142     bool IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
143                                          const TensorInfo& scores,
144                                          const TensorInfo& anchors,
145                                          const TensorInfo& detectionBoxes,
146                                          const TensorInfo& detectionClasses,
147                                          const TensorInfo& detectionScores,
148                                          const TensorInfo& numDetections,
149                                          const DetectionPostProcessDescriptor& descriptor,
150                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
151 
152     bool IsDilatedDepthwiseConvolutionSupported(
153             const TensorInfo& input,
154             const TensorInfo& output,
155             const DepthwiseConvolution2dDescriptor& descriptor,
156             const TensorInfo& weights,
157             const Optional<TensorInfo>& biases,
158             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
159 
160     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
161     bool IsDivisionSupported(const TensorInfo& input0,
162                              const TensorInfo& input1,
163                              const TensorInfo& output,
164                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
165 
166     bool IsElementwiseBinarySupported(const TensorInfo& input0,
167                                       const TensorInfo& input1,
168                                       const TensorInfo& output,
169                                       const ElementwiseBinaryDescriptor& descriptor,
170                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional());
171 
172     bool IsElementwiseUnarySupported(const TensorInfo& input,
173                                      const TensorInfo& output,
174                                      const ElementwiseUnaryDescriptor& descriptor,
175                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional());
176 
177     bool IsFakeQuantizationSupported(const TensorInfo& input,
178                                      const FakeQuantizationDescriptor& descriptor,
179                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional());
180 
181     bool IsFillSupported(const TensorInfo& input,
182                          const TensorInfo& output,
183                          const FillDescriptor& descriptor,
184                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
185 
186     bool IsFloorSupported(const TensorInfo& input,
187                           const TensorInfo& output,
188                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
189 
190     bool IsFullyConnectedSupported(const TensorInfo& input,
191                                    const TensorInfo& output,
192                                    const TensorInfo& weights,
193                                    const TensorInfo& biases,
194                                    const FullyConnectedDescriptor& descriptor,
195                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional());
196 
197     bool IsGatherSupported(const TensorInfo& input0,
198                            const TensorInfo& input1,
199                            const TensorInfo& output,
200                            const GatherDescriptor& descriptor,
201                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
202 
203     bool IsGatherNdSupported(const TensorInfo& input0,
204                              const TensorInfo& input1,
205                              const TensorInfo& output,
206                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
207 
208     bool IsInputSupported(const TensorInfo& input,
209                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
210 
211     bool IsInstanceNormalizationSupported(
212             const TensorInfo& input,
213             const TensorInfo& output,
214             const InstanceNormalizationDescriptor& descriptor,
215             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
216 
217     bool IsL2NormalizationSupported(const TensorInfo& input,
218                                     const TensorInfo& output,
219                                     const L2NormalizationDescriptor& descriptor,
220                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional());
221 
222     bool IsLogicalBinarySupported(const TensorInfo& input0,
223                                   const TensorInfo& input1,
224                                   const TensorInfo& output,
225                                   const LogicalBinaryDescriptor& descriptor,
226                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional());
227 
228     bool IsLogicalUnarySupported(const TensorInfo& input,
229                                  const TensorInfo& output,
230                                  const ElementwiseUnaryDescriptor& descriptor,
231                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional());
232 
233     bool IsLogSoftmaxSupported(const TensorInfo& input,
234                                const TensorInfo& output,
235                                const LogSoftmaxDescriptor& descriptor,
236                                Optional<std::string&> reasonIfUnsupported = EmptyOptional());
237 
238     bool IsLstmSupported(const TensorInfo& input,
239                          const TensorInfo& outputStateIn,
240                          const TensorInfo& cellStateIn,
241                          const TensorInfo& scratchBuffer,
242                          const TensorInfo& outputStateOut,
243                          const TensorInfo& cellStateOut,
244                          const TensorInfo& output,
245                          const LstmDescriptor& descriptor,
246                          const LstmInputParamsInfo& paramsInfo,
247                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
248 
249     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
250     bool IsMaximumSupported(const TensorInfo& input0,
251                             const TensorInfo& input1,
252                             const TensorInfo& output,
253                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
254 
255     bool IsMeanSupported(const TensorInfo& input,
256                          const TensorInfo& output,
257                          const MeanDescriptor& descriptor,
258                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
259 
260     bool IsMemCopySupported(const TensorInfo& input,
261                             const TensorInfo& output,
262                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
263 
264     bool IsMemImportSupported(const TensorInfo& input,
265                               const TensorInfo& output,
266                               Optional<std::string&> reasonIfUnsupported = EmptyOptional());
267 
268     bool IsMergeSupported(const TensorInfo& input0,
269                           const TensorInfo& input1,
270                           const TensorInfo& output,
271                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
272 
273     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
274     bool IsMinimumSupported(const TensorInfo& input0,
275                             const TensorInfo& input1,
276                             const TensorInfo& output,
277                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
278 
279     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
280     bool IsMultiplicationSupported(const TensorInfo& input0,
281                                    const TensorInfo& input1,
282                                    const TensorInfo& output,
283                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional());
284 
285     bool IsNormalizationSupported(const TensorInfo& input,
286                                   const TensorInfo& output,
287                                   const NormalizationDescriptor& descriptor,
288                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional());
289 
290     bool IsOutputSupported(const TensorInfo& output,
291                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
292 
293     bool IsPadSupported(const TensorInfo& input,
294                         const TensorInfo& output,
295                         const PadDescriptor& descriptor,
296                         Optional<std::string&> reasonIfUnsupported = EmptyOptional());
297 
298     bool IsPermuteSupported(const TensorInfo& input,
299                             const TensorInfo& output,
300                             const PermuteDescriptor& descriptor,
301                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
302 
303     bool IsPooling2dSupported(const TensorInfo& input,
304                               const TensorInfo& output,
305                               const Pooling2dDescriptor& descriptor,
306                               Optional<std::string&> reasonIfUnsupported = EmptyOptional());
307 
308     bool IsPooling3dSupported(const TensorInfo& input,
309                               const TensorInfo& output,
310                               const Pooling3dDescriptor& descriptor,
311                               Optional<std::string&> reasonIfUnsupported = EmptyOptional());
312 
313     bool IsPreCompiledSupported(const TensorInfo& input,
314                                 const PreCompiledDescriptor& descriptor,
315                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
316 
317     bool IsPreluSupported(const TensorInfo& input,
318                           const TensorInfo& alpha,
319                           const TensorInfo& output,
320                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
321 
322     bool IsQuantizeSupported(const TensorInfo& input,
323                              const TensorInfo& output,
324                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
325 
326     bool IsQLstmSupported(const TensorInfo& input,
327                           const TensorInfo& previousOutputIn,
328                           const TensorInfo& previousCellStateIn,
329                           const TensorInfo& outputStateOut,
330                           const TensorInfo& cellStateOut,
331                           const TensorInfo& output,
332                           const QLstmDescriptor& descriptor,
333                           const LstmInputParamsInfo& paramsInfo,
334                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
335 
336     bool IsQuantizedLstmSupported(const TensorInfo& input,
337                                   const TensorInfo& previousCellStateIn,
338                                   const TensorInfo& previousOutputIn,
339                                   const TensorInfo& cellStateOut,
340                                   const TensorInfo& output,
341                                   const QuantizedLstmInputParamsInfo& paramsInfo,
342                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional());
343 
344     bool IsRankSupported(const TensorInfo& input,
345                          const TensorInfo& output,
346                          Optional<std::string&> reasonIfUnsupported = EmptyOptional());
347 
348     bool IsReduceSupported(const TensorInfo& input,
349                            const TensorInfo& output,
350                            const ReduceDescriptor& descriptor,
351                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
352 
353     bool IsReshapeSupported(const TensorInfo& input,
354                             const TensorInfo& output,
355                             const ReshapeDescriptor& descriptor,
356                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
357 
358     bool IsResizeSupported(const TensorInfo& input,
359                            const TensorInfo& output,
360                            const ResizeDescriptor& descriptor,
361                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
362 
363     bool IsShapeSupported(const TensorInfo& input,
364                           const TensorInfo& output,
365                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
366 
367     bool IsSliceSupported(const TensorInfo& input,
368                           const TensorInfo& output,
369                           const SliceDescriptor& descriptor,
370                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
371 
372     bool IsSoftmaxSupported(const TensorInfo& input,
373                             const TensorInfo& output,
374                             const SoftmaxDescriptor& descriptor,
375                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
376 
377     bool IsSpaceToBatchNdSupported(const TensorInfo& input,
378                                    const TensorInfo& output,
379                                    const SpaceToBatchNdDescriptor& descriptor,
380                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional());
381 
382     bool IsSpaceToDepthSupported(const TensorInfo& input,
383                                  const TensorInfo& output,
384                                  const SpaceToDepthDescriptor& descriptor,
385                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional());
386 
387     bool IsSplitterSupported(const TensorInfo& input,
388                              const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
389                              const ViewsDescriptor& descriptor,
390                              Optional<std::string&> reasonIfUnsupported = EmptyOptional());
391 
392     bool IsStackSupported(const std::vector<const TensorInfo*>& inputs,
393                           const TensorInfo& output,
394                           const StackDescriptor& descriptor,
395                           Optional<std::string&> reasonIfUnsupported = EmptyOptional());
396 
397     bool IsStandInSupported(const std::vector<const TensorInfo*>& inputs,
398                             const std::vector<const TensorInfo*>& outputs,
399                             const StandInDescriptor& descriptor,
400                             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
401 
402 
403     bool IsStridedSliceSupported(const TensorInfo& input,
404                                  const TensorInfo& output,
405                                  const StridedSliceDescriptor& descriptor,
406                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional());
407 
408     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02")
409     bool IsSubtractionSupported(const TensorInfo& input0,
410                                 const TensorInfo& input1,
411                                 const TensorInfo& output,
412                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional());
413 
414     bool IsSwitchSupported(const TensorInfo& input0,
415                            const TensorInfo& input1,
416                            const TensorInfo& output0,
417                            const TensorInfo& output1,
418                            Optional<std::string&> reasonIfUnsupported = EmptyOptional());
419 
420     bool IsTransposeConvolution2dSupported(
421             const TensorInfo& input,
422             const TensorInfo& output,
423             const TransposeConvolution2dDescriptor& descriptor,
424             const TensorInfo& weights,
425             const Optional<TensorInfo>& biases,
426             Optional<std::string&> reasonIfUnsupported = EmptyOptional());
427 
428     bool IsTransposeSupported(const TensorInfo& input,
429                               const TensorInfo& output,
430                               const TransposeDescriptor& descriptor,
431                               Optional<std::string&> reasonIfUnsupported = EmptyOptional());
432 
433     bool IsUnidirectionalSequenceLstmSupported(
434         const TensorInfo& input,
435         const TensorInfo& outputStateIn,
436         const TensorInfo& cellStateIn,
437         const TensorInfo& outputStateOut,
438         const TensorInfo& cellStateOut,
439         const TensorInfo& output,
440         const LstmDescriptor& descriptor,
441         const LstmInputParamsInfo& paramsInfo,
442         Optional<std::string&> reasonIfUnsupported = EmptyOptional());
443 
444 private:
445     std::shared_ptr<ILayerSupport> m_LayerSupport;
446     const BackendId m_BackendId;
447 };
448 
449 /// Convenience function to retrieve the ILayerSupportHandle for a backend
450 LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend);
451 
452 /// Convenience function to check if a capability exists in a BackendCapabilites struct
453 bool HasCapability(const std::string& name,const BackendCapabilities& capabilities);
454 
455 /// Convenience function to check if a capability exists in a backend
456 bool HasCapability(const std::string& name, const armnn::BackendId& backend);
457 
458 /// Convenience function to check if a given capability matches a  capability in a BackendCapabilities struct
459 bool HasCapability(const BackendOptions::BackendOption& capability, const BackendCapabilities& capabilities);
460 
461 /// Convenience function to check if a given capability matches a  capability in a backend
462 bool HasCapability(const BackendOptions::BackendOption& backendOption, const armnn::BackendId& backend);
463 
464 /// Returns a BackendCapability if the backend lists the capability
465 /// The BackendCapability must then be inspected to check whether or not that BackendCapability is supported
466 /// Otherwise returns an EmptyOptional if the BackendCapability is unlisted
467 Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
468                                                             const BackendCapabilities& capabilities);
469 
470 /// Returns a BackendCapability if the backend lists the capability
471 /// The BackendCapability must then be inspected to check whether or not that BackendCapability is supported
472 /// Otherwise returns an EmptyOptional if the BackendCapability is unlisted
473 Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName,
474                                                             const armnn::BackendId& backend);
475 
476 /// Returns the number of cached files if backend supports caching
477 unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend);
478 
479 }
480