1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2017-2022 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #include "arm_compute/core/TensorShape.h" 25*c217d954SCole Faust #include "arm_compute/core/Types.h" 26*c217d954SCole Faust #include "arm_compute/core/utils/misc/ShapeCalculator.h" 27*c217d954SCole Faust #include "tests/AssetsLibrary.h" 28*c217d954SCole Faust #include "tests/Globals.h" 29*c217d954SCole Faust #include "tests/IAccessor.h" 30*c217d954SCole Faust #include "tests/framework/Asserts.h" 31*c217d954SCole Faust #include "tests/framework/Fixture.h" 32*c217d954SCole Faust #include "tests/validation/Helpers.h" 33*c217d954SCole Faust #include "tests/validation/reference/DeconvolutionLayer.h" 34*c217d954SCole Faust 35*c217d954SCole Faust #include <random> 36*c217d954SCole Faust 37*c217d954SCole Faust namespace arm_compute 38*c217d954SCole Faust { 39*c217d954SCole Faust namespace test 40*c217d954SCole Faust { 41*c217d954SCole Faust namespace validation 42*c217d954SCole Faust { 43*c217d954SCole Faust using namespace arm_compute::misc::shape_calculator; 44*c217d954SCole Faust 45*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW> 46*c217d954SCole Faust class DeconvolutionLayerFixtureBase : public framework::Fixture 47*c217d954SCole Faust { 48*c217d954SCole Faust public: 49*c217d954SCole Faust using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t>::value || std::is_same<typename std::decay<T>::type, int8_t>::value, int32_t, T >::type; 50*c217d954SCole Faust 51*c217d954SCole Faust public: 52*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape output_shape,PadStrideInfo info,DataType data_type,DataType weights_data_type,DataLayout data_layout,QuantizationInfo input_quantization_info,QuantizationInfo output_quantization_info,QuantizationInfo weights_quantization_info,bool add_bias)53*c217d954SCole Faust void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, 54*c217d954SCole Faust DataType data_type, DataType weights_data_type, DataLayout data_layout, 55*c217d954SCole Faust QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, QuantizationInfo weights_quantization_info, bool add_bias) 56*c217d954SCole Faust { 57*c217d954SCole Faust _data_type = data_type; 58*c217d954SCole Faust _weights_data_type = weights_data_type; 59*c217d954SCole Faust _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; 60*c217d954SCole Faust _data_layout = data_layout; 61*c217d954SCole Faust _input_quantization_info = input_quantization_info; 62*c217d954SCole Faust _output_quantization_info = output_quantization_info; 63*c217d954SCole Faust _weights_quantization_info = weights_quantization_info; 64*c217d954SCole Faust 65*c217d954SCole Faust _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, add_bias); 66*c217d954SCole Faust _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, add_bias); 67*c217d954SCole Faust } 68*c217d954SCole Faust 69*c217d954SCole Faust protected: 70*c217d954SCole Faust template <typename U> fill(U && tensor,int i)71*c217d954SCole Faust void fill(U &&tensor, int i) 72*c217d954SCole Faust { 73*c217d954SCole Faust switch(tensor.data_type()) 74*c217d954SCole Faust { 75*c217d954SCole Faust case DataType::QASYMM8: 76*c217d954SCole Faust { 77*c217d954SCole Faust std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); 78*c217d954SCole Faust std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second); 79*c217d954SCole Faust library->fill(tensor, distribution, i); 80*c217d954SCole Faust break; 81*c217d954SCole Faust } 82*c217d954SCole Faust case DataType::QASYMM8_SIGNED: 83*c217d954SCole Faust { 84*c217d954SCole Faust std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); 85*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second); 86*c217d954SCole Faust library->fill(tensor, distribution, i); 87*c217d954SCole Faust break; 88*c217d954SCole Faust } 89*c217d954SCole Faust case DataType::QSYMM8_PER_CHANNEL: 90*c217d954SCole Faust { 91*c217d954SCole Faust int min_bound = 128; 92*c217d954SCole Faust int max_bound = -127; 93*c217d954SCole Faust for(size_t i = 0; i < _input_quantization_info.scale().size(); i++) 94*c217d954SCole Faust { 95*c217d954SCole Faust std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f); 96*c217d954SCole Faust if(bounds.first < min_bound) 97*c217d954SCole Faust { 98*c217d954SCole Faust min_bound = bounds.first; 99*c217d954SCole Faust } 100*c217d954SCole Faust if(bounds.second > max_bound) 101*c217d954SCole Faust { 102*c217d954SCole Faust max_bound = bounds.second; 103*c217d954SCole Faust } 104*c217d954SCole Faust } 105*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(min_bound, max_bound); 106*c217d954SCole Faust library->fill(tensor, distribution, i); 107*c217d954SCole Faust break; 108*c217d954SCole Faust } 109*c217d954SCole Faust case DataType::S32: 110*c217d954SCole Faust { 111*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(-100, 100); 112*c217d954SCole Faust library->fill(tensor, distribution, i); 113*c217d954SCole Faust break; 114*c217d954SCole Faust } 115*c217d954SCole Faust case DataType::F16: 116*c217d954SCole Faust { 117*c217d954SCole Faust arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f }; 118*c217d954SCole Faust library->fill(tensor, distribution, i); 119*c217d954SCole Faust break; 120*c217d954SCole Faust } 121*c217d954SCole Faust case DataType::F32: 122*c217d954SCole Faust { 123*c217d954SCole Faust std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); 124*c217d954SCole Faust library->fill(tensor, distribution, i); 125*c217d954SCole Faust break; 126*c217d954SCole Faust } 127*c217d954SCole Faust default: 128*c217d954SCole Faust library->fill_tensor_uniform(tensor, i); 129*c217d954SCole Faust } 130*c217d954SCole Faust } 131*c217d954SCole Faust 132*c217d954SCole Faust template <typename U> fill_zeros(U && tensor)133*c217d954SCole Faust void fill_zeros(U &&tensor) 134*c217d954SCole Faust { 135*c217d954SCole Faust switch(tensor.data_type()) 136*c217d954SCole Faust { 137*c217d954SCole Faust case DataType::S32: 138*c217d954SCole Faust { 139*c217d954SCole Faust library->fill_tensor_value(tensor, 0); 140*c217d954SCole Faust break; 141*c217d954SCole Faust } 142*c217d954SCole Faust case DataType::F16: 143*c217d954SCole Faust library->fill_tensor_value(tensor, static_cast<half>(0.0f)); 144*c217d954SCole Faust break; 145*c217d954SCole Faust case DataType::F32: 146*c217d954SCole Faust library->fill_tensor_value(tensor, static_cast<float>(0.0f)); 147*c217d954SCole Faust break; 148*c217d954SCole Faust default: 149*c217d954SCole Faust ARM_COMPUTE_ERROR("Not supported"); 150*c217d954SCole Faust } 151*c217d954SCole Faust } 152*c217d954SCole Faust compute_target(TensorShape input_shape,TensorShape weights_shape,const TensorShape bias_shape,TensorShape output_shape,const PadStrideInfo & info,bool add_bias)153*c217d954SCole Faust TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape bias_shape, TensorShape output_shape, 154*c217d954SCole Faust const PadStrideInfo &info, bool add_bias) 155*c217d954SCole Faust { 156*c217d954SCole Faust if(_data_layout == DataLayout::NHWC) 157*c217d954SCole Faust { 158*c217d954SCole Faust permute(input_shape, PermutationVector(2U, 0U, 1U)); 159*c217d954SCole Faust permute(weights_shape, PermutationVector(2U, 0U, 1U)); 160*c217d954SCole Faust permute(output_shape, PermutationVector(2U, 0U, 1U)); 161*c217d954SCole Faust } 162*c217d954SCole Faust 163*c217d954SCole Faust // Create tensors 164*c217d954SCole Faust TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _input_quantization_info, _data_layout); 165*c217d954SCole Faust TensorType weights = create_tensor<TensorType>(weights_shape, _weights_data_type, 1, _weights_quantization_info, _data_layout); 166*c217d954SCole Faust TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _input_quantization_info, _data_layout); 167*c217d954SCole Faust TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _output_quantization_info, _data_layout); 168*c217d954SCole Faust 169*c217d954SCole Faust // Create and configure function 170*c217d954SCole Faust FunctionType conv; 171*c217d954SCole Faust conv.configure(&src, &weights, add_bias ? &bias : nullptr, &dst, info); 172*c217d954SCole Faust 173*c217d954SCole Faust ARM_COMPUTE_ASSERT(src.info()->is_resizable()); 174*c217d954SCole Faust ARM_COMPUTE_ASSERT(weights.info()->is_resizable()); 175*c217d954SCole Faust if(add_bias) 176*c217d954SCole Faust { 177*c217d954SCole Faust ARM_COMPUTE_ASSERT(bias.info()->is_resizable()); 178*c217d954SCole Faust } 179*c217d954SCole Faust ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); 180*c217d954SCole Faust 181*c217d954SCole Faust // Allocate tensors 182*c217d954SCole Faust src.allocator()->allocate(); 183*c217d954SCole Faust weights.allocator()->allocate(); 184*c217d954SCole Faust if(add_bias) 185*c217d954SCole Faust { 186*c217d954SCole Faust bias.allocator()->allocate(); 187*c217d954SCole Faust } 188*c217d954SCole Faust dst.allocator()->allocate(); 189*c217d954SCole Faust 190*c217d954SCole Faust ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); 191*c217d954SCole Faust ARM_COMPUTE_ASSERT(!weights.info()->is_resizable()); 192*c217d954SCole Faust if(add_bias) 193*c217d954SCole Faust { 194*c217d954SCole Faust ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); 195*c217d954SCole Faust } 196*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); 197*c217d954SCole Faust 198*c217d954SCole Faust // Fill tensors 199*c217d954SCole Faust fill(AccessorType(src), 0); 200*c217d954SCole Faust fill(AccessorType(weights), 1); 201*c217d954SCole Faust if(add_bias) 202*c217d954SCole Faust { 203*c217d954SCole Faust fill(AccessorType(bias), 2); 204*c217d954SCole Faust } 205*c217d954SCole Faust 206*c217d954SCole Faust // Compute DeconvolutionLayer function 207*c217d954SCole Faust conv.run(); 208*c217d954SCole Faust return dst; 209*c217d954SCole Faust } 210*c217d954SCole Faust compute_reference(const TensorShape & input_shape,const TensorShape & weights_shape,const TensorShape & bias_shape,const TensorShape & output_shape,const PadStrideInfo & info,bool add_bias)211*c217d954SCole Faust SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, 212*c217d954SCole Faust const PadStrideInfo &info, bool add_bias) 213*c217d954SCole Faust { 214*c217d954SCole Faust // Create reference 215*c217d954SCole Faust SimpleTensor<T> src{ input_shape, _data_type, 1, _input_quantization_info }; 216*c217d954SCole Faust SimpleTensor<TW> weights{ weights_shape, _weights_data_type, 1, _weights_quantization_info }; 217*c217d954SCole Faust SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _input_quantization_info }; 218*c217d954SCole Faust 219*c217d954SCole Faust // Fill reference 220*c217d954SCole Faust fill(src, 0); 221*c217d954SCole Faust fill(weights, 1); 222*c217d954SCole Faust 223*c217d954SCole Faust if(add_bias) 224*c217d954SCole Faust { 225*c217d954SCole Faust fill(bias, 2); 226*c217d954SCole Faust } 227*c217d954SCole Faust else 228*c217d954SCole Faust { 229*c217d954SCole Faust fill_zeros(bias); 230*c217d954SCole Faust } 231*c217d954SCole Faust return reference::deconvolution_layer<T, TW>(src, weights, bias, output_shape, info, _output_quantization_info); 232*c217d954SCole Faust } 233*c217d954SCole Faust 234*c217d954SCole Faust TensorType _target{}; 235*c217d954SCole Faust SimpleTensor<T> _reference{}; 236*c217d954SCole Faust DataType _data_type{}; 237*c217d954SCole Faust DataType _weights_data_type{}; 238*c217d954SCole Faust DataType _bias_data_type{}; 239*c217d954SCole Faust DataLayout _data_layout{}; 240*c217d954SCole Faust QuantizationInfo _input_quantization_info{}; 241*c217d954SCole Faust QuantizationInfo _output_quantization_info{}; 242*c217d954SCole Faust QuantizationInfo _weights_quantization_info{}; 243*c217d954SCole Faust }; 244*c217d954SCole Faust 245*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, unsigned int kernel_size_x, unsigned int kernel_size_y> 246*c217d954SCole Faust class DeconvolutionValidationFixture : public DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T> 247*c217d954SCole Faust { 248*c217d954SCole Faust public: 249*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,unsigned int sx,unsigned int sy,unsigned int padx,unsigned int pady,unsigned int num_kernels,DataType data_type,DataLayout data_layout,bool add_bias)250*c217d954SCole Faust void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int padx, unsigned int pady, 251*c217d954SCole Faust unsigned int num_kernels, DataType data_type, DataLayout data_layout, bool add_bias) 252*c217d954SCole Faust { 253*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported"); 254*c217d954SCole Faust const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels); 255*c217d954SCole Faust const TensorShape bias_shape(num_kernels); 256*c217d954SCole Faust const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL); 257*c217d954SCole Faust auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info); 258*c217d954SCole Faust TensorInfo input_info(input_shape, 1, data_type); 259*c217d954SCole Faust TensorInfo weights_info(weights_shape, 1, data_type); 260*c217d954SCole Faust TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); 261*c217d954SCole Faust DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_type, data_layout, QuantizationInfo(), 262*c217d954SCole Faust QuantizationInfo(), QuantizationInfo(), add_bias); 263*c217d954SCole Faust } 264*c217d954SCole Faust }; 265*c217d954SCole Faust 266*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, unsigned int kernel_size_x, unsigned int kernel_size_y> 267*c217d954SCole Faust class DeconvolutionValidationAsymmFixture : public DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T> 268*c217d954SCole Faust { 269*c217d954SCole Faust public: 270*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,unsigned int sx,unsigned int sy,unsigned int pad_left,unsigned int pad_right,unsigned int pad_top,unsigned int pad_bottom,unsigned int num_kernels,DataType data_type,DataLayout data_layout,bool add_bias)271*c217d954SCole Faust void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int pad_left, unsigned int pad_right, unsigned int pad_top, 272*c217d954SCole Faust unsigned int pad_bottom, unsigned int num_kernels, DataType data_type, DataLayout data_layout, bool add_bias) 273*c217d954SCole Faust { 274*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported"); 275*c217d954SCole Faust const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels); 276*c217d954SCole Faust const TensorShape bias_shape(num_kernels); 277*c217d954SCole Faust const PadStrideInfo info(sx, sy, pad_left, pad_right, pad_top, pad_bottom, DimensionRoundingType::CEIL); 278*c217d954SCole Faust auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info); 279*c217d954SCole Faust TensorInfo input_info(input_shape, 1, data_type); 280*c217d954SCole Faust TensorInfo weights_info(weights_shape, 1, data_type); 281*c217d954SCole Faust TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); 282*c217d954SCole Faust DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_type, data_layout, QuantizationInfo(), 283*c217d954SCole Faust QuantizationInfo(), QuantizationInfo(), add_bias); 284*c217d954SCole Faust } 285*c217d954SCole Faust }; 286*c217d954SCole Faust 287*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, unsigned int kernel_size_x, unsigned int kernel_size_y> 288*c217d954SCole Faust class DeconvolutionValidationQuantizedFixture : public DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T> 289*c217d954SCole Faust { 290*c217d954SCole Faust public: 291*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,unsigned int sx,unsigned int sy,unsigned int padx,unsigned int pady,unsigned int num_kernels,DataType data_type,DataLayout data_layout,QuantizationInfo input_quantization_info,QuantizationInfo output_quantization_info,bool add_bias)292*c217d954SCole Faust void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int padx, unsigned int pady, 293*c217d954SCole Faust unsigned int num_kernels, DataType data_type, DataLayout data_layout, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, bool add_bias) 294*c217d954SCole Faust { 295*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported"); 296*c217d954SCole Faust const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels); 297*c217d954SCole Faust const TensorShape bias_shape(num_kernels); 298*c217d954SCole Faust const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL); 299*c217d954SCole Faust auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info); 300*c217d954SCole Faust TensorInfo input_info(input_shape, 1, data_type, input_quantization_info); 301*c217d954SCole Faust TensorInfo weights_info(weights_shape, 1, data_type, input_quantization_info); 302*c217d954SCole Faust TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); 303*c217d954SCole Faust DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_type, data_layout, 304*c217d954SCole Faust input_quantization_info, 305*c217d954SCole Faust output_quantization_info, input_quantization_info, add_bias); 306*c217d954SCole Faust } 307*c217d954SCole Faust }; 308*c217d954SCole Faust 309*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW, unsigned int kernel_size_x, unsigned int kernel_size_y> 310*c217d954SCole Faust class DeconvolutionValidationQuantizedPerChannelFixture : public DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, TW> 311*c217d954SCole Faust { 312*c217d954SCole Faust public: 313*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,unsigned int sx,unsigned int sy,unsigned int padx,unsigned int pady,unsigned int num_kernels,DataType data_type,DataLayout data_layout,QuantizationInfo input_quantization_info,QuantizationInfo output_quantization_info,bool add_bias,DataType weights_data_type)314*c217d954SCole Faust void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int padx, unsigned int pady, 315*c217d954SCole Faust unsigned int num_kernels, DataType data_type, DataLayout data_layout, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, bool add_bias, 316*c217d954SCole Faust DataType weights_data_type) 317*c217d954SCole Faust { 318*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported"); 319*c217d954SCole Faust const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels); 320*c217d954SCole Faust const TensorShape bias_shape(num_kernels); 321*c217d954SCole Faust const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL); 322*c217d954SCole Faust auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info); 323*c217d954SCole Faust TensorInfo input_info(input_shape, 1, data_type, input_quantization_info); 324*c217d954SCole Faust TensorInfo weights_info(weights_shape, 1, weights_data_type, input_quantization_info); 325*c217d954SCole Faust TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info); 326*c217d954SCole Faust 327*c217d954SCole Faust std::vector<float> weights_scales{}; 328*c217d954SCole Faust std::mt19937 gen(library->seed()); 329*c217d954SCole Faust std::uniform_real_distribution<float> dis(0.01f, 1.f); 330*c217d954SCole Faust for(size_t i = 0; i < output_shape[2]; ++i) 331*c217d954SCole Faust { 332*c217d954SCole Faust weights_scales.push_back(dis(gen)); 333*c217d954SCole Faust } 334*c217d954SCole Faust DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T, TW>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, weights_data_type, data_layout, 335*c217d954SCole Faust input_quantization_info, 336*c217d954SCole Faust output_quantization_info, QuantizationInfo(weights_scales), add_bias); 337*c217d954SCole Faust } 338*c217d954SCole Faust }; 339*c217d954SCole Faust 340*c217d954SCole Faust } // namespace validation 341*c217d954SCole Faust } // namespace test 342*c217d954SCole Faust } // namespace arm_compute 343