1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ClLayerSupport.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendModelContext.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <InternalTypes.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <LayerSupportCommon.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
19*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClAbsWorkload.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClAdditionWorkload.hpp"
23*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClActivationWorkload.hpp"
24*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClArgMinMaxWorkload.hpp"
25*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClBatchMatMulWorkload.hpp"
26*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClBatchNormalizationFloatWorkload.hpp"
27*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClBatchToSpaceNdWorkload.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClCastWorkload.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClChannelShuffleWorkload.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClComparisonWorkload.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConstantWorkload.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConvertFp16ToFp32Workload.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConvertFp32ToFp16Workload.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConvolution2dWorkload.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConvolution3dWorkload.hpp"
36*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDepthToSpaceWorkload.hpp"
37*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDepthwiseConvolutionWorkload.hpp"
38*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDequantizeWorkload.hpp"
39*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDivisionWorkload.hpp"
40*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClExpWorkload.hpp"
41*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClFillWorkload.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClFloorFloatWorkload.hpp"
43*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClFullyConnectedWorkload.hpp"
44*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClGatherWorkload.hpp"
45*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClGatherNdWorkload.hpp"
46*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClInstanceNormalizationWorkload.hpp"
47*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClL2NormalizationFloatWorkload.hpp"
48*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLogWorkload.hpp"
49*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLogSoftmaxWorkload.hpp"
50*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLogicalAndWorkload.hpp"
51*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLogicalNotWorkload.hpp"
52*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLogicalOrWorkload.hpp"
53*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClLstmFloatWorkload.hpp"
54*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClMaximumWorkload.hpp"
55*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClMeanWorkload.hpp"
56*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConcatWorkload.hpp"
57*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClMinimumWorkload.hpp"
58*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClMultiplicationWorkload.hpp"
59*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClNegWorkload.hpp"
60*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClNormalizationFloatWorkload.hpp"
61*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClPadWorkload.hpp"
62*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClPermuteWorkload.hpp"
63*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClPooling2dWorkload.hpp"
64*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClPooling3dWorkload.hpp"
65*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClPreluWorkload.hpp"
66*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClQLstmWorkload.hpp"
67*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClQuantizedLstmWorkload.hpp"
68*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClQuantizeWorkload.hpp"
69*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClReduceWorkload.hpp"
70*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClReshapeWorkload.hpp"
71*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClResizeWorkload.hpp"
72*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClRsqrtWorkload.hpp"
73*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSinWorkload.hpp"
74*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSliceWorkload.hpp"
75*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSoftmaxWorkload.hpp"
76*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSpaceToBatchNdWorkload.hpp"
77*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSpaceToDepthWorkload.hpp"
78*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSplitterWorkload.hpp"
79*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSqrtWorkload.hpp"
80*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClStackWorkload.hpp"
81*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClStridedSliceWorkload.hpp"
82*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSubtractionWorkload.hpp"
83*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClTransposeConvolution2dWorkload.hpp"
84*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClTransposeWorkload.hpp"
85*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp"
86*89c4ff92SAndroid Build Coastguard Worker #endif
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker namespace armnn
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker namespace
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker template<unsigned int FilterSize>
IsMatchingSize2d(const TensorInfo & weightInfo)96*89c4ff92SAndroid Build Coastguard Worker bool IsMatchingSize2d(const TensorInfo& weightInfo)
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker // Width & Height must match.
99*89c4ff92SAndroid Build Coastguard Worker return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker template<uint32_t ValidStride>
IsMatchingStride(uint32_t actualStride)103*89c4ff92SAndroid Build Coastguard Worker bool IsMatchingStride(uint32_t actualStride)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker return ValidStride == actualStride;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
IsMatchingStride(uint32_t actualStride)109*89c4ff92SAndroid Build Coastguard Worker bool IsMatchingStride(uint32_t actualStride)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker template<typename ... Args>
IsClBackendSupported(Optional<std::string &> reasonIfUnsupported,Args...args)115*89c4ff92SAndroid Build Coastguard Worker bool IsClBackendSupported(Optional<std::string&> reasonIfUnsupported, Args... args)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(reasonIfUnsupported, (args)...);
118*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
119*89c4ff92SAndroid Build Coastguard Worker return true;
120*89c4ff92SAndroid Build Coastguard Worker #else
121*89c4ff92SAndroid Build Coastguard Worker if (reasonIfUnsupported)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported.value() = "The armnn library has been built without CL support";
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker return false;
126*89c4ff92SAndroid Build Coastguard Worker #endif
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
130*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
131*89c4ff92SAndroid Build Coastguard Worker #else
132*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
133*89c4ff92SAndroid Build Coastguard Worker #endif
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
136*89c4ff92SAndroid Build Coastguard Worker template<class FuncType, class... Args>
IsWorkloadSupported(FuncType && func,Optional<std::string &> reasonIfUnsupported,Args &&...args)137*89c4ff92SAndroid Build Coastguard Worker inline bool IsWorkloadSupported(FuncType&& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
140*89c4ff92SAndroid Build Coastguard Worker const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
141*89c4ff92SAndroid Build Coastguard Worker if (!supported && reasonIfUnsupported)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported.value() = aclStatus.error_description();
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker return supported;
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
149*89c4ff92SAndroid Build Coastguard Worker return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
150*89c4ff92SAndroid Build Coastguard Worker #else
151*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
152*89c4ff92SAndroid Build Coastguard Worker return IsClBackendSupported(reasonIfUnsupported, __VA_ARGS__);
153*89c4ff92SAndroid Build Coastguard Worker #endif
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker template<typename FloatFunc, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeCl(Optional<std::string &> reasonIfUnsupported,DataType dataType,FloatFunc floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)156*89c4ff92SAndroid Build Coastguard Worker bool IsSupportedForDataTypeCl(Optional<std::string&> reasonIfUnsupported,
157*89c4ff92SAndroid Build Coastguard Worker DataType dataType,
158*89c4ff92SAndroid Build Coastguard Worker FloatFunc floatFuncPtr,
159*89c4ff92SAndroid Build Coastguard Worker Uint8Func uint8FuncPtr,
160*89c4ff92SAndroid Build Coastguard Worker Params&&... params)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker return IsClBackendSupported(reasonIfUnsupported) &&
163*89c4ff92SAndroid Build Coastguard Worker IsSupportedForDataTypeGeneric(reasonIfUnsupported,
164*89c4ff92SAndroid Build Coastguard Worker dataType,
165*89c4ff92SAndroid Build Coastguard Worker floatFuncPtr,
166*89c4ff92SAndroid Build Coastguard Worker floatFuncPtr,
167*89c4ff92SAndroid Build Coastguard Worker uint8FuncPtr,
168*89c4ff92SAndroid Build Coastguard Worker &FalseFunc<>,
169*89c4ff92SAndroid Build Coastguard Worker &FalseFunc<>,
170*89c4ff92SAndroid Build Coastguard Worker std::forward<Params>(params)...);
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
173*89c4ff92SAndroid Build Coastguard Worker
ClLayerSupport(const IBackendInternal::IBackendSpecificModelContextPtr & modelContextPtr)174*89c4ff92SAndroid Build Coastguard Worker ClLayerSupport::ClLayerSupport(const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
175*89c4ff92SAndroid Build Coastguard Worker : m_ModelContextPtr(modelContextPtr)
176*89c4ff92SAndroid Build Coastguard Worker {
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker
ClLayerSupport()179*89c4ff92SAndroid Build Coastguard Worker ClLayerSupport::ClLayerSupport()
180*89c4ff92SAndroid Build Coastguard Worker : m_ModelContextPtr(nullptr)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmParamsInfo,Optional<std::string &> reasonIfUnsupported) const184*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsLayerSupported(const LayerType& type,
185*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& infos,
186*89c4ff92SAndroid Build Coastguard Worker const BaseDescriptor& descriptor,
187*89c4ff92SAndroid Build Coastguard Worker const Optional<LstmInputParamsInfo>& lstmParamsInfo,
188*89c4ff92SAndroid Build Coastguard Worker const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmParamsInfo,
189*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker switch (type)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker case LayerType::Activation:
194*89c4ff92SAndroid Build Coastguard Worker return IsActivationSupported(infos[0],
195*89c4ff92SAndroid Build Coastguard Worker infos[1],
196*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
197*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
198*89c4ff92SAndroid Build Coastguard Worker case LayerType::Addition:
199*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
200*89c4ff92SAndroid Build Coastguard Worker return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
201*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
202*89c4ff92SAndroid Build Coastguard Worker case LayerType::ArgMinMax:
203*89c4ff92SAndroid Build Coastguard Worker return IsArgMinMaxSupported(infos[0],
204*89c4ff92SAndroid Build Coastguard Worker infos[1],
205*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
206*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
207*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchMatMul:
208*89c4ff92SAndroid Build Coastguard Worker return IsBatchMatMulSupported(infos[0],
209*89c4ff92SAndroid Build Coastguard Worker infos[1],
210*89c4ff92SAndroid Build Coastguard Worker infos[2],
211*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
212*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
213*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchNormalization:
214*89c4ff92SAndroid Build Coastguard Worker return IsBatchNormalizationSupported(infos[0],
215*89c4ff92SAndroid Build Coastguard Worker infos[1],
216*89c4ff92SAndroid Build Coastguard Worker infos[2],
217*89c4ff92SAndroid Build Coastguard Worker infos[3],
218*89c4ff92SAndroid Build Coastguard Worker infos[4],
219*89c4ff92SAndroid Build Coastguard Worker infos[5],
220*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchNormalizationDescriptor*>
221*89c4ff92SAndroid Build Coastguard Worker (&descriptor)),
222*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
223*89c4ff92SAndroid Build Coastguard Worker case LayerType::BatchToSpaceNd:
224*89c4ff92SAndroid Build Coastguard Worker return IsBatchToSpaceNdSupported(infos[0],
225*89c4ff92SAndroid Build Coastguard Worker infos[1],
226*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const BatchToSpaceNdDescriptor*>(&descriptor)),
227*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
228*89c4ff92SAndroid Build Coastguard Worker case LayerType::Cast:
229*89c4ff92SAndroid Build Coastguard Worker return IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
230*89c4ff92SAndroid Build Coastguard Worker case LayerType::ChannelShuffle:
231*89c4ff92SAndroid Build Coastguard Worker return IsChannelShuffleSupported(infos[0],
232*89c4ff92SAndroid Build Coastguard Worker infos[1],
233*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ChannelShuffleDescriptor*>(&descriptor)),
234*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
235*89c4ff92SAndroid Build Coastguard Worker case LayerType::Comparison:
236*89c4ff92SAndroid Build Coastguard Worker return IsComparisonSupported(infos[0],
237*89c4ff92SAndroid Build Coastguard Worker infos[1],
238*89c4ff92SAndroid Build Coastguard Worker infos[2],
239*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
240*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
241*89c4ff92SAndroid Build Coastguard Worker case LayerType::Concat:
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker std::vector<const TensorInfo*> inputInfos;
244*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < (infos.size() - 1); i++)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker inputInfos.push_back(&infos[i]);
247*89c4ff92SAndroid Build Coastguard Worker }
248*89c4ff92SAndroid Build Coastguard Worker return IsConcatSupported(inputInfos,
249*89c4ff92SAndroid Build Coastguard Worker infos[infos.size() - 1],
250*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
251*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker case LayerType::Constant:
254*89c4ff92SAndroid Build Coastguard Worker return IsConstantSupported(infos[0], reasonIfUnsupported);
255*89c4ff92SAndroid Build Coastguard Worker case LayerType::ConvertFp16ToFp32:
256*89c4ff92SAndroid Build Coastguard Worker return IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
257*89c4ff92SAndroid Build Coastguard Worker case LayerType::ConvertFp32ToFp16:
258*89c4ff92SAndroid Build Coastguard Worker return IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
259*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution2d:
260*89c4ff92SAndroid Build Coastguard Worker {
261*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
262*89c4ff92SAndroid Build Coastguard Worker {
263*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of Convolution2d TensorInfos. "
264*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
268*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
269*89c4ff92SAndroid Build Coastguard Worker {
270*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0],
271*89c4ff92SAndroid Build Coastguard Worker infos[1],
272*89c4ff92SAndroid Build Coastguard Worker desc,
273*89c4ff92SAndroid Build Coastguard Worker infos[2],
274*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
275*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker else
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0],
280*89c4ff92SAndroid Build Coastguard Worker infos[1],
281*89c4ff92SAndroid Build Coastguard Worker desc,
282*89c4ff92SAndroid Build Coastguard Worker infos[2],
283*89c4ff92SAndroid Build Coastguard Worker infos[3],
284*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
285*89c4ff92SAndroid Build Coastguard Worker }
286*89c4ff92SAndroid Build Coastguard Worker }
287*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution3d:
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
290*89c4ff92SAndroid Build Coastguard Worker {
291*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
292*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker
295*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
296*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
297*89c4ff92SAndroid Build Coastguard Worker {
298*89c4ff92SAndroid Build Coastguard Worker return IsConvolution3dSupported(infos[0],
299*89c4ff92SAndroid Build Coastguard Worker infos[1],
300*89c4ff92SAndroid Build Coastguard Worker desc,
301*89c4ff92SAndroid Build Coastguard Worker infos[2],
302*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
303*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
304*89c4ff92SAndroid Build Coastguard Worker }
305*89c4ff92SAndroid Build Coastguard Worker else
306*89c4ff92SAndroid Build Coastguard Worker {
307*89c4ff92SAndroid Build Coastguard Worker return IsConvolution3dSupported(infos[0],
308*89c4ff92SAndroid Build Coastguard Worker infos[1],
309*89c4ff92SAndroid Build Coastguard Worker desc,
310*89c4ff92SAndroid Build Coastguard Worker infos[2],
311*89c4ff92SAndroid Build Coastguard Worker infos[3],
312*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
313*89c4ff92SAndroid Build Coastguard Worker }
314*89c4ff92SAndroid Build Coastguard Worker }
315*89c4ff92SAndroid Build Coastguard Worker case LayerType::DepthToSpace:
316*89c4ff92SAndroid Build Coastguard Worker return IsDepthToSpaceSupported(infos[0],
317*89c4ff92SAndroid Build Coastguard Worker infos[1],
318*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
319*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
320*89c4ff92SAndroid Build Coastguard Worker case LayerType::DepthwiseConvolution2d:
321*89c4ff92SAndroid Build Coastguard Worker {
322*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
323*89c4ff92SAndroid Build Coastguard Worker {
324*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
325*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
326*89c4ff92SAndroid Build Coastguard Worker }
327*89c4ff92SAndroid Build Coastguard Worker
328*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
329*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
330*89c4ff92SAndroid Build Coastguard Worker {
331*89c4ff92SAndroid Build Coastguard Worker return IsDepthwiseConvolutionSupported(infos[0],
332*89c4ff92SAndroid Build Coastguard Worker infos[1],
333*89c4ff92SAndroid Build Coastguard Worker desc,
334*89c4ff92SAndroid Build Coastguard Worker infos[2],
335*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
336*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
337*89c4ff92SAndroid Build Coastguard Worker }
338*89c4ff92SAndroid Build Coastguard Worker else
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker return IsDepthwiseConvolutionSupported(infos[0],
341*89c4ff92SAndroid Build Coastguard Worker infos[1],
342*89c4ff92SAndroid Build Coastguard Worker desc,
343*89c4ff92SAndroid Build Coastguard Worker infos[2],
344*89c4ff92SAndroid Build Coastguard Worker infos[3],
345*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
346*89c4ff92SAndroid Build Coastguard Worker }
347*89c4ff92SAndroid Build Coastguard Worker }
348*89c4ff92SAndroid Build Coastguard Worker case LayerType::Dequantize:
349*89c4ff92SAndroid Build Coastguard Worker return IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
350*89c4ff92SAndroid Build Coastguard Worker case LayerType::Division:
351*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
352*89c4ff92SAndroid Build Coastguard Worker return IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
353*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
354*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseBinary:
355*89c4ff92SAndroid Build Coastguard Worker {
356*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const ElementwiseBinaryDescriptor *>(&descriptor));
357*89c4ff92SAndroid Build Coastguard Worker
358*89c4ff92SAndroid Build Coastguard Worker switch (desc.m_Operation)
359*89c4ff92SAndroid Build Coastguard Worker {
360*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Add:
361*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
362*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
363*89c4ff92SAndroid Build Coastguard Worker infos[0],
364*89c4ff92SAndroid Build Coastguard Worker infos[1],
365*89c4ff92SAndroid Build Coastguard Worker infos[2],
366*89c4ff92SAndroid Build Coastguard Worker nullptr);
367*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Div:
368*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
369*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
370*89c4ff92SAndroid Build Coastguard Worker infos[0],
371*89c4ff92SAndroid Build Coastguard Worker infos[1],
372*89c4ff92SAndroid Build Coastguard Worker infos[2],
373*89c4ff92SAndroid Build Coastguard Worker nullptr);
374*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Minimum:
375*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
376*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
377*89c4ff92SAndroid Build Coastguard Worker infos[0],
378*89c4ff92SAndroid Build Coastguard Worker infos[1],
379*89c4ff92SAndroid Build Coastguard Worker infos[2]);
380*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Maximum:
381*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
382*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
383*89c4ff92SAndroid Build Coastguard Worker infos[0],
384*89c4ff92SAndroid Build Coastguard Worker infos[1],
385*89c4ff92SAndroid Build Coastguard Worker infos[2]);
386*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Mul:
387*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
388*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
389*89c4ff92SAndroid Build Coastguard Worker infos[0],
390*89c4ff92SAndroid Build Coastguard Worker infos[1],
391*89c4ff92SAndroid Build Coastguard Worker infos[2],
392*89c4ff92SAndroid Build Coastguard Worker nullptr);
393*89c4ff92SAndroid Build Coastguard Worker case BinaryOperation::Sub:
394*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
395*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
396*89c4ff92SAndroid Build Coastguard Worker infos[0],
397*89c4ff92SAndroid Build Coastguard Worker infos[1],
398*89c4ff92SAndroid Build Coastguard Worker infos[2],
399*89c4ff92SAndroid Build Coastguard Worker nullptr);
400*89c4ff92SAndroid Build Coastguard Worker default:
401*89c4ff92SAndroid Build Coastguard Worker return false;
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker }
404*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseUnary:
405*89c4ff92SAndroid Build Coastguard Worker return IsElementwiseUnarySupported(infos[0],
406*89c4ff92SAndroid Build Coastguard Worker infos[1],
407*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ElementwiseUnaryDescriptor*>(&descriptor)),
408*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
409*89c4ff92SAndroid Build Coastguard Worker case LayerType::Fill:
410*89c4ff92SAndroid Build Coastguard Worker return IsFillSupported(infos[0],
411*89c4ff92SAndroid Build Coastguard Worker infos[1],
412*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
413*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
414*89c4ff92SAndroid Build Coastguard Worker case LayerType::Floor:
415*89c4ff92SAndroid Build Coastguard Worker return IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
416*89c4ff92SAndroid Build Coastguard Worker case LayerType::FullyConnected:
417*89c4ff92SAndroid Build Coastguard Worker return IsFullyConnectedSupported(infos[0],
418*89c4ff92SAndroid Build Coastguard Worker infos[1],
419*89c4ff92SAndroid Build Coastguard Worker infos[2],
420*89c4ff92SAndroid Build Coastguard Worker infos[3],
421*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const FullyConnectedDescriptor*>(&descriptor)),
422*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
423*89c4ff92SAndroid Build Coastguard Worker case LayerType::Gather:
424*89c4ff92SAndroid Build Coastguard Worker return IsGatherSupported(infos[0],
425*89c4ff92SAndroid Build Coastguard Worker infos[1],
426*89c4ff92SAndroid Build Coastguard Worker infos[2],
427*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
428*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
429*89c4ff92SAndroid Build Coastguard Worker case LayerType::GatherNd:
430*89c4ff92SAndroid Build Coastguard Worker return IsGatherNdSupported(infos[0],
431*89c4ff92SAndroid Build Coastguard Worker infos[1],
432*89c4ff92SAndroid Build Coastguard Worker infos[2],
433*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
434*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
435*89c4ff92SAndroid Build Coastguard Worker return IsInputSupported(infos[0], reasonIfUnsupported);
436*89c4ff92SAndroid Build Coastguard Worker case LayerType::InstanceNormalization:
437*89c4ff92SAndroid Build Coastguard Worker return IsInstanceNormalizationSupported(infos[0],
438*89c4ff92SAndroid Build Coastguard Worker infos[1],
439*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const InstanceNormalizationDescriptor*>
440*89c4ff92SAndroid Build Coastguard Worker (&descriptor)),
441*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
442*89c4ff92SAndroid Build Coastguard Worker case LayerType::L2Normalization:
443*89c4ff92SAndroid Build Coastguard Worker return IsL2NormalizationSupported(infos[0],
444*89c4ff92SAndroid Build Coastguard Worker infos[1],
445*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const L2NormalizationDescriptor*>(&descriptor)),
446*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
447*89c4ff92SAndroid Build Coastguard Worker case LayerType::LogicalBinary:
448*89c4ff92SAndroid Build Coastguard Worker return IsLogicalBinarySupported(infos[0],
449*89c4ff92SAndroid Build Coastguard Worker infos[1],
450*89c4ff92SAndroid Build Coastguard Worker infos[2],
451*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LogicalBinaryDescriptor*>(&descriptor)),
452*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
453*89c4ff92SAndroid Build Coastguard Worker case LayerType::LogSoftmax:
454*89c4ff92SAndroid Build Coastguard Worker return IsLogSoftmaxSupported(infos[0],
455*89c4ff92SAndroid Build Coastguard Worker infos[1],
456*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
457*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
458*89c4ff92SAndroid Build Coastguard Worker case LayerType::Lstm:
459*89c4ff92SAndroid Build Coastguard Worker return IsLstmSupported(infos[0],
460*89c4ff92SAndroid Build Coastguard Worker infos[1],
461*89c4ff92SAndroid Build Coastguard Worker infos[2],
462*89c4ff92SAndroid Build Coastguard Worker infos[3],
463*89c4ff92SAndroid Build Coastguard Worker infos[4],
464*89c4ff92SAndroid Build Coastguard Worker infos[5],
465*89c4ff92SAndroid Build Coastguard Worker infos[6],
466*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
467*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
468*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
469*89c4ff92SAndroid Build Coastguard Worker case LayerType::Map:
470*89c4ff92SAndroid Build Coastguard Worker return true;
471*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemCopy:
472*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
473*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemImport:
474*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
475*89c4ff92SAndroid Build Coastguard Worker case LayerType::Merge:
476*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsMergeSupported(infos[0],
477*89c4ff92SAndroid Build Coastguard Worker infos[1],
478*89c4ff92SAndroid Build Coastguard Worker infos[2],
479*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
480*89c4ff92SAndroid Build Coastguard Worker case LayerType::Maximum:
481*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
482*89c4ff92SAndroid Build Coastguard Worker return IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
483*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
484*89c4ff92SAndroid Build Coastguard Worker case LayerType::Mean:
485*89c4ff92SAndroid Build Coastguard Worker return IsMeanSupported(infos[0],
486*89c4ff92SAndroid Build Coastguard Worker infos[1],
487*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
488*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
489*89c4ff92SAndroid Build Coastguard Worker case LayerType::Minimum:
490*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
491*89c4ff92SAndroid Build Coastguard Worker return IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
492*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
493*89c4ff92SAndroid Build Coastguard Worker case LayerType::Multiplication:
494*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
495*89c4ff92SAndroid Build Coastguard Worker return IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
496*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
497*89c4ff92SAndroid Build Coastguard Worker case LayerType::Normalization:
498*89c4ff92SAndroid Build Coastguard Worker return IsNormalizationSupported(infos[0],
499*89c4ff92SAndroid Build Coastguard Worker infos[1],
500*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const NormalizationDescriptor*>(&descriptor)),
501*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
502*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
503*89c4ff92SAndroid Build Coastguard Worker return IsOutputSupported(infos[0], reasonIfUnsupported);
504*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pad:
505*89c4ff92SAndroid Build Coastguard Worker return IsPadSupported(infos[0],
506*89c4ff92SAndroid Build Coastguard Worker infos[1],
507*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
508*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
509*89c4ff92SAndroid Build Coastguard Worker case LayerType::Permute:
510*89c4ff92SAndroid Build Coastguard Worker return IsPermuteSupported(infos[0],
511*89c4ff92SAndroid Build Coastguard Worker infos[1],
512*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
513*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
514*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pooling2d:
515*89c4ff92SAndroid Build Coastguard Worker return IsPooling2dSupported(infos[0],
516*89c4ff92SAndroid Build Coastguard Worker infos[1],
517*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
518*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
519*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pooling3d:
520*89c4ff92SAndroid Build Coastguard Worker return IsPooling3dSupported(infos[0],
521*89c4ff92SAndroid Build Coastguard Worker infos[1],
522*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
523*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
524*89c4ff92SAndroid Build Coastguard Worker case LayerType::Prelu:
525*89c4ff92SAndroid Build Coastguard Worker return IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
526*89c4ff92SAndroid Build Coastguard Worker case LayerType::QLstm:
527*89c4ff92SAndroid Build Coastguard Worker return IsQLstmSupported(infos[0],
528*89c4ff92SAndroid Build Coastguard Worker infos[1],
529*89c4ff92SAndroid Build Coastguard Worker infos[2],
530*89c4ff92SAndroid Build Coastguard Worker infos[3],
531*89c4ff92SAndroid Build Coastguard Worker infos[4],
532*89c4ff92SAndroid Build Coastguard Worker infos[5],
533*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
534*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
535*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
536*89c4ff92SAndroid Build Coastguard Worker case LayerType::Quantize:
537*89c4ff92SAndroid Build Coastguard Worker return IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
538*89c4ff92SAndroid Build Coastguard Worker case LayerType::QuantizedLstm:
539*89c4ff92SAndroid Build Coastguard Worker return IsQuantizedLstmSupported(infos[0],
540*89c4ff92SAndroid Build Coastguard Worker infos[1],
541*89c4ff92SAndroid Build Coastguard Worker infos[2],
542*89c4ff92SAndroid Build Coastguard Worker infos[3],
543*89c4ff92SAndroid Build Coastguard Worker infos[4],
544*89c4ff92SAndroid Build Coastguard Worker quantizedLstmParamsInfo.value(),
545*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
546*89c4ff92SAndroid Build Coastguard Worker case LayerType::Rank:
547*89c4ff92SAndroid Build Coastguard Worker return true;
548*89c4ff92SAndroid Build Coastguard Worker case LayerType::Reduce:
549*89c4ff92SAndroid Build Coastguard Worker return IsReduceSupported(infos[0],
550*89c4ff92SAndroid Build Coastguard Worker infos[1],
551*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
552*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
553*89c4ff92SAndroid Build Coastguard Worker case LayerType::Reshape:
554*89c4ff92SAndroid Build Coastguard Worker return IsReshapeSupported(infos[0],
555*89c4ff92SAndroid Build Coastguard Worker infos[1],
556*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
557*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
558*89c4ff92SAndroid Build Coastguard Worker case LayerType::Resize:
559*89c4ff92SAndroid Build Coastguard Worker return IsResizeSupported(infos[0],
560*89c4ff92SAndroid Build Coastguard Worker infos[1],
561*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
562*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
563*89c4ff92SAndroid Build Coastguard Worker case LayerType::Shape:
564*89c4ff92SAndroid Build Coastguard Worker return LayerSupportBase::IsShapeSupported(infos[0],
565*89c4ff92SAndroid Build Coastguard Worker infos[1],
566*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
567*89c4ff92SAndroid Build Coastguard Worker case LayerType::Slice:
568*89c4ff92SAndroid Build Coastguard Worker return IsSliceSupported(infos[0],
569*89c4ff92SAndroid Build Coastguard Worker infos[1],
570*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
571*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
572*89c4ff92SAndroid Build Coastguard Worker case LayerType::Softmax:
573*89c4ff92SAndroid Build Coastguard Worker return IsSoftmaxSupported(infos[0],
574*89c4ff92SAndroid Build Coastguard Worker infos[1],
575*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
576*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
577*89c4ff92SAndroid Build Coastguard Worker case LayerType::SpaceToBatchNd:
578*89c4ff92SAndroid Build Coastguard Worker return IsSpaceToBatchNdSupported(infos[0],
579*89c4ff92SAndroid Build Coastguard Worker infos[1],
580*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SpaceToBatchNdDescriptor*>(&descriptor)),
581*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
582*89c4ff92SAndroid Build Coastguard Worker case LayerType::SpaceToDepth:
583*89c4ff92SAndroid Build Coastguard Worker return IsSpaceToDepthSupported(infos[0],
584*89c4ff92SAndroid Build Coastguard Worker infos[1],
585*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
586*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
587*89c4ff92SAndroid Build Coastguard Worker case LayerType::Splitter:
588*89c4ff92SAndroid Build Coastguard Worker {
589*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorInfo> outputInfos;
590*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 1; i < infos.size(); i++)
591*89c4ff92SAndroid Build Coastguard Worker {
592*89c4ff92SAndroid Build Coastguard Worker outputInfos.push_back(infos[i]);
593*89c4ff92SAndroid Build Coastguard Worker }
594*89c4ff92SAndroid Build Coastguard Worker return IsSplitterSupported(infos[0],
595*89c4ff92SAndroid Build Coastguard Worker {outputInfos.begin(), outputInfos.end()},
596*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
597*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
598*89c4ff92SAndroid Build Coastguard Worker }
599*89c4ff92SAndroid Build Coastguard Worker case LayerType::Stack:
600*89c4ff92SAndroid Build Coastguard Worker {
601*89c4ff92SAndroid Build Coastguard Worker std::vector<const TensorInfo*> inputInfos;
602*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < infos.size() - 1; i++)
603*89c4ff92SAndroid Build Coastguard Worker {
604*89c4ff92SAndroid Build Coastguard Worker inputInfos.push_back(&infos[i]);
605*89c4ff92SAndroid Build Coastguard Worker }
606*89c4ff92SAndroid Build Coastguard Worker return IsStackSupported(inputInfos,
607*89c4ff92SAndroid Build Coastguard Worker infos[infos.size() - 1],
608*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
609*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
610*89c4ff92SAndroid Build Coastguard Worker }
611*89c4ff92SAndroid Build Coastguard Worker case LayerType::StridedSlice:
612*89c4ff92SAndroid Build Coastguard Worker return IsStridedSliceSupported(infos[0],
613*89c4ff92SAndroid Build Coastguard Worker infos[1],
614*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
615*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
616*89c4ff92SAndroid Build Coastguard Worker case LayerType::Subtraction:
617*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
618*89c4ff92SAndroid Build Coastguard Worker return IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
619*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
620*89c4ff92SAndroid Build Coastguard Worker case LayerType::Transpose:
621*89c4ff92SAndroid Build Coastguard Worker return IsTransposeSupported(infos[0],
622*89c4ff92SAndroid Build Coastguard Worker infos[1],
623*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
624*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
625*89c4ff92SAndroid Build Coastguard Worker case LayerType::TransposeConvolution2d:
626*89c4ff92SAndroid Build Coastguard Worker {
627*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4)
628*89c4ff92SAndroid Build Coastguard Worker {
629*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
630*89c4ff92SAndroid Build Coastguard Worker "TensorInfos should be of format: {input, output, weights, biases}.");
631*89c4ff92SAndroid Build Coastguard Worker }
632*89c4ff92SAndroid Build Coastguard Worker
633*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
634*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo())
635*89c4ff92SAndroid Build Coastguard Worker {
636*89c4ff92SAndroid Build Coastguard Worker return IsTransposeConvolution2dSupported(infos[0],
637*89c4ff92SAndroid Build Coastguard Worker infos[1],
638*89c4ff92SAndroid Build Coastguard Worker desc,
639*89c4ff92SAndroid Build Coastguard Worker infos[2],
640*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(),
641*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
642*89c4ff92SAndroid Build Coastguard Worker }
643*89c4ff92SAndroid Build Coastguard Worker else
644*89c4ff92SAndroid Build Coastguard Worker {
645*89c4ff92SAndroid Build Coastguard Worker return IsTransposeConvolution2dSupported(infos[0],
646*89c4ff92SAndroid Build Coastguard Worker infos[1],
647*89c4ff92SAndroid Build Coastguard Worker desc,
648*89c4ff92SAndroid Build Coastguard Worker infos[2],
649*89c4ff92SAndroid Build Coastguard Worker infos[3],
650*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
651*89c4ff92SAndroid Build Coastguard Worker }
652*89c4ff92SAndroid Build Coastguard Worker }
653*89c4ff92SAndroid Build Coastguard Worker case LayerType::UnidirectionalSequenceLstm:
654*89c4ff92SAndroid Build Coastguard Worker return IsUnidirectionalSequenceLstmSupported(infos[0],
655*89c4ff92SAndroid Build Coastguard Worker infos[1],
656*89c4ff92SAndroid Build Coastguard Worker infos[2],
657*89c4ff92SAndroid Build Coastguard Worker infos[3],
658*89c4ff92SAndroid Build Coastguard Worker infos[4],
659*89c4ff92SAndroid Build Coastguard Worker infos[5],
660*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const
661*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmDescriptor*>(&descriptor)),
662*89c4ff92SAndroid Build Coastguard Worker lstmParamsInfo.value(),
663*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
664*89c4ff92SAndroid Build Coastguard Worker case LayerType::Unmap:
665*89c4ff92SAndroid Build Coastguard Worker return true;
666*89c4ff92SAndroid Build Coastguard Worker default:
667*89c4ff92SAndroid Build Coastguard Worker // layers not supported in cl by default:
668*89c4ff92SAndroid Build Coastguard Worker // debug, detectionpostprocess, fakequantization,
669*89c4ff92SAndroid Build Coastguard Worker // precompiled, standin, switch, pooling3d
670*89c4ff92SAndroid Build Coastguard Worker return false;
671*89c4ff92SAndroid Build Coastguard Worker }
672*89c4ff92SAndroid Build Coastguard Worker }
673*89c4ff92SAndroid Build Coastguard Worker
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const674*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsActivationSupported(const TensorInfo& input,
675*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
676*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor& descriptor,
677*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
678*89c4ff92SAndroid Build Coastguard Worker {
679*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
680*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
681*89c4ff92SAndroid Build Coastguard Worker input,
682*89c4ff92SAndroid Build Coastguard Worker output,
683*89c4ff92SAndroid Build Coastguard Worker descriptor);
684*89c4ff92SAndroid Build Coastguard Worker }
685*89c4ff92SAndroid Build Coastguard Worker
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const686*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsAdditionSupported(const TensorInfo& input0,
687*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
688*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
689*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
690*89c4ff92SAndroid Build Coastguard Worker {
691*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClAdditionValidate,
692*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
693*89c4ff92SAndroid Build Coastguard Worker input0,
694*89c4ff92SAndroid Build Coastguard Worker input1,
695*89c4ff92SAndroid Build Coastguard Worker output,
696*89c4ff92SAndroid Build Coastguard Worker nullptr);
697*89c4ff92SAndroid Build Coastguard Worker }
698*89c4ff92SAndroid Build Coastguard Worker
IsArgMinMaxSupported(const TensorInfo & input,const TensorInfo & output,const ArgMinMaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const699*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsArgMinMaxSupported(const TensorInfo& input,
700*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
701*89c4ff92SAndroid Build Coastguard Worker const ArgMinMaxDescriptor& descriptor,
702*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
703*89c4ff92SAndroid Build Coastguard Worker {
704*89c4ff92SAndroid Build Coastguard Worker
705*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClArgMinMaxWorkloadValidate,
706*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
707*89c4ff92SAndroid Build Coastguard Worker input,
708*89c4ff92SAndroid Build Coastguard Worker output,
709*89c4ff92SAndroid Build Coastguard Worker descriptor);
710*89c4ff92SAndroid Build Coastguard Worker }
711*89c4ff92SAndroid Build Coastguard Worker
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const712*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
713*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputY,
714*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
715*89c4ff92SAndroid Build Coastguard Worker const BatchMatMulDescriptor& descriptor,
716*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
717*89c4ff92SAndroid Build Coastguard Worker {
718*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchMatMulValidate,
719*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
720*89c4ff92SAndroid Build Coastguard Worker inputX,
721*89c4ff92SAndroid Build Coastguard Worker inputY,
722*89c4ff92SAndroid Build Coastguard Worker output,
723*89c4ff92SAndroid Build Coastguard Worker descriptor);
724*89c4ff92SAndroid Build Coastguard Worker }
725*89c4ff92SAndroid Build Coastguard Worker
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & var,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const726*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
727*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
728*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& mean,
729*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& var,
730*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& beta,
731*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& gamma,
732*89c4ff92SAndroid Build Coastguard Worker const BatchNormalizationDescriptor& descriptor,
733*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
734*89c4ff92SAndroid Build Coastguard Worker {
735*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
736*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
737*89c4ff92SAndroid Build Coastguard Worker input,
738*89c4ff92SAndroid Build Coastguard Worker output,
739*89c4ff92SAndroid Build Coastguard Worker mean,
740*89c4ff92SAndroid Build Coastguard Worker var,
741*89c4ff92SAndroid Build Coastguard Worker beta,
742*89c4ff92SAndroid Build Coastguard Worker gamma,
743*89c4ff92SAndroid Build Coastguard Worker descriptor,
744*89c4ff92SAndroid Build Coastguard Worker nullptr);
745*89c4ff92SAndroid Build Coastguard Worker }
746*89c4ff92SAndroid Build Coastguard Worker
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const747*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
748*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
749*89c4ff92SAndroid Build Coastguard Worker const BatchToSpaceNdDescriptor& descriptor,
750*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
751*89c4ff92SAndroid Build Coastguard Worker {
752*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchToSpaceNdWorkloadValidate,
753*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
754*89c4ff92SAndroid Build Coastguard Worker input,
755*89c4ff92SAndroid Build Coastguard Worker output,
756*89c4ff92SAndroid Build Coastguard Worker descriptor);
757*89c4ff92SAndroid Build Coastguard Worker }
758*89c4ff92SAndroid Build Coastguard Worker
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const759*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsCastSupported(const TensorInfo& input,
760*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
761*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
762*89c4ff92SAndroid Build Coastguard Worker {
763*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClCastValidate,
764*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
765*89c4ff92SAndroid Build Coastguard Worker input,
766*89c4ff92SAndroid Build Coastguard Worker output);
767*89c4ff92SAndroid Build Coastguard Worker }
768*89c4ff92SAndroid Build Coastguard Worker
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const769*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
770*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
771*89c4ff92SAndroid Build Coastguard Worker const ChannelShuffleDescriptor& descriptor,
772*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
773*89c4ff92SAndroid Build Coastguard Worker {
774*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClChannelShuffleValidate,
775*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
776*89c4ff92SAndroid Build Coastguard Worker input,
777*89c4ff92SAndroid Build Coastguard Worker output,
778*89c4ff92SAndroid Build Coastguard Worker descriptor);
779*89c4ff92SAndroid Build Coastguard Worker }
780*89c4ff92SAndroid Build Coastguard Worker
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const781*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsComparisonSupported(const TensorInfo& input0,
782*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
783*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
784*89c4ff92SAndroid Build Coastguard Worker const ComparisonDescriptor& descriptor,
785*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
786*89c4ff92SAndroid Build Coastguard Worker {
787*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClComparisonWorkloadValidate,
788*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
789*89c4ff92SAndroid Build Coastguard Worker input0,
790*89c4ff92SAndroid Build Coastguard Worker input1,
791*89c4ff92SAndroid Build Coastguard Worker output,
792*89c4ff92SAndroid Build Coastguard Worker descriptor);
793*89c4ff92SAndroid Build Coastguard Worker }
794*89c4ff92SAndroid Build Coastguard Worker
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const795*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
796*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
797*89c4ff92SAndroid Build Coastguard Worker const OriginsDescriptor& descriptor,
798*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
799*89c4ff92SAndroid Build Coastguard Worker {
800*89c4ff92SAndroid Build Coastguard Worker if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
801*89c4ff92SAndroid Build Coastguard Worker {
802*89c4ff92SAndroid Build Coastguard Worker SetValueChecked(reasonIfUnsupported, "Cl Concat: Concat axis > Number of dimensions.");
803*89c4ff92SAndroid Build Coastguard Worker return false;
804*89c4ff92SAndroid Build Coastguard Worker }
805*89c4ff92SAndroid Build Coastguard Worker
806*89c4ff92SAndroid Build Coastguard Worker unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
807*89c4ff92SAndroid Build Coastguard Worker if(concatInnerAxis < 3) // Width, height, or channels
808*89c4ff92SAndroid Build Coastguard Worker {
809*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConcatWorkloadValidate,
810*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
811*89c4ff92SAndroid Build Coastguard Worker inputs,
812*89c4ff92SAndroid Build Coastguard Worker output,
813*89c4ff92SAndroid Build Coastguard Worker descriptor);
814*89c4ff92SAndroid Build Coastguard Worker }
815*89c4ff92SAndroid Build Coastguard Worker else if (concatInnerAxis == 3)
816*89c4ff92SAndroid Build Coastguard Worker {
817*89c4ff92SAndroid Build Coastguard Worker // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
818*89c4ff92SAndroid Build Coastguard Worker // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
819*89c4ff92SAndroid Build Coastguard Worker for (auto& input : inputs)
820*89c4ff92SAndroid Build Coastguard Worker {
821*89c4ff92SAndroid Build Coastguard Worker if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
822*89c4ff92SAndroid Build Coastguard Worker {
823*89c4ff92SAndroid Build Coastguard Worker SetValueChecked(reasonIfUnsupported, "Cl Concat: Types and quantization parameters must match.");
824*89c4ff92SAndroid Build Coastguard Worker return false;
825*89c4ff92SAndroid Build Coastguard Worker }
826*89c4ff92SAndroid Build Coastguard Worker }
827*89c4ff92SAndroid Build Coastguard Worker return true; // Sub-tensors support concat along batch
828*89c4ff92SAndroid Build Coastguard Worker }
829*89c4ff92SAndroid Build Coastguard Worker else // > 4 dimensions not supported.
830*89c4ff92SAndroid Build Coastguard Worker {
831*89c4ff92SAndroid Build Coastguard Worker SetValueChecked(reasonIfUnsupported, "Cl Concat: Maximum of 4 dimensions supported.");
832*89c4ff92SAndroid Build Coastguard Worker return false;
833*89c4ff92SAndroid Build Coastguard Worker }
834*89c4ff92SAndroid Build Coastguard Worker }
835*89c4ff92SAndroid Build Coastguard Worker
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const836*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConstantSupported(const TensorInfo& output,
837*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
838*89c4ff92SAndroid Build Coastguard Worker {
839*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConstantWorkloadValidate,
840*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
841*89c4ff92SAndroid Build Coastguard Worker output);
842*89c4ff92SAndroid Build Coastguard Worker }
843*89c4ff92SAndroid Build Coastguard Worker
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const844*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
845*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
846*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
847*89c4ff92SAndroid Build Coastguard Worker {
848*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
849*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
850*89c4ff92SAndroid Build Coastguard Worker input,
851*89c4ff92SAndroid Build Coastguard Worker output);
852*89c4ff92SAndroid Build Coastguard Worker }
853*89c4ff92SAndroid Build Coastguard Worker
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const854*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
855*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
856*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
857*89c4ff92SAndroid Build Coastguard Worker {
858*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
859*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
860*89c4ff92SAndroid Build Coastguard Worker input,
861*89c4ff92SAndroid Build Coastguard Worker output);
862*89c4ff92SAndroid Build Coastguard Worker }
863*89c4ff92SAndroid Build Coastguard Worker
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const864*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
865*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
866*89c4ff92SAndroid Build Coastguard Worker const Convolution2dDescriptor& descriptor,
867*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
868*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
869*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
870*89c4ff92SAndroid Build Coastguard Worker {
871*89c4ff92SAndroid Build Coastguard Worker bool isFastMathEnabled = false;
872*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
873*89c4ff92SAndroid Build Coastguard Worker if (m_ModelContextPtr)
874*89c4ff92SAndroid Build Coastguard Worker {
875*89c4ff92SAndroid Build Coastguard Worker if (m_ModelContextPtr.get() != nullptr)
876*89c4ff92SAndroid Build Coastguard Worker {
877*89c4ff92SAndroid Build Coastguard Worker auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
878*89c4ff92SAndroid Build Coastguard Worker if (modelOptions)
879*89c4ff92SAndroid Build Coastguard Worker {
880*89c4ff92SAndroid Build Coastguard Worker isFastMathEnabled = modelOptions->IsFastMathEnabled();
881*89c4ff92SAndroid Build Coastguard Worker }
882*89c4ff92SAndroid Build Coastguard Worker }
883*89c4ff92SAndroid Build Coastguard Worker }
884*89c4ff92SAndroid Build Coastguard Worker #endif
885*89c4ff92SAndroid Build Coastguard Worker
886*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
887*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
888*89c4ff92SAndroid Build Coastguard Worker input,
889*89c4ff92SAndroid Build Coastguard Worker output,
890*89c4ff92SAndroid Build Coastguard Worker descriptor,
891*89c4ff92SAndroid Build Coastguard Worker weights,
892*89c4ff92SAndroid Build Coastguard Worker biases,
893*89c4ff92SAndroid Build Coastguard Worker isFastMathEnabled,
894*89c4ff92SAndroid Build Coastguard Worker nullptr);
895*89c4ff92SAndroid Build Coastguard Worker }
896*89c4ff92SAndroid Build Coastguard Worker
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const897*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
898*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
899*89c4ff92SAndroid Build Coastguard Worker const Convolution3dDescriptor& descriptor,
900*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
901*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
902*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
903*89c4ff92SAndroid Build Coastguard Worker {
904*89c4ff92SAndroid Build Coastguard Worker bool isFastMathEnabled = false;
905*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
906*89c4ff92SAndroid Build Coastguard Worker if (m_ModelContextPtr)
907*89c4ff92SAndroid Build Coastguard Worker {
908*89c4ff92SAndroid Build Coastguard Worker if (m_ModelContextPtr.get() != nullptr)
909*89c4ff92SAndroid Build Coastguard Worker {
910*89c4ff92SAndroid Build Coastguard Worker auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
911*89c4ff92SAndroid Build Coastguard Worker if (modelOptions)
912*89c4ff92SAndroid Build Coastguard Worker {
913*89c4ff92SAndroid Build Coastguard Worker isFastMathEnabled = modelOptions->IsFastMathEnabled();
914*89c4ff92SAndroid Build Coastguard Worker }
915*89c4ff92SAndroid Build Coastguard Worker }
916*89c4ff92SAndroid Build Coastguard Worker }
917*89c4ff92SAndroid Build Coastguard Worker #endif
918*89c4ff92SAndroid Build Coastguard Worker
919*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution3dWorkloadValidate,
920*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
921*89c4ff92SAndroid Build Coastguard Worker input,
922*89c4ff92SAndroid Build Coastguard Worker output,
923*89c4ff92SAndroid Build Coastguard Worker descriptor,
924*89c4ff92SAndroid Build Coastguard Worker weights,
925*89c4ff92SAndroid Build Coastguard Worker biases,
926*89c4ff92SAndroid Build Coastguard Worker isFastMathEnabled,
927*89c4ff92SAndroid Build Coastguard Worker nullptr);
928*89c4ff92SAndroid Build Coastguard Worker }
929*89c4ff92SAndroid Build Coastguard Worker
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const930*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsDequantizeSupported(const TensorInfo& input,
931*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
932*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
933*89c4ff92SAndroid Build Coastguard Worker {
934*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDequantizeWorkloadValidate,
935*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
936*89c4ff92SAndroid Build Coastguard Worker input,
937*89c4ff92SAndroid Build Coastguard Worker output);
938*89c4ff92SAndroid Build Coastguard Worker }
939*89c4ff92SAndroid Build Coastguard Worker
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const940*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
941*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
942*89c4ff92SAndroid Build Coastguard Worker const DepthToSpaceDescriptor& descriptor,
943*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
944*89c4ff92SAndroid Build Coastguard Worker {
945*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthToSpaceWorkloadValidate,
946*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
947*89c4ff92SAndroid Build Coastguard Worker input,
948*89c4ff92SAndroid Build Coastguard Worker output,
949*89c4ff92SAndroid Build Coastguard Worker descriptor);
950*89c4ff92SAndroid Build Coastguard Worker }
951*89c4ff92SAndroid Build Coastguard Worker
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const952*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
953*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
954*89c4ff92SAndroid Build Coastguard Worker const DepthwiseConvolution2dDescriptor& descriptor,
955*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
956*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
957*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
958*89c4ff92SAndroid Build Coastguard Worker {
959*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
960*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
961*89c4ff92SAndroid Build Coastguard Worker input,
962*89c4ff92SAndroid Build Coastguard Worker output,
963*89c4ff92SAndroid Build Coastguard Worker descriptor,
964*89c4ff92SAndroid Build Coastguard Worker weights,
965*89c4ff92SAndroid Build Coastguard Worker biases,
966*89c4ff92SAndroid Build Coastguard Worker nullptr);
967*89c4ff92SAndroid Build Coastguard Worker }
968*89c4ff92SAndroid Build Coastguard Worker
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const969*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
970*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
971*89c4ff92SAndroid Build Coastguard Worker const DepthwiseConvolution2dDescriptor& descriptor,
972*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
973*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
974*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
975*89c4ff92SAndroid Build Coastguard Worker {
976*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
977*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
978*89c4ff92SAndroid Build Coastguard Worker input,
979*89c4ff92SAndroid Build Coastguard Worker output,
980*89c4ff92SAndroid Build Coastguard Worker descriptor,
981*89c4ff92SAndroid Build Coastguard Worker weights,
982*89c4ff92SAndroid Build Coastguard Worker biases,
983*89c4ff92SAndroid Build Coastguard Worker nullptr);
984*89c4ff92SAndroid Build Coastguard Worker }
985*89c4ff92SAndroid Build Coastguard Worker
986*89c4ff92SAndroid Build Coastguard Worker
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const987*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
988*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
989*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
990*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
991*89c4ff92SAndroid Build Coastguard Worker {
992*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
993*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
994*89c4ff92SAndroid Build Coastguard Worker input0,
995*89c4ff92SAndroid Build Coastguard Worker input1,
996*89c4ff92SAndroid Build Coastguard Worker output,
997*89c4ff92SAndroid Build Coastguard Worker nullptr);
998*89c4ff92SAndroid Build Coastguard Worker }
999*89c4ff92SAndroid Build Coastguard Worker
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1000*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1001*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1002*89c4ff92SAndroid Build Coastguard Worker const ElementwiseUnaryDescriptor& descriptor,
1003*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1004*89c4ff92SAndroid Build Coastguard Worker {
1005*89c4ff92SAndroid Build Coastguard Worker switch(descriptor.m_Operation)
1006*89c4ff92SAndroid Build Coastguard Worker {
1007*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Abs:
1008*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClAbsWorkloadValidate,
1009*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1010*89c4ff92SAndroid Build Coastguard Worker input,
1011*89c4ff92SAndroid Build Coastguard Worker output);
1012*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Exp:
1013*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClExpWorkloadValidate,
1014*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1015*89c4ff92SAndroid Build Coastguard Worker input,
1016*89c4ff92SAndroid Build Coastguard Worker output);
1017*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Log:
1018*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLogWorkloadValidate,
1019*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1020*89c4ff92SAndroid Build Coastguard Worker input,
1021*89c4ff92SAndroid Build Coastguard Worker output);
1022*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::LogicalNot:
1023*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLogicalNotWorkloadValidate,
1024*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1025*89c4ff92SAndroid Build Coastguard Worker input,
1026*89c4ff92SAndroid Build Coastguard Worker output);
1027*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Neg:
1028*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClNegWorkloadValidate,
1029*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1030*89c4ff92SAndroid Build Coastguard Worker input,
1031*89c4ff92SAndroid Build Coastguard Worker output);
1032*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Rsqrt:
1033*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClRsqrtWorkloadValidate,
1034*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1035*89c4ff92SAndroid Build Coastguard Worker input,
1036*89c4ff92SAndroid Build Coastguard Worker output);
1037*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Sin:
1038*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSinWorkloadValidate,
1039*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1040*89c4ff92SAndroid Build Coastguard Worker input,
1041*89c4ff92SAndroid Build Coastguard Worker output);
1042*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Sqrt:
1043*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSqrtWorkloadValidate,
1044*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1045*89c4ff92SAndroid Build Coastguard Worker input,
1046*89c4ff92SAndroid Build Coastguard Worker output);
1047*89c4ff92SAndroid Build Coastguard Worker default:
1048*89c4ff92SAndroid Build Coastguard Worker return false;
1049*89c4ff92SAndroid Build Coastguard Worker }
1050*89c4ff92SAndroid Build Coastguard Worker }
1051*89c4ff92SAndroid Build Coastguard Worker
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1052*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsFillSupported(const TensorInfo& input,
1053*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1054*89c4ff92SAndroid Build Coastguard Worker const FillDescriptor& descriptor,
1055*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1056*89c4ff92SAndroid Build Coastguard Worker {
1057*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(input);
1058*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(output);
1059*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(descriptor);
1060*89c4ff92SAndroid Build Coastguard Worker
1061*89c4ff92SAndroid Build Coastguard Worker return IsClBackendSupported(reasonIfUnsupported);
1062*89c4ff92SAndroid Build Coastguard Worker }
1063*89c4ff92SAndroid Build Coastguard Worker
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1064*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsFloorSupported(const TensorInfo& input,
1065*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1066*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1067*89c4ff92SAndroid Build Coastguard Worker {
1068*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClFloorWorkloadValidate,
1069*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1070*89c4ff92SAndroid Build Coastguard Worker input,
1071*89c4ff92SAndroid Build Coastguard Worker output);
1072*89c4ff92SAndroid Build Coastguard Worker }
1073*89c4ff92SAndroid Build Coastguard Worker
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1074*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1075*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1076*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1077*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& biases,
1078*89c4ff92SAndroid Build Coastguard Worker const FullyConnectedDescriptor& descriptor,
1079*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1080*89c4ff92SAndroid Build Coastguard Worker {
1081*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
1082*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1083*89c4ff92SAndroid Build Coastguard Worker input,
1084*89c4ff92SAndroid Build Coastguard Worker output,
1085*89c4ff92SAndroid Build Coastguard Worker weights,
1086*89c4ff92SAndroid Build Coastguard Worker biases,
1087*89c4ff92SAndroid Build Coastguard Worker descriptor,
1088*89c4ff92SAndroid Build Coastguard Worker nullptr);
1089*89c4ff92SAndroid Build Coastguard Worker }
1090*89c4ff92SAndroid Build Coastguard Worker
IsGatherSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const GatherDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1091*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsGatherSupported(const TensorInfo& input0,
1092*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1093*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1094*89c4ff92SAndroid Build Coastguard Worker const GatherDescriptor& descriptor,
1095*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1096*89c4ff92SAndroid Build Coastguard Worker {
1097*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClGatherWorkloadValidate,
1098*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1099*89c4ff92SAndroid Build Coastguard Worker input0,
1100*89c4ff92SAndroid Build Coastguard Worker input1,
1101*89c4ff92SAndroid Build Coastguard Worker output,
1102*89c4ff92SAndroid Build Coastguard Worker descriptor);
1103*89c4ff92SAndroid Build Coastguard Worker }
1104*89c4ff92SAndroid Build Coastguard Worker
IsGatherNdSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1105*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsGatherNdSupported(const TensorInfo& input0,
1106*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1107*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1108*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1109*89c4ff92SAndroid Build Coastguard Worker {
1110*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClGatherNdWorkloadValidate,
1111*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1112*89c4ff92SAndroid Build Coastguard Worker input0,
1113*89c4ff92SAndroid Build Coastguard Worker input1,
1114*89c4ff92SAndroid Build Coastguard Worker output);
1115*89c4ff92SAndroid Build Coastguard Worker }
1116*89c4ff92SAndroid Build Coastguard Worker
IsInputSupported(const TensorInfo & input,Optional<std::string &> reasonIfUnsupported) const1117*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
1118*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1119*89c4ff92SAndroid Build Coastguard Worker {
1120*89c4ff92SAndroid Build Coastguard Worker return IsClBackendSupported(reasonIfUnsupported, input);
1121*89c4ff92SAndroid Build Coastguard Worker }
1122*89c4ff92SAndroid Build Coastguard Worker
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1123*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1124*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1125*89c4ff92SAndroid Build Coastguard Worker const InstanceNormalizationDescriptor& descriptor,
1126*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1127*89c4ff92SAndroid Build Coastguard Worker {
1128*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClInstanceNormalizationWorkloadValidate,
1129*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1130*89c4ff92SAndroid Build Coastguard Worker input,
1131*89c4ff92SAndroid Build Coastguard Worker output,
1132*89c4ff92SAndroid Build Coastguard Worker descriptor);
1133*89c4ff92SAndroid Build Coastguard Worker }
1134*89c4ff92SAndroid Build Coastguard Worker
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1135*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1136*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1137*89c4ff92SAndroid Build Coastguard Worker const L2NormalizationDescriptor& descriptor,
1138*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1139*89c4ff92SAndroid Build Coastguard Worker {
1140*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate,
1141*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1142*89c4ff92SAndroid Build Coastguard Worker input,
1143*89c4ff92SAndroid Build Coastguard Worker output,
1144*89c4ff92SAndroid Build Coastguard Worker descriptor);
1145*89c4ff92SAndroid Build Coastguard Worker }
1146*89c4ff92SAndroid Build Coastguard Worker
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1147*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1148*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1149*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1150*89c4ff92SAndroid Build Coastguard Worker const LogicalBinaryDescriptor& descriptor,
1151*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1152*89c4ff92SAndroid Build Coastguard Worker {
1153*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(output);
1154*89c4ff92SAndroid Build Coastguard Worker
1155*89c4ff92SAndroid Build Coastguard Worker switch(descriptor.m_Operation)
1156*89c4ff92SAndroid Build Coastguard Worker {
1157*89c4ff92SAndroid Build Coastguard Worker case LogicalBinaryOperation::LogicalAnd:
1158*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLogicalAndWorkloadValidate,
1159*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1160*89c4ff92SAndroid Build Coastguard Worker input0,
1161*89c4ff92SAndroid Build Coastguard Worker input1,
1162*89c4ff92SAndroid Build Coastguard Worker output);
1163*89c4ff92SAndroid Build Coastguard Worker case LogicalBinaryOperation::LogicalOr:
1164*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLogicalOrWorkloadValidate,
1165*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1166*89c4ff92SAndroid Build Coastguard Worker input0,
1167*89c4ff92SAndroid Build Coastguard Worker input1,
1168*89c4ff92SAndroid Build Coastguard Worker output);
1169*89c4ff92SAndroid Build Coastguard Worker default:
1170*89c4ff92SAndroid Build Coastguard Worker return false;
1171*89c4ff92SAndroid Build Coastguard Worker }
1172*89c4ff92SAndroid Build Coastguard Worker }
1173*89c4ff92SAndroid Build Coastguard Worker
1174*89c4ff92SAndroid Build Coastguard Worker
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1175*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1176*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1177*89c4ff92SAndroid Build Coastguard Worker const LogSoftmaxDescriptor& descriptor,
1178*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1179*89c4ff92SAndroid Build Coastguard Worker {
1180*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLogSoftmaxWorkloadValidate,
1181*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1182*89c4ff92SAndroid Build Coastguard Worker input,
1183*89c4ff92SAndroid Build Coastguard Worker output,
1184*89c4ff92SAndroid Build Coastguard Worker descriptor);
1185*89c4ff92SAndroid Build Coastguard Worker }
1186*89c4ff92SAndroid Build Coastguard Worker
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1187*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
1188*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateIn,
1189*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateIn,
1190*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& scratchBuffer,
1191*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
1192*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
1193*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1194*89c4ff92SAndroid Build Coastguard Worker const LstmDescriptor& descriptor,
1195*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
1196*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1197*89c4ff92SAndroid Build Coastguard Worker {
1198*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
1199*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1200*89c4ff92SAndroid Build Coastguard Worker input,
1201*89c4ff92SAndroid Build Coastguard Worker outputStateIn,
1202*89c4ff92SAndroid Build Coastguard Worker cellStateIn,
1203*89c4ff92SAndroid Build Coastguard Worker scratchBuffer,
1204*89c4ff92SAndroid Build Coastguard Worker outputStateOut,
1205*89c4ff92SAndroid Build Coastguard Worker cellStateOut,
1206*89c4ff92SAndroid Build Coastguard Worker output,
1207*89c4ff92SAndroid Build Coastguard Worker descriptor,
1208*89c4ff92SAndroid Build Coastguard Worker paramsInfo);
1209*89c4ff92SAndroid Build Coastguard Worker }
1210*89c4ff92SAndroid Build Coastguard Worker
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1211*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1212*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1213*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1214*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1215*89c4ff92SAndroid Build Coastguard Worker {
1216*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMaximumWorkloadValidate,
1217*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1218*89c4ff92SAndroid Build Coastguard Worker input0,
1219*89c4ff92SAndroid Build Coastguard Worker input1,
1220*89c4ff92SAndroid Build Coastguard Worker output);
1221*89c4ff92SAndroid Build Coastguard Worker }
1222*89c4ff92SAndroid Build Coastguard Worker
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1223*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsMeanSupported(const TensorInfo& input,
1224*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1225*89c4ff92SAndroid Build Coastguard Worker const MeanDescriptor& descriptor,
1226*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1227*89c4ff92SAndroid Build Coastguard Worker {
1228*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMeanValidate,
1229*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1230*89c4ff92SAndroid Build Coastguard Worker input,
1231*89c4ff92SAndroid Build Coastguard Worker output,
1232*89c4ff92SAndroid Build Coastguard Worker descriptor);
1233*89c4ff92SAndroid Build Coastguard Worker }
1234*89c4ff92SAndroid Build Coastguard Worker
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1235*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1236*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1237*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1238*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1239*89c4ff92SAndroid Build Coastguard Worker {
1240*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMinimumWorkloadValidate,
1241*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1242*89c4ff92SAndroid Build Coastguard Worker input0,
1243*89c4ff92SAndroid Build Coastguard Worker input1,
1244*89c4ff92SAndroid Build Coastguard Worker output);
1245*89c4ff92SAndroid Build Coastguard Worker }
1246*89c4ff92SAndroid Build Coastguard Worker
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1247*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1248*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1249*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1250*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1251*89c4ff92SAndroid Build Coastguard Worker {
1252*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
1253*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1254*89c4ff92SAndroid Build Coastguard Worker input0,
1255*89c4ff92SAndroid Build Coastguard Worker input1,
1256*89c4ff92SAndroid Build Coastguard Worker output,
1257*89c4ff92SAndroid Build Coastguard Worker nullptr);
1258*89c4ff92SAndroid Build Coastguard Worker }
1259*89c4ff92SAndroid Build Coastguard Worker
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1260*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1261*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1262*89c4ff92SAndroid Build Coastguard Worker const NormalizationDescriptor& descriptor,
1263*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1264*89c4ff92SAndroid Build Coastguard Worker {
1265*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1266*89c4ff92SAndroid Build Coastguard Worker }
1267*89c4ff92SAndroid Build Coastguard Worker
IsOutputSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1268*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsOutputSupported(const TensorInfo& output,
1269*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1270*89c4ff92SAndroid Build Coastguard Worker {
1271*89c4ff92SAndroid Build Coastguard Worker return IsClBackendSupported(reasonIfUnsupported, output);
1272*89c4ff92SAndroid Build Coastguard Worker }
1273*89c4ff92SAndroid Build Coastguard Worker
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1274*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsPadSupported(const TensorInfo& input,
1275*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1276*89c4ff92SAndroid Build Coastguard Worker const PadDescriptor& descriptor,
1277*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1278*89c4ff92SAndroid Build Coastguard Worker {
1279*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClPadValidate,
1280*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1281*89c4ff92SAndroid Build Coastguard Worker input,
1282*89c4ff92SAndroid Build Coastguard Worker output,
1283*89c4ff92SAndroid Build Coastguard Worker descriptor);
1284*89c4ff92SAndroid Build Coastguard Worker }
1285*89c4ff92SAndroid Build Coastguard Worker
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1286*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
1287*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1288*89c4ff92SAndroid Build Coastguard Worker const PermuteDescriptor& descriptor,
1289*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1290*89c4ff92SAndroid Build Coastguard Worker {
1291*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1292*89c4ff92SAndroid Build Coastguard Worker }
1293*89c4ff92SAndroid Build Coastguard Worker
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1294*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1295*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1296*89c4ff92SAndroid Build Coastguard Worker const Pooling2dDescriptor& descriptor,
1297*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1298*89c4ff92SAndroid Build Coastguard Worker {
1299*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1300*89c4ff92SAndroid Build Coastguard Worker }
1301*89c4ff92SAndroid Build Coastguard Worker
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1302*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsPooling3dSupported(const TensorInfo& input,
1303*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1304*89c4ff92SAndroid Build Coastguard Worker const Pooling3dDescriptor& descriptor,
1305*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1306*89c4ff92SAndroid Build Coastguard Worker {
1307*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling3dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1308*89c4ff92SAndroid Build Coastguard Worker }
1309*89c4ff92SAndroid Build Coastguard Worker
IsPreluSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & alpha,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1310*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input,
1311*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo &alpha,
1312*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo &output,
1313*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string &> reasonIfUnsupported) const
1314*89c4ff92SAndroid Build Coastguard Worker {
1315*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output);
1316*89c4ff92SAndroid Build Coastguard Worker }
1317*89c4ff92SAndroid Build Coastguard Worker
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1318*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsQLstmSupported(const TensorInfo& input,
1319*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousOutputIn,
1320*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousCellStateIn,
1321*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
1322*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
1323*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1324*89c4ff92SAndroid Build Coastguard Worker const QLstmDescriptor& descriptor,
1325*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
1326*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1327*89c4ff92SAndroid Build Coastguard Worker {
1328*89c4ff92SAndroid Build Coastguard Worker if (input.GetDataType() == armnn::DataType::QAsymmS8 &&
1329*89c4ff92SAndroid Build Coastguard Worker previousOutputIn.GetDataType() == armnn::DataType::QAsymmS8 &&
1330*89c4ff92SAndroid Build Coastguard Worker previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 &&
1331*89c4ff92SAndroid Build Coastguard Worker outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 &&
1332*89c4ff92SAndroid Build Coastguard Worker cellStateOut.GetDataType() == armnn::DataType::QSymmS16 &&
1333*89c4ff92SAndroid Build Coastguard Worker output.GetDataType() == armnn::DataType::QAsymmS8)
1334*89c4ff92SAndroid Build Coastguard Worker {
1335*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClQLstmWorkloadValidate,
1336*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1337*89c4ff92SAndroid Build Coastguard Worker input,
1338*89c4ff92SAndroid Build Coastguard Worker previousCellStateIn,
1339*89c4ff92SAndroid Build Coastguard Worker previousOutputIn,
1340*89c4ff92SAndroid Build Coastguard Worker cellStateOut,
1341*89c4ff92SAndroid Build Coastguard Worker outputStateOut,
1342*89c4ff92SAndroid Build Coastguard Worker output,
1343*89c4ff92SAndroid Build Coastguard Worker descriptor,
1344*89c4ff92SAndroid Build Coastguard Worker paramsInfo);
1345*89c4ff92SAndroid Build Coastguard Worker }
1346*89c4ff92SAndroid Build Coastguard Worker else
1347*89c4ff92SAndroid Build Coastguard Worker {
1348*89c4ff92SAndroid Build Coastguard Worker return false;
1349*89c4ff92SAndroid Build Coastguard Worker }
1350*89c4ff92SAndroid Build Coastguard Worker }
1351*89c4ff92SAndroid Build Coastguard Worker
IsQuantizedLstmSupported(const TensorInfo & input,const TensorInfo & previousCellStateIn,const TensorInfo & previousOutputIn,const TensorInfo & cellStateOut,const TensorInfo & output,const QuantizedLstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1352*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input,
1353*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousCellStateIn,
1354*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& previousOutputIn,
1355*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
1356*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1357*89c4ff92SAndroid Build Coastguard Worker const QuantizedLstmInputParamsInfo& paramsInfo,
1358*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1359*89c4ff92SAndroid Build Coastguard Worker {
1360*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClQuantizedLstmWorkloadValidate,
1361*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1362*89c4ff92SAndroid Build Coastguard Worker input,
1363*89c4ff92SAndroid Build Coastguard Worker previousCellStateIn,
1364*89c4ff92SAndroid Build Coastguard Worker previousOutputIn,
1365*89c4ff92SAndroid Build Coastguard Worker cellStateOut,
1366*89c4ff92SAndroid Build Coastguard Worker output,
1367*89c4ff92SAndroid Build Coastguard Worker paramsInfo);
1368*89c4ff92SAndroid Build Coastguard Worker }
1369*89c4ff92SAndroid Build Coastguard Worker
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1370*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1371*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1372*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1373*89c4ff92SAndroid Build Coastguard Worker {
1374*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClQuantizeWorkloadValidate,
1375*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1376*89c4ff92SAndroid Build Coastguard Worker input,
1377*89c4ff92SAndroid Build Coastguard Worker output);
1378*89c4ff92SAndroid Build Coastguard Worker }
1379*89c4ff92SAndroid Build Coastguard Worker
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1380*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsReduceSupported(const TensorInfo& input,
1381*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1382*89c4ff92SAndroid Build Coastguard Worker const ReduceDescriptor& descriptor,
1383*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1384*89c4ff92SAndroid Build Coastguard Worker {
1385*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClReduceWorkloadValidate,
1386*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1387*89c4ff92SAndroid Build Coastguard Worker input,
1388*89c4ff92SAndroid Build Coastguard Worker output,
1389*89c4ff92SAndroid Build Coastguard Worker descriptor);
1390*89c4ff92SAndroid Build Coastguard Worker }
1391*89c4ff92SAndroid Build Coastguard Worker
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1392*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
1393*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1394*89c4ff92SAndroid Build Coastguard Worker const ReshapeDescriptor& descriptor,
1395*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1396*89c4ff92SAndroid Build Coastguard Worker {
1397*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1398*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClReshapeWorkloadValidate, reasonIfUnsupported, input, output);
1399*89c4ff92SAndroid Build Coastguard Worker }
1400*89c4ff92SAndroid Build Coastguard Worker
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1401*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsResizeSupported(const TensorInfo& input,
1402*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1403*89c4ff92SAndroid Build Coastguard Worker const ResizeDescriptor& descriptor,
1404*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1405*89c4ff92SAndroid Build Coastguard Worker {
1406*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClResizeWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1407*89c4ff92SAndroid Build Coastguard Worker }
1408*89c4ff92SAndroid Build Coastguard Worker
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1409*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSliceSupported(const TensorInfo& input,
1410*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1411*89c4ff92SAndroid Build Coastguard Worker const SliceDescriptor& descriptor,
1412*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1413*89c4ff92SAndroid Build Coastguard Worker {
1414*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSliceWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1415*89c4ff92SAndroid Build Coastguard Worker }
1416*89c4ff92SAndroid Build Coastguard Worker
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1417*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1418*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1419*89c4ff92SAndroid Build Coastguard Worker const SoftmaxDescriptor& descriptor,
1420*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1421*89c4ff92SAndroid Build Coastguard Worker {
1422*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1423*89c4ff92SAndroid Build Coastguard Worker }
1424*89c4ff92SAndroid Build Coastguard Worker
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1425*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1426*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1427*89c4ff92SAndroid Build Coastguard Worker const SpaceToBatchNdDescriptor& descriptor,
1428*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1429*89c4ff92SAndroid Build Coastguard Worker {
1430*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToBatchNdWorkloadValidate,
1431*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1432*89c4ff92SAndroid Build Coastguard Worker input,
1433*89c4ff92SAndroid Build Coastguard Worker output,
1434*89c4ff92SAndroid Build Coastguard Worker descriptor);
1435*89c4ff92SAndroid Build Coastguard Worker }
1436*89c4ff92SAndroid Build Coastguard Worker
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1437*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1438*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1439*89c4ff92SAndroid Build Coastguard Worker const SpaceToDepthDescriptor& descriptor,
1440*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1441*89c4ff92SAndroid Build Coastguard Worker {
1442*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToDepthWorkloadValidate,
1443*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1444*89c4ff92SAndroid Build Coastguard Worker input,
1445*89c4ff92SAndroid Build Coastguard Worker output,
1446*89c4ff92SAndroid Build Coastguard Worker descriptor);
1447*89c4ff92SAndroid Build Coastguard Worker }
1448*89c4ff92SAndroid Build Coastguard Worker
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1449*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
1450*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1451*89c4ff92SAndroid Build Coastguard Worker const ViewsDescriptor& descriptor,
1452*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1453*89c4ff92SAndroid Build Coastguard Worker {
1454*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
1455*89c4ff92SAndroid Build Coastguard Worker // Split along the last dimension, cannot use sub-tensors
1456*89c4ff92SAndroid Build Coastguard Worker // as width and height of the sub-tensors do not match
1457*89c4ff92SAndroid Build Coastguard Worker // the width and height of the parent tensor
1458*89c4ff92SAndroid Build Coastguard Worker // in case of input with more than 2D.
1459*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape());
1460*89c4ff92SAndroid Build Coastguard Worker if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 &&
1461*89c4ff92SAndroid Build Coastguard Worker *splitAxis.begin() == descriptor.GetNumDimensions() - 1 )
1462*89c4ff92SAndroid Build Coastguard Worker {
1463*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSplitterWorkloadValidate,
1464*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1465*89c4ff92SAndroid Build Coastguard Worker input,
1466*89c4ff92SAndroid Build Coastguard Worker outputs,
1467*89c4ff92SAndroid Build Coastguard Worker *splitAxis.begin());
1468*89c4ff92SAndroid Build Coastguard Worker }
1469*89c4ff92SAndroid Build Coastguard Worker #endif
1470*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(descriptor);
1471*89c4ff92SAndroid Build Coastguard Worker for (auto output : outputs)
1472*89c4ff92SAndroid Build Coastguard Worker {
1473*89c4ff92SAndroid Build Coastguard Worker if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space
1474*89c4ff92SAndroid Build Coastguard Worker {
1475*89c4ff92SAndroid Build Coastguard Worker SetValueChecked(reasonIfUnsupported, "Cl Splitter: Types and quantization parameters must match.");
1476*89c4ff92SAndroid Build Coastguard Worker return false;
1477*89c4ff92SAndroid Build Coastguard Worker }
1478*89c4ff92SAndroid Build Coastguard Worker }
1479*89c4ff92SAndroid Build Coastguard Worker return true;
1480*89c4ff92SAndroid Build Coastguard Worker }
1481*89c4ff92SAndroid Build Coastguard Worker
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1482*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1483*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1484*89c4ff92SAndroid Build Coastguard Worker const StackDescriptor& descriptor,
1485*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1486*89c4ff92SAndroid Build Coastguard Worker {
1487*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClStackWorkloadValidate,
1488*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1489*89c4ff92SAndroid Build Coastguard Worker inputs,
1490*89c4ff92SAndroid Build Coastguard Worker output,
1491*89c4ff92SAndroid Build Coastguard Worker descriptor);
1492*89c4ff92SAndroid Build Coastguard Worker }
1493*89c4ff92SAndroid Build Coastguard Worker
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1494*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1495*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1496*89c4ff92SAndroid Build Coastguard Worker const StridedSliceDescriptor& descriptor,
1497*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1498*89c4ff92SAndroid Build Coastguard Worker {
1499*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClStridedSliceWorkloadValidate,
1500*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1501*89c4ff92SAndroid Build Coastguard Worker input,
1502*89c4ff92SAndroid Build Coastguard Worker output,
1503*89c4ff92SAndroid Build Coastguard Worker descriptor);
1504*89c4ff92SAndroid Build Coastguard Worker }
1505*89c4ff92SAndroid Build Coastguard Worker
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1506*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1507*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& input1,
1508*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1509*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1510*89c4ff92SAndroid Build Coastguard Worker {
1511*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClSubtractionValidate,
1512*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1513*89c4ff92SAndroid Build Coastguard Worker input0,
1514*89c4ff92SAndroid Build Coastguard Worker input1,
1515*89c4ff92SAndroid Build Coastguard Worker output,
1516*89c4ff92SAndroid Build Coastguard Worker nullptr);
1517*89c4ff92SAndroid Build Coastguard Worker }
1518*89c4ff92SAndroid Build Coastguard Worker
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1519*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1520*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1521*89c4ff92SAndroid Build Coastguard Worker const TransposeConvolution2dDescriptor& descriptor,
1522*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& weights,
1523*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& biases,
1524*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1525*89c4ff92SAndroid Build Coastguard Worker {
1526*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeConvolution2dWorkloadValidate,
1527*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1528*89c4ff92SAndroid Build Coastguard Worker input,
1529*89c4ff92SAndroid Build Coastguard Worker output,
1530*89c4ff92SAndroid Build Coastguard Worker descriptor,
1531*89c4ff92SAndroid Build Coastguard Worker weights,
1532*89c4ff92SAndroid Build Coastguard Worker biases);
1533*89c4ff92SAndroid Build Coastguard Worker }
1534*89c4ff92SAndroid Build Coastguard Worker
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1535*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsTransposeSupported(const TensorInfo& input,
1536*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1537*89c4ff92SAndroid Build Coastguard Worker const TransposeDescriptor& descriptor,
1538*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1539*89c4ff92SAndroid Build Coastguard Worker {
1540*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClTransposeWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1541*89c4ff92SAndroid Build Coastguard Worker }
1542*89c4ff92SAndroid Build Coastguard Worker
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1543*89c4ff92SAndroid Build Coastguard Worker bool ClLayerSupport::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input,
1544*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateIn,
1545*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateIn,
1546*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputStateOut,
1547*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& cellStateOut,
1548*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& output,
1549*89c4ff92SAndroid Build Coastguard Worker const UnidirectionalSequenceLstmDescriptor& descriptor,
1550*89c4ff92SAndroid Build Coastguard Worker const LstmInputParamsInfo& paramsInfo,
1551*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const
1552*89c4ff92SAndroid Build Coastguard Worker {
1553*89c4ff92SAndroid Build Coastguard Worker FORWARD_WORKLOAD_VALIDATE_FUNC(ClUnidirectionalSequenceLstmFloatWorkloadValidate,
1554*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported,
1555*89c4ff92SAndroid Build Coastguard Worker input,
1556*89c4ff92SAndroid Build Coastguard Worker outputStateIn,
1557*89c4ff92SAndroid Build Coastguard Worker cellStateIn,
1558*89c4ff92SAndroid Build Coastguard Worker outputStateOut,
1559*89c4ff92SAndroid Build Coastguard Worker cellStateOut,
1560*89c4ff92SAndroid Build Coastguard Worker output,
1561*89c4ff92SAndroid Build Coastguard Worker descriptor,
1562*89c4ff92SAndroid Build Coastguard Worker paramsInfo);
1563*89c4ff92SAndroid Build Coastguard Worker }
1564*89c4ff92SAndroid Build Coastguard Worker
1565*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
1566