1 // 2 // Copyright © 2017-2019,2021-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/BackendId.hpp> 9 #include <armnn/BackendOptions.hpp> 10 #include <armnn/Descriptors.hpp> 11 #include <armnn/Optional.hpp> 12 #include <functional> 13 #include <memory> 14 #include <string> 15 #include <utility> 16 #include <vector> 17 18 namespace armnn 19 { 20 class ILayerSupport; 21 class TensorInfo; 22 struct LstmInputParamsInfo; 23 struct QuantizedLstmInputParamsInfo; 24 25 // This handle calls its own IsXXXLayerSupported() functions which then call the polymorphic 26 // ILayerSupport::IsXXXLayerSupported() at the framework level so there is no risk of VTable misalignment. 27 // This is to make ILayerSupport in its abstract form a solely Backend interface alongside a 28 // separate ABI stable frontend class free of virtual functions via an added layer of indirection. 29 class LayerSupportHandle 30 { 31 public: LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport)32 explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport) 33 : m_LayerSupport(std::move(layerSupport)), m_BackendId(Compute::Undefined) {}; 34 LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport,const BackendId & backendId)35 explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport, const BackendId& backendId) 36 : m_LayerSupport(std::move(layerSupport)), m_BackendId(backendId) {}; 37 38 bool IsBackendRegistered() const; 39 40 bool IsActivationSupported(const TensorInfo& input, 41 const TensorInfo& output, 42 const ActivationDescriptor& descriptor, 43 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 44 45 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 46 bool IsAdditionSupported(const TensorInfo& input0, 47 const TensorInfo& input1, 48 const TensorInfo& output, 49 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 50 51 bool IsArgMinMaxSupported(const TensorInfo& input, 52 const TensorInfo& output, 53 const ArgMinMaxDescriptor& descriptor, 54 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 55 56 bool IsBatchMatMulSupported(const TensorInfo& input0, 57 const TensorInfo& input1, 58 const TensorInfo& output, 59 const BatchMatMulDescriptor& descriptor, 60 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 61 62 bool IsBatchNormalizationSupported(const TensorInfo& input, 63 const TensorInfo& output, 64 const TensorInfo& mean, 65 const TensorInfo& var, 66 const TensorInfo& beta, 67 const TensorInfo& gamma, 68 const BatchNormalizationDescriptor& descriptor, 69 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 70 71 bool IsBatchToSpaceNdSupported(const TensorInfo& input, 72 const TensorInfo& output, 73 const BatchToSpaceNdDescriptor& descriptor, 74 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 75 76 bool IsCastSupported(const TensorInfo& input, 77 const TensorInfo& output, 78 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 79 80 bool IsChannelShuffleSupported(const TensorInfo& input, 81 const TensorInfo& output, 82 const ChannelShuffleDescriptor& descriptor, 83 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 84 85 bool IsComparisonSupported(const TensorInfo& input0, 86 const TensorInfo& input1, 87 const TensorInfo& output, 88 const ComparisonDescriptor& descriptor, 89 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 90 91 bool IsConcatSupported(const std::vector<const TensorInfo*> inputs, 92 const TensorInfo& output, 93 const OriginsDescriptor& descriptor, 94 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 95 96 bool IsConstantSupported(const TensorInfo& output, 97 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 98 99 bool IsConvertFp16ToFp32Supported(const TensorInfo& input, 100 const TensorInfo& output, 101 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 102 103 bool IsConvertFp32ToFp16Supported(const TensorInfo& input, 104 const TensorInfo& output, 105 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 106 107 bool IsConvolution2dSupported(const TensorInfo& input, 108 const TensorInfo& output, 109 const Convolution2dDescriptor& descriptor, 110 const TensorInfo& weights, 111 const Optional<TensorInfo>& biases, 112 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 113 114 bool IsConvolution3dSupported(const TensorInfo& input, 115 const TensorInfo& output, 116 const Convolution3dDescriptor& descriptor, 117 const TensorInfo& weights, 118 const Optional<TensorInfo>& biases, 119 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 120 121 bool IsDebugSupported(const TensorInfo& input, 122 const TensorInfo& output, 123 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 124 125 bool IsDepthToSpaceSupported(const TensorInfo& input, 126 const TensorInfo& output, 127 const DepthToSpaceDescriptor& descriptor, 128 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 129 130 bool IsDepthwiseConvolutionSupported( 131 const TensorInfo& input, 132 const TensorInfo& output, 133 const DepthwiseConvolution2dDescriptor& descriptor, 134 const TensorInfo& weights, 135 const Optional<TensorInfo>& biases, 136 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 137 138 bool IsDequantizeSupported(const TensorInfo& input, 139 const TensorInfo& output, 140 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 141 142 bool IsDetectionPostProcessSupported(const TensorInfo& boxEncodings, 143 const TensorInfo& scores, 144 const TensorInfo& anchors, 145 const TensorInfo& detectionBoxes, 146 const TensorInfo& detectionClasses, 147 const TensorInfo& detectionScores, 148 const TensorInfo& numDetections, 149 const DetectionPostProcessDescriptor& descriptor, 150 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 151 152 bool IsDilatedDepthwiseConvolutionSupported( 153 const TensorInfo& input, 154 const TensorInfo& output, 155 const DepthwiseConvolution2dDescriptor& descriptor, 156 const TensorInfo& weights, 157 const Optional<TensorInfo>& biases, 158 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 159 160 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 161 bool IsDivisionSupported(const TensorInfo& input0, 162 const TensorInfo& input1, 163 const TensorInfo& output, 164 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 165 166 bool IsElementwiseBinarySupported(const TensorInfo& input0, 167 const TensorInfo& input1, 168 const TensorInfo& output, 169 const ElementwiseBinaryDescriptor& descriptor, 170 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 171 172 bool IsElementwiseUnarySupported(const TensorInfo& input, 173 const TensorInfo& output, 174 const ElementwiseUnaryDescriptor& descriptor, 175 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 176 177 bool IsFakeQuantizationSupported(const TensorInfo& input, 178 const FakeQuantizationDescriptor& descriptor, 179 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 180 181 bool IsFillSupported(const TensorInfo& input, 182 const TensorInfo& output, 183 const FillDescriptor& descriptor, 184 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 185 186 bool IsFloorSupported(const TensorInfo& input, 187 const TensorInfo& output, 188 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 189 190 bool IsFullyConnectedSupported(const TensorInfo& input, 191 const TensorInfo& output, 192 const TensorInfo& weights, 193 const TensorInfo& biases, 194 const FullyConnectedDescriptor& descriptor, 195 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 196 197 bool IsGatherSupported(const TensorInfo& input0, 198 const TensorInfo& input1, 199 const TensorInfo& output, 200 const GatherDescriptor& descriptor, 201 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 202 203 bool IsGatherNdSupported(const TensorInfo& input0, 204 const TensorInfo& input1, 205 const TensorInfo& output, 206 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 207 208 bool IsInputSupported(const TensorInfo& input, 209 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 210 211 bool IsInstanceNormalizationSupported( 212 const TensorInfo& input, 213 const TensorInfo& output, 214 const InstanceNormalizationDescriptor& descriptor, 215 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 216 217 bool IsL2NormalizationSupported(const TensorInfo& input, 218 const TensorInfo& output, 219 const L2NormalizationDescriptor& descriptor, 220 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 221 222 bool IsLogicalBinarySupported(const TensorInfo& input0, 223 const TensorInfo& input1, 224 const TensorInfo& output, 225 const LogicalBinaryDescriptor& descriptor, 226 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 227 228 bool IsLogicalUnarySupported(const TensorInfo& input, 229 const TensorInfo& output, 230 const ElementwiseUnaryDescriptor& descriptor, 231 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 232 233 bool IsLogSoftmaxSupported(const TensorInfo& input, 234 const TensorInfo& output, 235 const LogSoftmaxDescriptor& descriptor, 236 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 237 238 bool IsLstmSupported(const TensorInfo& input, 239 const TensorInfo& outputStateIn, 240 const TensorInfo& cellStateIn, 241 const TensorInfo& scratchBuffer, 242 const TensorInfo& outputStateOut, 243 const TensorInfo& cellStateOut, 244 const TensorInfo& output, 245 const LstmDescriptor& descriptor, 246 const LstmInputParamsInfo& paramsInfo, 247 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 248 249 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 250 bool IsMaximumSupported(const TensorInfo& input0, 251 const TensorInfo& input1, 252 const TensorInfo& output, 253 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 254 255 bool IsMeanSupported(const TensorInfo& input, 256 const TensorInfo& output, 257 const MeanDescriptor& descriptor, 258 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 259 260 bool IsMemCopySupported(const TensorInfo& input, 261 const TensorInfo& output, 262 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 263 264 bool IsMemImportSupported(const TensorInfo& input, 265 const TensorInfo& output, 266 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 267 268 bool IsMergeSupported(const TensorInfo& input0, 269 const TensorInfo& input1, 270 const TensorInfo& output, 271 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 272 273 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 274 bool IsMinimumSupported(const TensorInfo& input0, 275 const TensorInfo& input1, 276 const TensorInfo& output, 277 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 278 279 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 280 bool IsMultiplicationSupported(const TensorInfo& input0, 281 const TensorInfo& input1, 282 const TensorInfo& output, 283 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 284 285 bool IsNormalizationSupported(const TensorInfo& input, 286 const TensorInfo& output, 287 const NormalizationDescriptor& descriptor, 288 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 289 290 bool IsOutputSupported(const TensorInfo& output, 291 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 292 293 bool IsPadSupported(const TensorInfo& input, 294 const TensorInfo& output, 295 const PadDescriptor& descriptor, 296 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 297 298 bool IsPermuteSupported(const TensorInfo& input, 299 const TensorInfo& output, 300 const PermuteDescriptor& descriptor, 301 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 302 303 bool IsPooling2dSupported(const TensorInfo& input, 304 const TensorInfo& output, 305 const Pooling2dDescriptor& descriptor, 306 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 307 308 bool IsPooling3dSupported(const TensorInfo& input, 309 const TensorInfo& output, 310 const Pooling3dDescriptor& descriptor, 311 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 312 313 bool IsPreCompiledSupported(const TensorInfo& input, 314 const PreCompiledDescriptor& descriptor, 315 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 316 317 bool IsPreluSupported(const TensorInfo& input, 318 const TensorInfo& alpha, 319 const TensorInfo& output, 320 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 321 322 bool IsQuantizeSupported(const TensorInfo& input, 323 const TensorInfo& output, 324 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 325 326 bool IsQLstmSupported(const TensorInfo& input, 327 const TensorInfo& previousOutputIn, 328 const TensorInfo& previousCellStateIn, 329 const TensorInfo& outputStateOut, 330 const TensorInfo& cellStateOut, 331 const TensorInfo& output, 332 const QLstmDescriptor& descriptor, 333 const LstmInputParamsInfo& paramsInfo, 334 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 335 336 bool IsQuantizedLstmSupported(const TensorInfo& input, 337 const TensorInfo& previousCellStateIn, 338 const TensorInfo& previousOutputIn, 339 const TensorInfo& cellStateOut, 340 const TensorInfo& output, 341 const QuantizedLstmInputParamsInfo& paramsInfo, 342 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 343 344 bool IsRankSupported(const TensorInfo& input, 345 const TensorInfo& output, 346 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 347 348 bool IsReduceSupported(const TensorInfo& input, 349 const TensorInfo& output, 350 const ReduceDescriptor& descriptor, 351 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 352 353 bool IsReshapeSupported(const TensorInfo& input, 354 const TensorInfo& output, 355 const ReshapeDescriptor& descriptor, 356 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 357 358 bool IsResizeSupported(const TensorInfo& input, 359 const TensorInfo& output, 360 const ResizeDescriptor& descriptor, 361 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 362 363 bool IsShapeSupported(const TensorInfo& input, 364 const TensorInfo& output, 365 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 366 367 bool IsSliceSupported(const TensorInfo& input, 368 const TensorInfo& output, 369 const SliceDescriptor& descriptor, 370 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 371 372 bool IsSoftmaxSupported(const TensorInfo& input, 373 const TensorInfo& output, 374 const SoftmaxDescriptor& descriptor, 375 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 376 377 bool IsSpaceToBatchNdSupported(const TensorInfo& input, 378 const TensorInfo& output, 379 const SpaceToBatchNdDescriptor& descriptor, 380 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 381 382 bool IsSpaceToDepthSupported(const TensorInfo& input, 383 const TensorInfo& output, 384 const SpaceToDepthDescriptor& descriptor, 385 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 386 387 bool IsSplitterSupported(const TensorInfo& input, 388 const std::vector<std::reference_wrapper<TensorInfo>>& outputs, 389 const ViewsDescriptor& descriptor, 390 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 391 392 bool IsStackSupported(const std::vector<const TensorInfo*>& inputs, 393 const TensorInfo& output, 394 const StackDescriptor& descriptor, 395 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 396 397 bool IsStandInSupported(const std::vector<const TensorInfo*>& inputs, 398 const std::vector<const TensorInfo*>& outputs, 399 const StandInDescriptor& descriptor, 400 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 401 402 403 bool IsStridedSliceSupported(const TensorInfo& input, 404 const TensorInfo& output, 405 const StridedSliceDescriptor& descriptor, 406 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 407 408 ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use IsElementwiseBinarySupported instead", "24.02") 409 bool IsSubtractionSupported(const TensorInfo& input0, 410 const TensorInfo& input1, 411 const TensorInfo& output, 412 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 413 414 bool IsSwitchSupported(const TensorInfo& input0, 415 const TensorInfo& input1, 416 const TensorInfo& output0, 417 const TensorInfo& output1, 418 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 419 420 bool IsTransposeConvolution2dSupported( 421 const TensorInfo& input, 422 const TensorInfo& output, 423 const TransposeConvolution2dDescriptor& descriptor, 424 const TensorInfo& weights, 425 const Optional<TensorInfo>& biases, 426 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 427 428 bool IsTransposeSupported(const TensorInfo& input, 429 const TensorInfo& output, 430 const TransposeDescriptor& descriptor, 431 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 432 433 bool IsUnidirectionalSequenceLstmSupported( 434 const TensorInfo& input, 435 const TensorInfo& outputStateIn, 436 const TensorInfo& cellStateIn, 437 const TensorInfo& outputStateOut, 438 const TensorInfo& cellStateOut, 439 const TensorInfo& output, 440 const LstmDescriptor& descriptor, 441 const LstmInputParamsInfo& paramsInfo, 442 Optional<std::string&> reasonIfUnsupported = EmptyOptional()); 443 444 private: 445 std::shared_ptr<ILayerSupport> m_LayerSupport; 446 const BackendId m_BackendId; 447 }; 448 449 /// Convenience function to retrieve the ILayerSupportHandle for a backend 450 LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend); 451 452 /// Convenience function to check if a capability exists in a BackendCapabilites struct 453 bool HasCapability(const std::string& name,const BackendCapabilities& capabilities); 454 455 /// Convenience function to check if a capability exists in a backend 456 bool HasCapability(const std::string& name, const armnn::BackendId& backend); 457 458 /// Convenience function to check if a given capability matches a capability in a BackendCapabilities struct 459 bool HasCapability(const BackendOptions::BackendOption& capability, const BackendCapabilities& capabilities); 460 461 /// Convenience function to check if a given capability matches a capability in a backend 462 bool HasCapability(const BackendOptions::BackendOption& backendOption, const armnn::BackendId& backend); 463 464 /// Returns a BackendCapability if the backend lists the capability 465 /// The BackendCapability must then be inspected to check whether or not that BackendCapability is supported 466 /// Otherwise returns an EmptyOptional if the BackendCapability is unlisted 467 Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName, 468 const BackendCapabilities& capabilities); 469 470 /// Returns a BackendCapability if the backend lists the capability 471 /// The BackendCapability must then be inspected to check whether or not that BackendCapability is supported 472 /// Otherwise returns an EmptyOptional if the BackendCapability is unlisted 473 Optional<const BackendOptions::BackendOption> GetCapability(const std::string& backendCapabilityName, 474 const armnn::BackendId& backend); 475 476 /// Returns the number of cached files if backend supports caching 477 unsigned int GetNumberOfCacheFiles(const armnn::BackendId& backend); 478 479 } 480