xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/ScaleFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_SCALE_FIXTURE
25*c217d954SCole Faust #define ARM_COMPUTE_TEST_SCALE_FIXTURE
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "tests/framework/Fixture.h"
28*c217d954SCole Faust #include "tests/validation/reference/Permute.h"
29*c217d954SCole Faust #include "tests/validation/reference/Scale.h"
30*c217d954SCole Faust 
31*c217d954SCole Faust namespace arm_compute
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace test
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace validation
36*c217d954SCole Faust {
37*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
38*c217d954SCole Faust class ScaleValidationGenericFixture : public framework::Fixture
39*c217d954SCole Faust {
40*c217d954SCole Faust public:
41*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,DataType data_type,QuantizationInfo quantization_info,DataLayout data_layout,InterpolationPolicy policy,BorderMode border_mode,SamplingPolicy sampling_policy,bool align_corners,bool mixed_layout,QuantizationInfo output_quantization_info)42*c217d954SCole Faust     void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy,
43*c217d954SCole Faust                bool align_corners, bool mixed_layout, QuantizationInfo output_quantization_info)
44*c217d954SCole Faust     {
45*c217d954SCole Faust         _shape                    = shape;
46*c217d954SCole Faust         _policy                   = policy;
47*c217d954SCole Faust         _border_mode              = border_mode;
48*c217d954SCole Faust         _sampling_policy          = sampling_policy;
49*c217d954SCole Faust         _data_type                = data_type;
50*c217d954SCole Faust         _input_quantization_info  = quantization_info;
51*c217d954SCole Faust         _output_quantization_info = output_quantization_info;
52*c217d954SCole Faust         _align_corners            = align_corners;
53*c217d954SCole Faust         _mixed_layout             = mixed_layout;
54*c217d954SCole Faust 
55*c217d954SCole Faust         generate_scale(shape);
56*c217d954SCole Faust 
57*c217d954SCole Faust         std::mt19937                            generator(library->seed());
58*c217d954SCole Faust         std::uniform_int_distribution<uint32_t> distribution_u8(0, 255);
59*c217d954SCole Faust         _constant_border_value = static_cast<T>(distribution_u8(generator));
60*c217d954SCole Faust 
61*c217d954SCole Faust         _target    = compute_target(shape, data_layout);
62*c217d954SCole Faust         _reference = compute_reference(shape);
63*c217d954SCole Faust     }
64*c217d954SCole Faust 
65*c217d954SCole Faust protected:
mix_layout(FunctionType & layer,TensorType & src,TensorType & dst)66*c217d954SCole Faust     void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst)
67*c217d954SCole Faust     {
68*c217d954SCole Faust         const DataLayout data_layout = src.info()->data_layout();
69*c217d954SCole Faust         // Test Multi DataLayout graph cases, when the data layout changes after configure
70*c217d954SCole Faust         src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
71*c217d954SCole Faust         dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
72*c217d954SCole Faust 
73*c217d954SCole Faust         // Compute Convolution function
74*c217d954SCole Faust         layer.run();
75*c217d954SCole Faust 
76*c217d954SCole Faust         // Reinstating original data layout for the test suite to properly check the values
77*c217d954SCole Faust         src.info()->set_data_layout(data_layout);
78*c217d954SCole Faust         dst.info()->set_data_layout(data_layout);
79*c217d954SCole Faust     }
80*c217d954SCole Faust 
generate_scale(const TensorShape & shape)81*c217d954SCole Faust     void generate_scale(const TensorShape &shape)
82*c217d954SCole Faust     {
83*c217d954SCole Faust         static constexpr float _min_scale{ 0.25f };
84*c217d954SCole Faust         static constexpr float _max_scale{ 3.f };
85*c217d954SCole Faust 
86*c217d954SCole Faust         constexpr float max_width{ 8192.0f };
87*c217d954SCole Faust         constexpr float max_height{ 6384.0f };
88*c217d954SCole Faust         const float     min_width{ 1.f };
89*c217d954SCole Faust         const float     min_height{ 1.f };
90*c217d954SCole Faust 
91*c217d954SCole Faust         std::mt19937                          generator(library->seed());
92*c217d954SCole Faust         std::uniform_real_distribution<float> distribution_float(_min_scale, _max_scale);
93*c217d954SCole Faust 
94*c217d954SCole Faust         auto generate = [&](size_t input_size, float min_output, float max_output) -> float
95*c217d954SCole Faust         {
96*c217d954SCole Faust             const float generated_scale = distribution_float(generator);
97*c217d954SCole Faust             const float output_size     = utility::clamp(static_cast<float>(input_size) * generated_scale, min_output, max_output);
98*c217d954SCole Faust             return output_size / input_size;
99*c217d954SCole Faust         };
100*c217d954SCole Faust 
101*c217d954SCole Faust         // Input shape is always given in NCHW layout. NHWC is dealt by permute in compute_target()
102*c217d954SCole Faust         const int idx_width  = get_data_layout_dimension_index(DataLayout::NCHW, DataLayoutDimension::WIDTH);
103*c217d954SCole Faust         const int idx_height = get_data_layout_dimension_index(DataLayout::NCHW, DataLayoutDimension::HEIGHT);
104*c217d954SCole Faust 
105*c217d954SCole Faust         _scale_x = generate(shape[idx_width], min_width, max_width);
106*c217d954SCole Faust         _scale_y = generate(shape[idx_height], min_height, max_height);
107*c217d954SCole Faust     }
108*c217d954SCole Faust 
109*c217d954SCole Faust     template <typename U>
fill(U && tensor)110*c217d954SCole Faust     void fill(U &&tensor)
111*c217d954SCole Faust     {
112*c217d954SCole Faust         if(tensor.data_type() == DataType::F32)
113*c217d954SCole Faust         {
114*c217d954SCole Faust             std::uniform_real_distribution<float> distribution(-5.0f, 5.0f);
115*c217d954SCole Faust             library->fill(tensor, distribution, 0);
116*c217d954SCole Faust         }
117*c217d954SCole Faust         else if(tensor.data_type() == DataType::F16)
118*c217d954SCole Faust         {
119*c217d954SCole Faust             arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -5.0f, 5.0f };
120*c217d954SCole Faust             library->fill(tensor, distribution, 0);
121*c217d954SCole Faust         }
122*c217d954SCole Faust         else if(is_data_type_quantized(tensor.data_type()))
123*c217d954SCole Faust         {
124*c217d954SCole Faust             std::uniform_int_distribution<> distribution(0, 100);
125*c217d954SCole Faust             library->fill(tensor, distribution, 0);
126*c217d954SCole Faust         }
127*c217d954SCole Faust         else
128*c217d954SCole Faust         {
129*c217d954SCole Faust             library->fill_tensor_uniform(tensor, 0);
130*c217d954SCole Faust         }
131*c217d954SCole Faust     }
132*c217d954SCole Faust 
compute_target(TensorShape shape,DataLayout data_layout)133*c217d954SCole Faust     TensorType compute_target(TensorShape shape, DataLayout data_layout)
134*c217d954SCole Faust     {
135*c217d954SCole Faust         // Change shape in case of NHWC.
136*c217d954SCole Faust         if(data_layout == DataLayout::NHWC)
137*c217d954SCole Faust         {
138*c217d954SCole Faust             permute(shape, PermutationVector(2U, 0U, 1U));
139*c217d954SCole Faust         }
140*c217d954SCole Faust 
141*c217d954SCole Faust         // Create tensors
142*c217d954SCole Faust         TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _input_quantization_info, data_layout);
143*c217d954SCole Faust 
144*c217d954SCole Faust         const int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
145*c217d954SCole Faust         const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
146*c217d954SCole Faust 
147*c217d954SCole Faust         TensorShape shape_scaled(shape);
148*c217d954SCole Faust         shape_scaled.set(idx_width, shape[idx_width] * _scale_x, /* apply_dim_correction = */ false);
149*c217d954SCole Faust         shape_scaled.set(idx_height, shape[idx_height] * _scale_y, /* apply_dim_correction = */ false);
150*c217d954SCole Faust         TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _output_quantization_info, data_layout);
151*c217d954SCole Faust 
152*c217d954SCole Faust         // Create and configure function
153*c217d954SCole Faust         FunctionType scale;
154*c217d954SCole Faust 
155*c217d954SCole Faust         scale.configure(&src, &dst, ScaleKernelInfo{ _policy, _border_mode, _constant_border_value, _sampling_policy, /* use_padding */ false, _align_corners });
156*c217d954SCole Faust 
157*c217d954SCole Faust         ARM_COMPUTE_ASSERT(src.info()->is_resizable());
158*c217d954SCole Faust         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
159*c217d954SCole Faust 
160*c217d954SCole Faust         add_padding_x({ &src, &dst }, data_layout);
161*c217d954SCole Faust 
162*c217d954SCole Faust         // Allocate tensors
163*c217d954SCole Faust         src.allocator()->allocate();
164*c217d954SCole Faust         dst.allocator()->allocate();
165*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
166*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
167*c217d954SCole Faust 
168*c217d954SCole Faust         // Fill tensors
169*c217d954SCole Faust         fill(AccessorType(src));
170*c217d954SCole Faust 
171*c217d954SCole Faust         if(_mixed_layout)
172*c217d954SCole Faust         {
173*c217d954SCole Faust             mix_layout(scale, src, dst);
174*c217d954SCole Faust         }
175*c217d954SCole Faust         else
176*c217d954SCole Faust         {
177*c217d954SCole Faust             // Compute function
178*c217d954SCole Faust             scale.run();
179*c217d954SCole Faust         }
180*c217d954SCole Faust         return dst;
181*c217d954SCole Faust     }
182*c217d954SCole Faust 
compute_reference(const TensorShape & shape)183*c217d954SCole Faust     SimpleTensor<T> compute_reference(const TensorShape &shape)
184*c217d954SCole Faust     {
185*c217d954SCole Faust         // Create reference
186*c217d954SCole Faust         SimpleTensor<T> src{ shape, _data_type, 1, _input_quantization_info };
187*c217d954SCole Faust 
188*c217d954SCole Faust         // Fill reference
189*c217d954SCole Faust         fill(src);
190*c217d954SCole Faust 
191*c217d954SCole Faust         return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners, _output_quantization_info);
192*c217d954SCole Faust     }
193*c217d954SCole Faust 
194*c217d954SCole Faust     TensorType          _target{};
195*c217d954SCole Faust     SimpleTensor<T>     _reference{};
196*c217d954SCole Faust     TensorShape         _shape{};
197*c217d954SCole Faust     InterpolationPolicy _policy{};
198*c217d954SCole Faust     BorderMode          _border_mode{};
199*c217d954SCole Faust     T                   _constant_border_value{};
200*c217d954SCole Faust     SamplingPolicy      _sampling_policy{};
201*c217d954SCole Faust     DataType            _data_type{};
202*c217d954SCole Faust     QuantizationInfo    _input_quantization_info{};
203*c217d954SCole Faust     QuantizationInfo    _output_quantization_info{};
204*c217d954SCole Faust     bool                _align_corners{ false };
205*c217d954SCole Faust     bool                _mixed_layout{ false };
206*c217d954SCole Faust     float               _scale_x{ 1.f };
207*c217d954SCole Faust     float               _scale_y{ 1.f };
208*c217d954SCole Faust };
209*c217d954SCole Faust 
210*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
211*c217d954SCole Faust class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
212*c217d954SCole Faust {
213*c217d954SCole Faust public:
214*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,DataType data_type,QuantizationInfo quantization_info,DataLayout data_layout,InterpolationPolicy policy,BorderMode border_mode,SamplingPolicy sampling_policy,bool align_corners)215*c217d954SCole Faust     void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy,
216*c217d954SCole Faust                bool align_corners)
217*c217d954SCole Faust     {
218*c217d954SCole Faust         ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
219*c217d954SCole Faust                                                                                         data_type,
220*c217d954SCole Faust                                                                                         quantization_info,
221*c217d954SCole Faust                                                                                         data_layout,
222*c217d954SCole Faust                                                                                         policy,
223*c217d954SCole Faust                                                                                         border_mode,
224*c217d954SCole Faust                                                                                         sampling_policy,
225*c217d954SCole Faust                                                                                         align_corners,
226*c217d954SCole Faust                                                                                         mixed_layout,
227*c217d954SCole Faust                                                                                         quantization_info);
228*c217d954SCole Faust     }
229*c217d954SCole Faust };
230*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
231*c217d954SCole Faust class ScaleValidationDifferentOutputQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
232*c217d954SCole Faust {
233*c217d954SCole Faust public:
234*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,DataType data_type,QuantizationInfo input_quantization_info,QuantizationInfo output_quantization_info,DataLayout data_layout,InterpolationPolicy policy,BorderMode border_mode,SamplingPolicy sampling_policy,bool align_corners)235*c217d954SCole Faust     void setup(TensorShape shape, DataType data_type, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, InterpolationPolicy policy,
236*c217d954SCole Faust                BorderMode border_mode, SamplingPolicy sampling_policy,
237*c217d954SCole Faust                bool align_corners)
238*c217d954SCole Faust     {
239*c217d954SCole Faust         ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
240*c217d954SCole Faust                                                                                         data_type,
241*c217d954SCole Faust                                                                                         input_quantization_info,
242*c217d954SCole Faust                                                                                         data_layout,
243*c217d954SCole Faust                                                                                         policy,
244*c217d954SCole Faust                                                                                         border_mode,
245*c217d954SCole Faust                                                                                         sampling_policy,
246*c217d954SCole Faust                                                                                         align_corners,
247*c217d954SCole Faust                                                                                         mixed_layout,
248*c217d954SCole Faust                                                                                         output_quantization_info);
249*c217d954SCole Faust     }
250*c217d954SCole Faust };
251*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
252*c217d954SCole Faust class ScaleValidationFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
253*c217d954SCole Faust {
254*c217d954SCole Faust public:
255*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,DataType data_type,DataLayout data_layout,InterpolationPolicy policy,BorderMode border_mode,SamplingPolicy sampling_policy,bool align_corners)256*c217d954SCole Faust     void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners)
257*c217d954SCole Faust     {
258*c217d954SCole Faust         ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
259*c217d954SCole Faust                                                                                         data_type,
260*c217d954SCole Faust                                                                                         QuantizationInfo(),
261*c217d954SCole Faust                                                                                         data_layout,
262*c217d954SCole Faust                                                                                         policy,
263*c217d954SCole Faust                                                                                         border_mode,
264*c217d954SCole Faust                                                                                         sampling_policy,
265*c217d954SCole Faust                                                                                         align_corners,
266*c217d954SCole Faust                                                                                         mixed_layout,
267*c217d954SCole Faust                                                                                         QuantizationInfo());
268*c217d954SCole Faust     }
269*c217d954SCole Faust };
270*c217d954SCole Faust } // namespace validation
271*c217d954SCole Faust } // namespace test
272*c217d954SCole Faust } // namespace arm_compute
273*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_SCALE_FIXTURE */
274