xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/AddMulAddFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2023 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust 
25*c217d954SCole Faust #ifndef TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE
26*c217d954SCole Faust #define TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE
27*c217d954SCole Faust 
28*c217d954SCole Faust #include "arm_compute/core/TensorShape.h"
29*c217d954SCole Faust #include "arm_compute/core/Types.h"
30*c217d954SCole Faust #include "tests/AssetsLibrary.h"
31*c217d954SCole Faust #include "tests/Globals.h"
32*c217d954SCole Faust #include "tests/IAccessor.h"
33*c217d954SCole Faust #include "tests/framework/Asserts.h"
34*c217d954SCole Faust #include "tests/framework/Fixture.h"
35*c217d954SCole Faust #include "tests/validation/Helpers.h"
36*c217d954SCole Faust #include "tests/validation/reference/ActivationLayer.h"
37*c217d954SCole Faust #include "tests/validation/reference/ArithmeticOperations.h"
38*c217d954SCole Faust #include "tests/validation/reference/DequantizationLayer.h"
39*c217d954SCole Faust #include "tests/validation/reference/PixelWiseMultiplication.h"
40*c217d954SCole Faust #include "tests/validation/reference/QuantizationLayer.h"
41*c217d954SCole Faust 
42*c217d954SCole Faust namespace arm_compute
43*c217d954SCole Faust {
44*c217d954SCole Faust namespace test
45*c217d954SCole Faust {
46*c217d954SCole Faust namespace validation
47*c217d954SCole Faust {
48*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
49*c217d954SCole Faust class AddMulAddGenericFixture : public framework::Fixture
50*c217d954SCole Faust {
51*c217d954SCole Faust public:
52*c217d954SCole Faust     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info,bool interm_out)53*c217d954SCole Faust     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info, bool interm_out)
54*c217d954SCole Faust     {
55*c217d954SCole Faust         compute_target(shape, data_type, act_info, interm_out);
56*c217d954SCole Faust     }
57*c217d954SCole Faust 
58*c217d954SCole Faust protected:
59*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i,DataType data_type)60*c217d954SCole Faust     void fill(U &&tensor, int i, DataType data_type)
61*c217d954SCole Faust     {
62*c217d954SCole Faust         switch(data_type)
63*c217d954SCole Faust         {
64*c217d954SCole Faust             case DataType::F32:
65*c217d954SCole Faust                 library->fill_tensor_uniform(tensor, i, -10.f, 10.f);
66*c217d954SCole Faust                 break;
67*c217d954SCole Faust             case DataType::F16:
68*c217d954SCole Faust                 library->fill_tensor_uniform(tensor, i, -1.f, 1.f);
69*c217d954SCole Faust                 break;
70*c217d954SCole Faust             default:
71*c217d954SCole Faust                 library->fill_tensor_uniform(tensor, i);
72*c217d954SCole Faust                 break;
73*c217d954SCole Faust         }
74*c217d954SCole Faust     }
75*c217d954SCole Faust 
compute_target(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info,bool interm_out)76*c217d954SCole Faust     void compute_target(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info, bool interm_out)
77*c217d954SCole Faust     {
78*c217d954SCole Faust         TensorShape b_shape(shape.x());
79*c217d954SCole Faust 
80*c217d954SCole Faust         // Create tensors
81*c217d954SCole Faust         TensorType input1       = create_tensor<TensorType>(shape, data_type, 1, _input1_qinfo);
82*c217d954SCole Faust         TensorType input2       = create_tensor<TensorType>(shape, data_type, 1, _input2_qinfo);
83*c217d954SCole Faust         TensorType bn_mul       = create_tensor<TensorType>(b_shape, data_type, 1, _bn_mul_qinfo);
84*c217d954SCole Faust         TensorType bn_add       = create_tensor<TensorType>(b_shape, data_type, 1, _bn_add_qinfo);
85*c217d954SCole Faust         TensorType add_output   = create_tensor<TensorType>(shape, data_type, 1, _add_output_qinfo);
86*c217d954SCole Faust         TensorType final_output = create_tensor<TensorType>(shape, data_type, 1, _final_output_qinfo);
87*c217d954SCole Faust 
88*c217d954SCole Faust         // Create and configure function
89*c217d954SCole Faust         FunctionType add_mul_add;
90*c217d954SCole Faust         add_mul_add.configure(&input1, &input2, &bn_mul, &bn_add, interm_out ? &add_output : nullptr, &final_output, ConvertPolicy::SATURATE, act_info);
91*c217d954SCole Faust 
92*c217d954SCole Faust         // Allocate tensors
93*c217d954SCole Faust         input1.allocator()->allocate();
94*c217d954SCole Faust         input2.allocator()->allocate();
95*c217d954SCole Faust         bn_mul.allocator()->allocate();
96*c217d954SCole Faust         bn_add.allocator()->allocate();
97*c217d954SCole Faust 
98*c217d954SCole Faust         if(interm_out)
99*c217d954SCole Faust         {
100*c217d954SCole Faust             add_output.allocator()->allocate();
101*c217d954SCole Faust         }
102*c217d954SCole Faust 
103*c217d954SCole Faust         final_output.allocator()->allocate();
104*c217d954SCole Faust 
105*c217d954SCole Faust         // Fill tensors
106*c217d954SCole Faust         fill(AccessorType(input1), 0, data_type);
107*c217d954SCole Faust         fill(AccessorType(input2), 1, data_type);
108*c217d954SCole Faust         fill(AccessorType(bn_mul), 2, data_type);
109*c217d954SCole Faust         fill(AccessorType(bn_add), 3, data_type);
110*c217d954SCole Faust 
111*c217d954SCole Faust         // // Compute function
112*c217d954SCole Faust         add_mul_add.run();
113*c217d954SCole Faust 
114*c217d954SCole Faust         _target = std::move(final_output);
115*c217d954SCole Faust 
116*c217d954SCole Faust         if(interm_out)
117*c217d954SCole Faust         {
118*c217d954SCole Faust             _interm_target = std::move(add_output);
119*c217d954SCole Faust         }
120*c217d954SCole Faust     }
121*c217d954SCole Faust 
122*c217d954SCole Faust     TensorType      _target{};
123*c217d954SCole Faust     TensorType      _interm_target{};
124*c217d954SCole Faust     SimpleTensor<T> _reference{};
125*c217d954SCole Faust     SimpleTensor<T> _interm_reference{};
126*c217d954SCole Faust 
127*c217d954SCole Faust     QuantizationInfo _input1_qinfo{};
128*c217d954SCole Faust     QuantizationInfo _input2_qinfo{};
129*c217d954SCole Faust     QuantizationInfo _bn_mul_qinfo{};
130*c217d954SCole Faust     QuantizationInfo _bn_add_qinfo{};
131*c217d954SCole Faust     QuantizationInfo _add_output_qinfo{};
132*c217d954SCole Faust     QuantizationInfo _final_output_qinfo{};
133*c217d954SCole Faust };
134*c217d954SCole Faust 
135*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool interm_out>
136*c217d954SCole Faust class AddMulAddFloatValidationFixture : public AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>
137*c217d954SCole Faust {
138*c217d954SCole Faust public:
139*c217d954SCole Faust     using Parent = AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>;
140*c217d954SCole Faust 
141*c217d954SCole Faust     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo act_info)142*c217d954SCole Faust     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo act_info)
143*c217d954SCole Faust     {
144*c217d954SCole Faust         Parent::setup(shape, data_type, act_info, interm_out);
145*c217d954SCole Faust         compute_reference(shape, data_type, act_info);
146*c217d954SCole Faust     }
147*c217d954SCole Faust 
148*c217d954SCole Faust     // Compute Reference is moved outside of the generic fixture because with the quantized data types,
149*c217d954SCole Faust     // it becomes a very different implementation with intermediate tensors' data types being always float.
150*c217d954SCole Faust     // This way the reference calculations are more readable and the size of the classes will be smaller
151*c217d954SCole Faust     // due to unrepeated fill() and target() methods.
compute_reference(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info)152*c217d954SCole Faust     void compute_reference(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info)
153*c217d954SCole Faust     {
154*c217d954SCole Faust         TensorShape b_shape(shape.x());
155*c217d954SCole Faust 
156*c217d954SCole Faust         // Create reference
157*c217d954SCole Faust         SimpleTensor<T> input1{ shape, data_type };
158*c217d954SCole Faust         SimpleTensor<T> input2{ shape, data_type };
159*c217d954SCole Faust         SimpleTensor<T> bn_mul{ b_shape, data_type };
160*c217d954SCole Faust         SimpleTensor<T> bn_add{ b_shape, data_type };
161*c217d954SCole Faust         SimpleTensor<T> add_output{ shape, data_type, 1 };
162*c217d954SCole Faust 
163*c217d954SCole Faust         SimpleTensor<T> bn_mul_out{ shape, data_type };
164*c217d954SCole Faust         SimpleTensor<T> bn_add_out{ shape, data_type };
165*c217d954SCole Faust 
166*c217d954SCole Faust         // Fill reference
167*c217d954SCole Faust         Parent::fill(input1, 0, data_type);
168*c217d954SCole Faust         Parent::fill(input2, 1, data_type);
169*c217d954SCole Faust         Parent::fill(bn_mul, 2, data_type);
170*c217d954SCole Faust         Parent::fill(bn_add, 3, data_type);
171*c217d954SCole Faust 
172*c217d954SCole Faust         reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, input1, input2, add_output, ConvertPolicy::SATURATE);
173*c217d954SCole Faust         bn_mul_out = reference::pixel_wise_multiplication<T, T, T>(add_output, bn_mul, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, data_type);
174*c217d954SCole Faust         reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, bn_mul_out, bn_add, bn_add_out, ConvertPolicy::SATURATE);
175*c217d954SCole Faust 
176*c217d954SCole Faust         if(interm_out)
177*c217d954SCole Faust         {
178*c217d954SCole Faust             Parent::_interm_reference = std::move(add_output);
179*c217d954SCole Faust         }
180*c217d954SCole Faust 
181*c217d954SCole Faust         if(act_info.enabled() && act_info.activation() != ActivationLayerInfo::ActivationFunction::IDENTITY)
182*c217d954SCole Faust         {
183*c217d954SCole Faust             Parent::_reference = reference::activation_layer(bn_add_out, act_info);
184*c217d954SCole Faust         }
185*c217d954SCole Faust         else
186*c217d954SCole Faust         {
187*c217d954SCole Faust             Parent::_reference = std::move(bn_add_out);
188*c217d954SCole Faust         }
189*c217d954SCole Faust     }
190*c217d954SCole Faust };
191*c217d954SCole Faust 
192*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool interm_out>
193*c217d954SCole Faust class AddMulAddQuantizedValidationFixture : public AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>
194*c217d954SCole Faust {
195*c217d954SCole Faust public:
196*c217d954SCole Faust     using Parent = AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>;
197*c217d954SCole Faust 
198*c217d954SCole Faust     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo act_info,QuantizationInfo input1_qinfo,QuantizationInfo input2_qinfo,QuantizationInfo bn_mul_qinfo,QuantizationInfo bn_add_qinfo,QuantizationInfo add_output_qinfo,QuantizationInfo final_output_qinfo)199*c217d954SCole Faust     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo act_info,
200*c217d954SCole Faust                QuantizationInfo input1_qinfo, QuantizationInfo input2_qinfo, QuantizationInfo bn_mul_qinfo,
201*c217d954SCole Faust                QuantizationInfo bn_add_qinfo, QuantizationInfo add_output_qinfo, QuantizationInfo final_output_qinfo)
202*c217d954SCole Faust     {
203*c217d954SCole Faust         // Quantization arguments moved to class attributes to prevent long function declerations
204*c217d954SCole Faust         Parent::_input1_qinfo       = input1_qinfo;
205*c217d954SCole Faust         Parent::_input2_qinfo       = input2_qinfo;
206*c217d954SCole Faust         Parent::_bn_mul_qinfo       = bn_mul_qinfo;
207*c217d954SCole Faust         Parent::_bn_add_qinfo       = bn_add_qinfo;
208*c217d954SCole Faust         Parent::_add_output_qinfo   = add_output_qinfo;
209*c217d954SCole Faust         Parent::_final_output_qinfo = final_output_qinfo;
210*c217d954SCole Faust 
211*c217d954SCole Faust         Parent::setup(shape, data_type, act_info, interm_out);
212*c217d954SCole Faust         compute_reference(shape, data_type, act_info);
213*c217d954SCole Faust     }
214*c217d954SCole Faust 
215*c217d954SCole Faust     // Compute Reference is moved outside of the generic fixture because with the quantized data types,
216*c217d954SCole Faust     // it becomes a very different implementation with intermediate tensors' data types being always float.
217*c217d954SCole Faust     // This way the reference calculations are more readable and the size of the classes will be smaller
218*c217d954SCole Faust     // due to unrepeated fill() and target() methods.
compute_reference(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info)219*c217d954SCole Faust     void compute_reference(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info)
220*c217d954SCole Faust     {
221*c217d954SCole Faust         TensorShape b_shape(shape.x());
222*c217d954SCole Faust 
223*c217d954SCole Faust         // Create reference
224*c217d954SCole Faust         SimpleTensor<T> input1{ shape, data_type, 1, Parent::_input1_qinfo };
225*c217d954SCole Faust         SimpleTensor<T> input2{ shape, data_type, 1, Parent::_input2_qinfo };
226*c217d954SCole Faust         SimpleTensor<T> bn_mul{ b_shape, data_type, 1, Parent::_bn_mul_qinfo };
227*c217d954SCole Faust         SimpleTensor<T> bn_add{ b_shape, data_type, 1, Parent::_bn_add_qinfo };
228*c217d954SCole Faust 
229*c217d954SCole Faust         // Fill input tensors
230*c217d954SCole Faust         Parent::fill(input1, 0, data_type);
231*c217d954SCole Faust         Parent::fill(input2, 1, data_type);
232*c217d954SCole Faust         Parent::fill(bn_mul, 2, data_type);
233*c217d954SCole Faust         Parent::fill(bn_add, 3, data_type);
234*c217d954SCole Faust 
235*c217d954SCole Faust         SimpleTensor<float> input1_dequantized = reference::dequantization_layer<float>(input1);
236*c217d954SCole Faust         SimpleTensor<float> input2_dequantized = reference::dequantization_layer<float>(input2);
237*c217d954SCole Faust         SimpleTensor<float> bn_mul_dequantized = reference::dequantization_layer<float>(bn_mul);
238*c217d954SCole Faust         SimpleTensor<float> bn_add_dequantized = reference::dequantization_layer<float>(bn_add);
239*c217d954SCole Faust 
240*c217d954SCole Faust         SimpleTensor<float> add_output_dequantized{ shape, DataType::F32 };
241*c217d954SCole Faust         SimpleTensor<float> bn_add_out_dequantized{ shape, DataType::F32 };
242*c217d954SCole Faust 
243*c217d954SCole Faust         reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, input1_dequantized, input2_dequantized, add_output_dequantized, ConvertPolicy::SATURATE);
244*c217d954SCole Faust         SimpleTensor<float> bn_mul_out_dequantized = reference::pixel_wise_multiplication<float, float, float>(add_output_dequantized, bn_mul_dequantized, 1.f, ConvertPolicy::SATURATE,
245*c217d954SCole Faust                                                                                                                RoundingPolicy::TO_NEAREST_UP, DataType::F32);
246*c217d954SCole Faust         reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, bn_mul_out_dequantized, bn_add_dequantized, bn_add_out_dequantized, ConvertPolicy::SATURATE);
247*c217d954SCole Faust 
248*c217d954SCole Faust         if(interm_out)
249*c217d954SCole Faust         {
250*c217d954SCole Faust             Parent::_interm_reference = reference::quantization_layer<float, T>(add_output_dequantized, data_type, Parent::_add_output_qinfo);
251*c217d954SCole Faust         }
252*c217d954SCole Faust 
253*c217d954SCole Faust         if(act_info.enabled() && act_info.activation() != ActivationLayerInfo::ActivationFunction::IDENTITY)
254*c217d954SCole Faust         {
255*c217d954SCole Faust             SimpleTensor<T> ref = reference::quantization_layer<float, T>(bn_add_out_dequantized, data_type, Parent::_final_output_qinfo);
256*c217d954SCole Faust             Parent::_reference  = reference::activation_layer(ref, act_info);
257*c217d954SCole Faust         }
258*c217d954SCole Faust         else
259*c217d954SCole Faust         {
260*c217d954SCole Faust             Parent::_reference = reference::quantization_layer<float, T>(bn_add_out_dequantized, data_type, Parent::_final_output_qinfo);
261*c217d954SCole Faust         }
262*c217d954SCole Faust     }
263*c217d954SCole Faust };
264*c217d954SCole Faust } // namespace validation
265*c217d954SCole Faust } // namespace test
266*c217d954SCole Faust } // namespace arm_compute
267*c217d954SCole Faust 
268*c217d954SCole Faust #endif /* TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE */
269