1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <Layer.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <tosaCommon/TosaMappings.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <tosaCommon/operatorMappings/TosaOperatorUtils.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
17*89c4ff92SAndroid Build Coastguard Worker using namespace tosa;
18*89c4ff92SAndroid Build Coastguard Worker
VerifyTosaAttribute(const BaseDescriptor & descriptor,const TosaAttributeBase * attribute,std::vector<int32_t> inputShape,std::vector<int32_t> outputShape,LayerType type,uint32_t mappingOpNumber=0)19*89c4ff92SAndroid Build Coastguard Worker inline void VerifyTosaAttribute(const BaseDescriptor& descriptor,
20*89c4ff92SAndroid Build Coastguard Worker const TosaAttributeBase* attribute,
21*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape,
22*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape,
23*89c4ff92SAndroid Build Coastguard Worker LayerType type,
24*89c4ff92SAndroid Build Coastguard Worker uint32_t mappingOpNumber = 0)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker switch (type)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution2d:
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor);
31*89c4ff92SAndroid Build Coastguard Worker std::vector<int> pad = {static_cast<int>(conv2dDesc->m_PadTop),
32*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(conv2dDesc->m_PadBottom),
33*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(conv2dDesc->m_PadLeft),
34*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(conv2dDesc->m_PadRight)};
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker std::vector<int> dilation = {static_cast<int>(conv2dDesc->m_DilationY),
37*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(conv2dDesc->m_DilationX)};
38*89c4ff92SAndroid Build Coastguard Worker std::vector<int> stride = {static_cast<int>(conv2dDesc->m_StrideY),
39*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(conv2dDesc->m_StrideX)};
40*89c4ff92SAndroid Build Coastguard Worker TosaConvAttribute convAttribute(attribute);
41*89c4ff92SAndroid Build Coastguard Worker CHECK(pad == convAttribute.pad());
42*89c4ff92SAndroid Build Coastguard Worker CHECK(dilation == convAttribute.dilation());
43*89c4ff92SAndroid Build Coastguard Worker CHECK(stride == convAttribute.stride());
44*89c4ff92SAndroid Build Coastguard Worker break;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker case LayerType::Pooling2d:
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker auto poolDesc = PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor);
49*89c4ff92SAndroid Build Coastguard Worker std::vector<int> pad = {static_cast<int>(poolDesc->m_PadTop),
50*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadBottom),
51*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadLeft),
52*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadRight)};
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker bool avgPoolIgnoreValue =
55*89c4ff92SAndroid Build Coastguard Worker (poolDesc->m_PoolType == PoolingAlgorithm::Average) &&
56*89c4ff92SAndroid Build Coastguard Worker (poolDesc->m_PaddingMethod == PaddingMethod::IgnoreValue);
57*89c4ff92SAndroid Build Coastguard Worker if (avgPoolIgnoreValue)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker if (mappingOpNumber == 0)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker if (poolDesc->m_DataLayout == DataLayout::NHWC)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker pad = {0,
64*89c4ff92SAndroid Build Coastguard Worker 0,
65*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadTop),
66*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadBottom),
67*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadLeft),
68*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadRight),
69*89c4ff92SAndroid Build Coastguard Worker 0,
70*89c4ff92SAndroid Build Coastguard Worker 0
71*89c4ff92SAndroid Build Coastguard Worker };
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker else
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker pad = {0,
76*89c4ff92SAndroid Build Coastguard Worker 0,
77*89c4ff92SAndroid Build Coastguard Worker 0,
78*89c4ff92SAndroid Build Coastguard Worker 0,
79*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadTop),
80*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadBottom),
81*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadLeft),
82*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PadRight)
83*89c4ff92SAndroid Build Coastguard Worker };
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker TosaPadAttribute padAttribute(attribute);
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker CHECK(pad == padAttribute.padding());
89*89c4ff92SAndroid Build Coastguard Worker CHECK(0.0f == padAttribute.pad_const_fp());
90*89c4ff92SAndroid Build Coastguard Worker CHECK(0 == padAttribute.pad_const_int());
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker break;
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker pad = {0, 0, 0, 0};
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker std::vector<int> kernel = {static_cast<int>(poolDesc->m_PoolHeight),
98*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_PoolWidth)};
99*89c4ff92SAndroid Build Coastguard Worker std::vector<int> stride = {static_cast<int>(poolDesc->m_StrideY),
100*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(poolDesc->m_StrideX)};
101*89c4ff92SAndroid Build Coastguard Worker TosaPoolAttribute poolAttribute(attribute);
102*89c4ff92SAndroid Build Coastguard Worker CHECK(pad == poolAttribute.pad());
103*89c4ff92SAndroid Build Coastguard Worker CHECK(kernel == poolAttribute.kernel());
104*89c4ff92SAndroid Build Coastguard Worker CHECK(stride == poolAttribute.stride());
105*89c4ff92SAndroid Build Coastguard Worker break;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker case LayerType::Reshape:
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker auto reshapeDesc = PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor);
110*89c4ff92SAndroid Build Coastguard Worker TosaReshapeAttribute reshapeAttribute(attribute);
111*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> shapeAttrib = reshapeAttribute.new_shape();
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker CHECK(GetTosaTensorShape(reshapeDesc->m_TargetShape) == shapeAttrib);
114*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape == shapeAttrib);
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker auto numInputElements = std::accumulate(std::begin(inputShape),
117*89c4ff92SAndroid Build Coastguard Worker std::end(inputShape),
118*89c4ff92SAndroid Build Coastguard Worker 1,
119*89c4ff92SAndroid Build Coastguard Worker std::multiplies<int32_t>());
120*89c4ff92SAndroid Build Coastguard Worker auto numAttributeShapeElements = std::accumulate(std::begin(shapeAttrib),
121*89c4ff92SAndroid Build Coastguard Worker std::end(shapeAttrib),
122*89c4ff92SAndroid Build Coastguard Worker 1,
123*89c4ff92SAndroid Build Coastguard Worker std::multiplies<int32_t>());
124*89c4ff92SAndroid Build Coastguard Worker CHECK(numInputElements == numAttributeShapeElements);
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker break;
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker case LayerType::Slice:
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker auto sliceDesc = PolymorphicDowncast<const SliceDescriptor*>(&descriptor);
131*89c4ff92SAndroid Build Coastguard Worker TosaSliceAttribute reshapeAttribute(attribute);
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> begin(sliceDesc->m_Begin.begin(), sliceDesc->m_Begin.end());
134*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> size(sliceDesc->m_Size.begin(), sliceDesc->m_Size.end());
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker CHECK(begin == reshapeAttribute.start());
137*89c4ff92SAndroid Build Coastguard Worker CHECK(size == reshapeAttribute.size());
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker CHECK(begin.size() == inputShape.size());
140*89c4ff92SAndroid Build Coastguard Worker CHECK(size.size() == inputShape.size());
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker CHECK(begin.size() == outputShape.size());
143*89c4ff92SAndroid Build Coastguard Worker CHECK(size.size() == outputShape.size());
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker break;
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker case LayerType::TransposeConvolution2d:
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker auto transposeConv2dDesc = PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor);
150*89c4ff92SAndroid Build Coastguard Worker std::vector<int> outPad = {-static_cast<int>(transposeConv2dDesc->m_PadTop),
151*89c4ff92SAndroid Build Coastguard Worker -static_cast<int>(transposeConv2dDesc->m_PadBottom),
152*89c4ff92SAndroid Build Coastguard Worker -static_cast<int>(transposeConv2dDesc->m_PadLeft),
153*89c4ff92SAndroid Build Coastguard Worker -static_cast<int>(transposeConv2dDesc->m_PadRight)};
154*89c4ff92SAndroid Build Coastguard Worker std::vector<int> stride = {static_cast<int>(transposeConv2dDesc->m_StrideY),
155*89c4ff92SAndroid Build Coastguard Worker static_cast<int>(transposeConv2dDesc->m_StrideX)};
156*89c4ff92SAndroid Build Coastguard Worker TosaTransposeConvAttribute transposeConvAttribute(attribute);
157*89c4ff92SAndroid Build Coastguard Worker CHECK(outPad == transposeConvAttribute.out_pad());
158*89c4ff92SAndroid Build Coastguard Worker CHECK(stride == transposeConvAttribute.stride());
159*89c4ff92SAndroid Build Coastguard Worker break;
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker case LayerType::Transpose:
162*89c4ff92SAndroid Build Coastguard Worker {
163*89c4ff92SAndroid Build Coastguard Worker auto transposeDesc = PolymorphicDowncast<const TransposeDescriptor*>(&descriptor);
164*89c4ff92SAndroid Build Coastguard Worker std::vector<int> outPerm(transposeDesc->m_DimMappings.begin(), transposeDesc->m_DimMappings.end());
165*89c4ff92SAndroid Build Coastguard Worker TosaTransposeAttribute transposeAttribute(attribute);
166*89c4ff92SAndroid Build Coastguard Worker CHECK(outPerm == transposeAttribute.perms());
167*89c4ff92SAndroid Build Coastguard Worker break;
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker default:
170*89c4ff92SAndroid Build Coastguard Worker break;
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker return;
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker
AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock * basicBlock,std::vector<std::vector<int32_t>> inputShape,std::vector<std::vector<int32_t>> outputShape,Op tosaOp,Attribute tosaAttribute,const BaseDescriptor & descriptor,LayerType type,DType dataType=DType_FP32)175*89c4ff92SAndroid Build Coastguard Worker inline void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock,
176*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<int32_t>> inputShape,
177*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<int32_t>> outputShape,
178*89c4ff92SAndroid Build Coastguard Worker Op tosaOp,
179*89c4ff92SAndroid Build Coastguard Worker Attribute tosaAttribute,
180*89c4ff92SAndroid Build Coastguard Worker const BaseDescriptor& descriptor,
181*89c4ff92SAndroid Build Coastguard Worker LayerType type,
182*89c4ff92SAndroid Build Coastguard Worker DType dataType = DType_FP32)
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
185*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputTensors = static_cast<uint32_t>(inputShape.size());
186*89c4ff92SAndroid Build Coastguard Worker uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
187*89c4ff92SAndroid Build Coastguard Worker std::string operatorString = TosaOpToString(tosaOp);
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker // The number of tensors in the block can be different if there are constant layers, as they are created separately.
190*89c4ff92SAndroid Build Coastguard Worker if(type == LayerType::Convolution2d)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker numInputTensors = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor)->m_BiasEnabled ? 3 : 2;
193*89c4ff92SAndroid Build Coastguard Worker }
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker std::string blockStr = operatorString + "_block_";
196*89c4ff92SAndroid Build Coastguard Worker CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
197*89c4ff92SAndroid Build Coastguard Worker CHECK(basicBlock->GetInputs().size() == numInputTensors);
198*89c4ff92SAndroid Build Coastguard Worker CHECK(basicBlock->GetOutputs().size() == numOutputs);
199*89c4ff92SAndroid Build Coastguard Worker CHECK(basicBlock->GetOperators().size() == 1);
200*89c4ff92SAndroid Build Coastguard Worker CHECK(basicBlock->GetTensors().size() == (numInputs + numOutputs));
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker TosaSerializationOperator* op = basicBlock->GetOperators().at(0);
203*89c4ff92SAndroid Build Coastguard Worker CHECK(op->GetInputTensorNames().size() == numInputTensors);
204*89c4ff92SAndroid Build Coastguard Worker CHECK(op->GetOutputTensorNames().size() == numOutputs);
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < numInputs; i++)
207*89c4ff92SAndroid Build Coastguard Worker {
208*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
209*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> operatorInputName = op->GetInputTensorNames()[i];
210*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> tensorName = basicBlock->GetTensors()[i]->GetName();
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker std::string opStr = "input" + std::to_string(i) + "_";
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker CHECK(blockInputName == operatorInputName);
215*89c4ff92SAndroid Build Coastguard Worker CHECK(tensorName == operatorInputName);
216*89c4ff92SAndroid Build Coastguard Worker CHECK(blockInputName.find(opStr) != std::string::npos);
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker
219*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < numOutputs; i++)
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
222*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> operatorOutputName = op->GetOutputTensorNames()[i];
223*89c4ff92SAndroid Build Coastguard Worker std::basic_string<char> tensorName = basicBlock->GetTensors()[numInputs + i]->GetName();
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker std::string opStr = "output" + std::to_string(i) + "_";
226*89c4ff92SAndroid Build Coastguard Worker if (tosaOp == Op_CONST)
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker opStr = "constant_";
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker
231*89c4ff92SAndroid Build Coastguard Worker CHECK(blockOutputName == operatorOutputName);
232*89c4ff92SAndroid Build Coastguard Worker CHECK(tensorName == operatorOutputName);
233*89c4ff92SAndroid Build Coastguard Worker CHECK(blockOutputName.find(opStr) != std::string::npos);
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker
236*89c4ff92SAndroid Build Coastguard Worker CHECK(op->GetAttributeType() == tosaAttribute);
237*89c4ff92SAndroid Build Coastguard Worker CHECK(op->GetOp() == tosaOp);
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < numInputs; i++)
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker TosaSerializationTensor* tensor = basicBlock->GetTensors()[i];
242*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetDtype() == dataType);
243*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetData().size() == 0);
244*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetShape() == inputShape[static_cast<unsigned long int>(i)]);
245*89c4ff92SAndroid Build Coastguard Worker }
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < numOutputs; i++)
248*89c4ff92SAndroid Build Coastguard Worker {
249*89c4ff92SAndroid Build Coastguard Worker TosaSerializationTensor* tensor = basicBlock->GetTensors()[i + inputShape.size()];
250*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetDtype() == dataType);
251*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetShape() == outputShape[static_cast<unsigned long int>(i)]);
252*89c4ff92SAndroid Build Coastguard Worker if (tosaOp != Op_CONST)
253*89c4ff92SAndroid Build Coastguard Worker {
254*89c4ff92SAndroid Build Coastguard Worker // Const tensors contain data.
255*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor->GetData().size() == 0);
256*89c4ff92SAndroid Build Coastguard Worker }
257*89c4ff92SAndroid Build Coastguard Worker }
258*89c4ff92SAndroid Build Coastguard Worker
259*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input = {};
260*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> output = {};
261*89c4ff92SAndroid Build Coastguard Worker
262*89c4ff92SAndroid Build Coastguard Worker if (!inputShape.empty())
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker input = inputShape[0];
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker if (!outputShape.empty())
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker output = outputShape[0];
270*89c4ff92SAndroid Build Coastguard Worker }
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker VerifyTosaAttribute(descriptor,
273*89c4ff92SAndroid Build Coastguard Worker op->GetAttribute(),
274*89c4ff92SAndroid Build Coastguard Worker input,
275*89c4ff92SAndroid Build Coastguard Worker output,
276*89c4ff92SAndroid Build Coastguard Worker type);
277*89c4ff92SAndroid Build Coastguard Worker }