xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/AddMulAddFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifndef TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE
26 #define TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE
27 
28 #include "arm_compute/core/TensorShape.h"
29 #include "arm_compute/core/Types.h"
30 #include "tests/AssetsLibrary.h"
31 #include "tests/Globals.h"
32 #include "tests/IAccessor.h"
33 #include "tests/framework/Asserts.h"
34 #include "tests/framework/Fixture.h"
35 #include "tests/validation/Helpers.h"
36 #include "tests/validation/reference/ActivationLayer.h"
37 #include "tests/validation/reference/ArithmeticOperations.h"
38 #include "tests/validation/reference/DequantizationLayer.h"
39 #include "tests/validation/reference/PixelWiseMultiplication.h"
40 #include "tests/validation/reference/QuantizationLayer.h"
41 
42 namespace arm_compute
43 {
44 namespace test
45 {
46 namespace validation
47 {
48 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
49 class AddMulAddGenericFixture : public framework::Fixture
50 {
51 public:
52     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info,bool interm_out)53     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info, bool interm_out)
54     {
55         compute_target(shape, data_type, act_info, interm_out);
56     }
57 
58 protected:
59     template <typename U>
fill(U && tensor,int i,DataType data_type)60     void fill(U &&tensor, int i, DataType data_type)
61     {
62         switch(data_type)
63         {
64             case DataType::F32:
65                 library->fill_tensor_uniform(tensor, i, -10.f, 10.f);
66                 break;
67             case DataType::F16:
68                 library->fill_tensor_uniform(tensor, i, -1.f, 1.f);
69                 break;
70             default:
71                 library->fill_tensor_uniform(tensor, i);
72                 break;
73         }
74     }
75 
compute_target(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info,bool interm_out)76     void compute_target(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info, bool interm_out)
77     {
78         TensorShape b_shape(shape.x());
79 
80         // Create tensors
81         TensorType input1       = create_tensor<TensorType>(shape, data_type, 1, _input1_qinfo);
82         TensorType input2       = create_tensor<TensorType>(shape, data_type, 1, _input2_qinfo);
83         TensorType bn_mul       = create_tensor<TensorType>(b_shape, data_type, 1, _bn_mul_qinfo);
84         TensorType bn_add       = create_tensor<TensorType>(b_shape, data_type, 1, _bn_add_qinfo);
85         TensorType add_output   = create_tensor<TensorType>(shape, data_type, 1, _add_output_qinfo);
86         TensorType final_output = create_tensor<TensorType>(shape, data_type, 1, _final_output_qinfo);
87 
88         // Create and configure function
89         FunctionType add_mul_add;
90         add_mul_add.configure(&input1, &input2, &bn_mul, &bn_add, interm_out ? &add_output : nullptr, &final_output, ConvertPolicy::SATURATE, act_info);
91 
92         // Allocate tensors
93         input1.allocator()->allocate();
94         input2.allocator()->allocate();
95         bn_mul.allocator()->allocate();
96         bn_add.allocator()->allocate();
97 
98         if(interm_out)
99         {
100             add_output.allocator()->allocate();
101         }
102 
103         final_output.allocator()->allocate();
104 
105         // Fill tensors
106         fill(AccessorType(input1), 0, data_type);
107         fill(AccessorType(input2), 1, data_type);
108         fill(AccessorType(bn_mul), 2, data_type);
109         fill(AccessorType(bn_add), 3, data_type);
110 
111         // // Compute function
112         add_mul_add.run();
113 
114         _target = std::move(final_output);
115 
116         if(interm_out)
117         {
118             _interm_target = std::move(add_output);
119         }
120     }
121 
122     TensorType      _target{};
123     TensorType      _interm_target{};
124     SimpleTensor<T> _reference{};
125     SimpleTensor<T> _interm_reference{};
126 
127     QuantizationInfo _input1_qinfo{};
128     QuantizationInfo _input2_qinfo{};
129     QuantizationInfo _bn_mul_qinfo{};
130     QuantizationInfo _bn_add_qinfo{};
131     QuantizationInfo _add_output_qinfo{};
132     QuantizationInfo _final_output_qinfo{};
133 };
134 
135 template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool interm_out>
136 class AddMulAddFloatValidationFixture : public AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>
137 {
138 public:
139     using Parent = AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>;
140 
141     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo act_info)142     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo act_info)
143     {
144         Parent::setup(shape, data_type, act_info, interm_out);
145         compute_reference(shape, data_type, act_info);
146     }
147 
148     // Compute Reference is moved outside of the generic fixture because with the quantized data types,
149     // it becomes a very different implementation with intermediate tensors' data types being always float.
150     // This way the reference calculations are more readable and the size of the classes will be smaller
151     // due to unrepeated fill() and target() methods.
compute_reference(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info)152     void compute_reference(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info)
153     {
154         TensorShape b_shape(shape.x());
155 
156         // Create reference
157         SimpleTensor<T> input1{ shape, data_type };
158         SimpleTensor<T> input2{ shape, data_type };
159         SimpleTensor<T> bn_mul{ b_shape, data_type };
160         SimpleTensor<T> bn_add{ b_shape, data_type };
161         SimpleTensor<T> add_output{ shape, data_type, 1 };
162 
163         SimpleTensor<T> bn_mul_out{ shape, data_type };
164         SimpleTensor<T> bn_add_out{ shape, data_type };
165 
166         // Fill reference
167         Parent::fill(input1, 0, data_type);
168         Parent::fill(input2, 1, data_type);
169         Parent::fill(bn_mul, 2, data_type);
170         Parent::fill(bn_add, 3, data_type);
171 
172         reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, input1, input2, add_output, ConvertPolicy::SATURATE);
173         bn_mul_out = reference::pixel_wise_multiplication<T, T, T>(add_output, bn_mul, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, data_type);
174         reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, bn_mul_out, bn_add, bn_add_out, ConvertPolicy::SATURATE);
175 
176         if(interm_out)
177         {
178             Parent::_interm_reference = std::move(add_output);
179         }
180 
181         if(act_info.enabled() && act_info.activation() != ActivationLayerInfo::ActivationFunction::IDENTITY)
182         {
183             Parent::_reference = reference::activation_layer(bn_add_out, act_info);
184         }
185         else
186         {
187             Parent::_reference = std::move(bn_add_out);
188         }
189     }
190 };
191 
192 template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool interm_out>
193 class AddMulAddQuantizedValidationFixture : public AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>
194 {
195 public:
196     using Parent = AddMulAddGenericFixture<TensorType, AccessorType, FunctionType, T>;
197 
198     template <typename...>
setup(const TensorShape & shape,DataType data_type,ActivationLayerInfo act_info,QuantizationInfo input1_qinfo,QuantizationInfo input2_qinfo,QuantizationInfo bn_mul_qinfo,QuantizationInfo bn_add_qinfo,QuantizationInfo add_output_qinfo,QuantizationInfo final_output_qinfo)199     void setup(const TensorShape &shape, DataType data_type, ActivationLayerInfo act_info,
200                QuantizationInfo input1_qinfo, QuantizationInfo input2_qinfo, QuantizationInfo bn_mul_qinfo,
201                QuantizationInfo bn_add_qinfo, QuantizationInfo add_output_qinfo, QuantizationInfo final_output_qinfo)
202     {
203         // Quantization arguments moved to class attributes to prevent long function declerations
204         Parent::_input1_qinfo       = input1_qinfo;
205         Parent::_input2_qinfo       = input2_qinfo;
206         Parent::_bn_mul_qinfo       = bn_mul_qinfo;
207         Parent::_bn_add_qinfo       = bn_add_qinfo;
208         Parent::_add_output_qinfo   = add_output_qinfo;
209         Parent::_final_output_qinfo = final_output_qinfo;
210 
211         Parent::setup(shape, data_type, act_info, interm_out);
212         compute_reference(shape, data_type, act_info);
213     }
214 
215     // Compute Reference is moved outside of the generic fixture because with the quantized data types,
216     // it becomes a very different implementation with intermediate tensors' data types being always float.
217     // This way the reference calculations are more readable and the size of the classes will be smaller
218     // due to unrepeated fill() and target() methods.
compute_reference(const TensorShape & shape,DataType data_type,ActivationLayerInfo & act_info)219     void compute_reference(const TensorShape &shape, DataType data_type, ActivationLayerInfo &act_info)
220     {
221         TensorShape b_shape(shape.x());
222 
223         // Create reference
224         SimpleTensor<T> input1{ shape, data_type, 1, Parent::_input1_qinfo };
225         SimpleTensor<T> input2{ shape, data_type, 1, Parent::_input2_qinfo };
226         SimpleTensor<T> bn_mul{ b_shape, data_type, 1, Parent::_bn_mul_qinfo };
227         SimpleTensor<T> bn_add{ b_shape, data_type, 1, Parent::_bn_add_qinfo };
228 
229         // Fill input tensors
230         Parent::fill(input1, 0, data_type);
231         Parent::fill(input2, 1, data_type);
232         Parent::fill(bn_mul, 2, data_type);
233         Parent::fill(bn_add, 3, data_type);
234 
235         SimpleTensor<float> input1_dequantized = reference::dequantization_layer<float>(input1);
236         SimpleTensor<float> input2_dequantized = reference::dequantization_layer<float>(input2);
237         SimpleTensor<float> bn_mul_dequantized = reference::dequantization_layer<float>(bn_mul);
238         SimpleTensor<float> bn_add_dequantized = reference::dequantization_layer<float>(bn_add);
239 
240         SimpleTensor<float> add_output_dequantized{ shape, DataType::F32 };
241         SimpleTensor<float> bn_add_out_dequantized{ shape, DataType::F32 };
242 
243         reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, input1_dequantized, input2_dequantized, add_output_dequantized, ConvertPolicy::SATURATE);
244         SimpleTensor<float> bn_mul_out_dequantized = reference::pixel_wise_multiplication<float, float, float>(add_output_dequantized, bn_mul_dequantized, 1.f, ConvertPolicy::SATURATE,
245                                                                                                                RoundingPolicy::TO_NEAREST_UP, DataType::F32);
246         reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, bn_mul_out_dequantized, bn_add_dequantized, bn_add_out_dequantized, ConvertPolicy::SATURATE);
247 
248         if(interm_out)
249         {
250             Parent::_interm_reference = reference::quantization_layer<float, T>(add_output_dequantized, data_type, Parent::_add_output_qinfo);
251         }
252 
253         if(act_info.enabled() && act_info.activation() != ActivationLayerInfo::ActivationFunction::IDENTITY)
254         {
255             SimpleTensor<T> ref = reference::quantization_layer<float, T>(bn_add_out_dequantized, data_type, Parent::_final_output_qinfo);
256             Parent::_reference  = reference::activation_layer(ref, act_info);
257         }
258         else
259         {
260             Parent::_reference = reference::quantization_layer<float, T>(bn_add_out_dequantized, data_type, Parent::_final_output_qinfo);
261         }
262     }
263 };
264 } // namespace validation
265 } // namespace test
266 } // namespace arm_compute
267 
268 #endif /* TESTS_VALIDATION_FIXTURES_ADDMULADDFIXTURE */
269