xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/GEMMLowpFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
25*c217d954SCole Faust #define ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
28*c217d954SCole Faust #include "tests/framework/Fixture.h"
29*c217d954SCole Faust #include "tests/validation/Validation.h"
30*c217d954SCole Faust #include "tests/validation/reference/GEMMLowp.h"
31*c217d954SCole Faust 
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace test
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace validation
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace
39*c217d954SCole Faust {
40*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)41*c217d954SCole Faust void fill(U &&tensor, int i)
42*c217d954SCole Faust {
43*c217d954SCole Faust     switch(tensor.data_type())
44*c217d954SCole Faust     {
45*c217d954SCole Faust         case DataType::QSYMM8_PER_CHANNEL:
46*c217d954SCole Faust         {
47*c217d954SCole Faust             int min_bound = 128;
48*c217d954SCole Faust             int max_bound = -127;
49*c217d954SCole Faust             for(size_t j = 0; j < tensor.quantization_info().scale().size(); j++)
50*c217d954SCole Faust             {
51*c217d954SCole Faust                 std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i);
52*c217d954SCole Faust                 if(bounds.first < min_bound)
53*c217d954SCole Faust                 {
54*c217d954SCole Faust                     min_bound = bounds.first;
55*c217d954SCole Faust                 }
56*c217d954SCole Faust                 if(bounds.second > max_bound)
57*c217d954SCole Faust                 {
58*c217d954SCole Faust                     max_bound = bounds.second;
59*c217d954SCole Faust                 }
60*c217d954SCole Faust             }
61*c217d954SCole Faust             std::uniform_int_distribution<int32_t> distribution(min_bound, max_bound);
62*c217d954SCole Faust             library->fill(tensor, distribution, i);
63*c217d954SCole Faust             break;
64*c217d954SCole Faust         }
65*c217d954SCole Faust         case DataType::QASYMM8:
66*c217d954SCole Faust         {
67*c217d954SCole Faust             std::uniform_int_distribution<uint32_t> distribution(1, 254);
68*c217d954SCole Faust             library->fill(tensor, distribution, i);
69*c217d954SCole Faust             break;
70*c217d954SCole Faust         }
71*c217d954SCole Faust         case DataType::S32:
72*c217d954SCole Faust         {
73*c217d954SCole Faust             std::uniform_int_distribution<int32_t> distribution(-20000, 20000);
74*c217d954SCole Faust             library->fill(tensor, distribution, i);
75*c217d954SCole Faust             break;
76*c217d954SCole Faust         }
77*c217d954SCole Faust         case DataType::F16:
78*c217d954SCole Faust         {
79*c217d954SCole Faust             arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
80*c217d954SCole Faust             library->fill(tensor, distribution, i);
81*c217d954SCole Faust             break;
82*c217d954SCole Faust         }
83*c217d954SCole Faust         case DataType::F32:
84*c217d954SCole Faust         {
85*c217d954SCole Faust             std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
86*c217d954SCole Faust             library->fill(tensor, distribution, i);
87*c217d954SCole Faust             break;
88*c217d954SCole Faust         }
89*c217d954SCole Faust         default:
90*c217d954SCole Faust             library->fill_tensor_uniform(tensor, i);
91*c217d954SCole Faust     }
92*c217d954SCole Faust }
93*c217d954SCole Faust 
94*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
95*c217d954SCole Faust TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
96*c217d954SCole Faust                                    GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
97*c217d954SCole Faust                                    QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
98*c217d954SCole Faust {
99*c217d954SCole Faust     // Create tensors
100*c217d954SCole Faust     DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
101*c217d954SCole Faust 
102*c217d954SCole Faust     TensorType a      = create_tensor<TensorType>(shape_a, data_type_a, 1);
103*c217d954SCole Faust     TensorType b      = create_tensor<TensorType>(shape_b, data_type_b, 1); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
104*c217d954SCole Faust     TensorType output = create_tensor<TensorType>(shape_output, data_type_output, 1);
105*c217d954SCole Faust 
106*c217d954SCole Faust     a.info()->set_quantization_info(QuantizationInfo(1.0f / 255, a_offset));
107*c217d954SCole Faust 
108*c217d954SCole Faust     if(data_type_b == DataType::QSYMM8_PER_CHANNEL)
109*c217d954SCole Faust     {
110*c217d954SCole Faust         b.info()->set_quantization_info(b_qinfo);
111*c217d954SCole Faust     }
112*c217d954SCole Faust     else
113*c217d954SCole Faust     {
114*c217d954SCole Faust         b.info()->set_quantization_info(QuantizationInfo(1.0f / 255, b_offset));
115*c217d954SCole Faust     }
116*c217d954SCole Faust     TensorType bias;
117*c217d954SCole Faust     if(is_fused)
118*c217d954SCole Faust     {
119*c217d954SCole Faust         TensorShape bias_shape(shape_b[0]);
120*c217d954SCole Faust         bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
121*c217d954SCole Faust     }
122*c217d954SCole Faust 
123*c217d954SCole Faust     // Create and configure function
124*c217d954SCole Faust     // The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
125*c217d954SCole Faust     FunctionType gemmlowp;
126*c217d954SCole Faust     gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
127*c217d954SCole Faust                                                                              output_stage));
128*c217d954SCole Faust 
129*c217d954SCole Faust     ARM_COMPUTE_ASSERT(a.info()->is_resizable());
130*c217d954SCole Faust     ARM_COMPUTE_ASSERT(b.info()->is_resizable());
131*c217d954SCole Faust     ARM_COMPUTE_ASSERT(output.info()->is_resizable());
132*c217d954SCole Faust 
133*c217d954SCole Faust     add_padding_x({ &a, &b, &output });
134*c217d954SCole Faust 
135*c217d954SCole Faust     // Allocate tensors
136*c217d954SCole Faust     a.allocator()->allocate();
137*c217d954SCole Faust     b.allocator()->allocate();
138*c217d954SCole Faust     output.allocator()->allocate();
139*c217d954SCole Faust 
140*c217d954SCole Faust     ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
141*c217d954SCole Faust     ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
142*c217d954SCole Faust     ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
143*c217d954SCole Faust 
144*c217d954SCole Faust     // Fill tensors
145*c217d954SCole Faust     fill(AccessorType(a), 0);
146*c217d954SCole Faust     fill(AccessorType(b), 1);
147*c217d954SCole Faust 
148*c217d954SCole Faust     if(is_fused)
149*c217d954SCole Faust     {
150*c217d954SCole Faust         ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
151*c217d954SCole Faust         bias.allocator()->allocate();
152*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
153*c217d954SCole Faust         fill(AccessorType(bias), 2);
154*c217d954SCole Faust     }
155*c217d954SCole Faust 
156*c217d954SCole Faust     // Run with variable inputs.
157*c217d954SCole Faust     if(run_twice)
158*c217d954SCole Faust     {
159*c217d954SCole Faust         gemmlowp.run();
160*c217d954SCole Faust         fill(AccessorType(a), 3); // Fill tensors with new seed after run
161*c217d954SCole Faust         fill(AccessorType(b), 4);
162*c217d954SCole Faust         if(is_fused)
163*c217d954SCole Faust         {
164*c217d954SCole Faust             fill(AccessorType(bias), 5);
165*c217d954SCole Faust         }
166*c217d954SCole Faust     }
167*c217d954SCole Faust 
168*c217d954SCole Faust     // Compute GEMM function
169*c217d954SCole Faust     gemmlowp.run();
170*c217d954SCole Faust     return output;
171*c217d954SCole Faust }
172*c217d954SCole Faust 
173*c217d954SCole Faust template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false>
174*c217d954SCole Faust SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
175*c217d954SCole Faust                                                  DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
176*c217d954SCole Faust {
177*c217d954SCole Faust     TensorShape shape_a_to_use = shape_a;
178*c217d954SCole Faust     if(reinterpret_input_as_3d)
179*c217d954SCole Faust     {
180*c217d954SCole Faust         // Collapse the second and third dimension if the input is 3D
181*c217d954SCole Faust         shape_a_to_use.collapse(2U, 1U);
182*c217d954SCole Faust     }
183*c217d954SCole Faust 
184*c217d954SCole Faust     // Create reference
185*c217d954SCole Faust     SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 };
186*c217d954SCole Faust     SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
187*c217d954SCole Faust 
188*c217d954SCole Faust     TensorShape shape_a_to_use_transposed{ shape_a_to_use };
189*c217d954SCole Faust     TensorShape shape_b_transposed{ shape_b };
190*c217d954SCole Faust 
191*c217d954SCole Faust     shape_a_to_use_transposed.set(0, shape_a_to_use[1]);
192*c217d954SCole Faust     shape_a_to_use_transposed.set(1, shape_a_to_use[0]);
193*c217d954SCole Faust     shape_b_transposed.set(0, shape_b[1]);
194*c217d954SCole Faust     shape_b_transposed.set(1, shape_b[0]);
195*c217d954SCole Faust 
196*c217d954SCole Faust     SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 };
197*c217d954SCole Faust     SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
198*c217d954SCole Faust 
199*c217d954SCole Faust     // Fill reference
200*c217d954SCole Faust     fill(a, 0);
201*c217d954SCole Faust     fill(b, 1);
202*c217d954SCole Faust 
203*c217d954SCole Faust     // Transpose reference if required
204*c217d954SCole Faust     /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
205*c217d954SCole Faust        therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
206*c217d954SCole Faust        in order to be able to call reference implementation that works with (B x M x K) input.
207*c217d954SCole Faust        Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
208*c217d954SCole Faust     if(pretranspose_A)
209*c217d954SCole Faust     {
210*c217d954SCole Faust         transpose_matrix<TI>(a, a_transposed);
211*c217d954SCole Faust     }
212*c217d954SCole Faust 
213*c217d954SCole Faust     if(pretranspose_B)
214*c217d954SCole Faust     {
215*c217d954SCole Faust         transpose_matrix<TW>(b, b_transposed);
216*c217d954SCole Faust     }
217*c217d954SCole Faust 
218*c217d954SCole Faust     // Run with variable inputs.
219*c217d954SCole Faust     if(run_twice)
220*c217d954SCole Faust     {
221*c217d954SCole Faust         reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
222*c217d954SCole Faust         fill((pretranspose_A) ? a_transposed : a, 3);
223*c217d954SCole Faust         fill((pretranspose_B) ? b_transposed : b, 4);
224*c217d954SCole Faust     }
225*c217d954SCole Faust 
226*c217d954SCole Faust     return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
227*c217d954SCole Faust }
228*c217d954SCole Faust }
229*c217d954SCole Faust 
230*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
231*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
232*c217d954SCole Faust {
233*c217d954SCole Faust public:
234*c217d954SCole Faust     template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset)235*c217d954SCole Faust     void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
236*c217d954SCole Faust     {
237*c217d954SCole Faust         _target    = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset);
238*c217d954SCole Faust         _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset);
239*c217d954SCole Faust     }
240*c217d954SCole Faust 
241*c217d954SCole Faust protected:
compute_target(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset)242*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
243*c217d954SCole Faust     {
244*c217d954SCole Faust         return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset,
245*c217d954SCole Faust                 b_offset);
246*c217d954SCole Faust     }
247*c217d954SCole Faust 
compute_reference(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset)248*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
249*c217d954SCole Faust     {
250*c217d954SCole Faust         return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
251*c217d954SCole Faust     }
252*c217d954SCole Faust 
253*c217d954SCole Faust     TensorType            _target{};
254*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
255*c217d954SCole Faust };
256*c217d954SCole Faust 
257*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
258*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
259*c217d954SCole Faust {
260*c217d954SCole Faust public:
261*c217d954SCole Faust     template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_b,bool reshape_b_only_on_first_run)262*c217d954SCole Faust     void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b,
263*c217d954SCole Faust                bool reshape_b_only_on_first_run)
264*c217d954SCole Faust     {
265*c217d954SCole Faust         ARM_COMPUTE_ASSERT(output_stage.type != GEMMLowpOutputStageType::NONE);
266*c217d954SCole Faust         DataType data_type_a = data_type_b == DataType::QASYMM8_SIGNED ? DataType::QASYMM8_SIGNED : DataType::QASYMM8;
267*c217d954SCole Faust 
268*c217d954SCole Faust         if(data_type_b == DataType::QSYMM8_PER_CHANNEL)
269*c217d954SCole Faust         {
270*c217d954SCole Faust             output_stage.is_quantized_per_channel              = true;
271*c217d954SCole Faust             const size_t                          num_channels = shape_b[0];
272*c217d954SCole Faust             std::vector<float>                    scales(num_channels);
273*c217d954SCole Faust             std::uniform_real_distribution<float> distribution(0.f, 1.f);
274*c217d954SCole Faust             library->fill(scales, distribution, 0);
275*c217d954SCole Faust             output_stage.gemmlowp_multipliers.resize(num_channels);
276*c217d954SCole Faust             output_stage.gemmlowp_shifts.resize(num_channels);
277*c217d954SCole Faust             for(size_t i = 0; i < num_channels; ++i)
278*c217d954SCole Faust             {
279*c217d954SCole Faust                 quantization::calculate_quantized_multiplier(scales[i], &output_stage.gemmlowp_multipliers[i], &output_stage.gemmlowp_shifts[i]);
280*c217d954SCole Faust             }
281*c217d954SCole Faust 
282*c217d954SCole Faust             _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales));
283*c217d954SCole Faust             _target    = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales), reshape_b_only_on_first_run);
284*c217d954SCole Faust         }
285*c217d954SCole Faust         else
286*c217d954SCole Faust         {
287*c217d954SCole Faust             _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo());
288*c217d954SCole Faust             _target    = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo(), reshape_b_only_on_first_run);
289*c217d954SCole Faust         }
290*c217d954SCole Faust     }
291*c217d954SCole Faust 
292*c217d954SCole Faust protected:
293*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage,
294*c217d954SCole Faust                               DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo, bool reshape_b_only_on_first_run = false)
295*c217d954SCole Faust     {
296*c217d954SCole Faust         return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, qasymm8_t, true, run_twice>(shape_a, shape_b, shape_output, a_offset,
297*c217d954SCole Faust                 b_offset,
298*c217d954SCole Faust                 output_stage, data_type_a, data_type_b, b_qinfo, reshape_b_only_on_first_run);
299*c217d954SCole Faust     }
300*c217d954SCole Faust 
compute_reference(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_a,DataType data_type_b,QuantizationInfo b_qinfo)301*c217d954SCole Faust     SimpleTensor<TI> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
302*c217d954SCole Faust                                        GEMMLowpOutputStageInfo output_stage, DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo)
303*c217d954SCole Faust     {
304*c217d954SCole Faust         SimpleTensor<int32_t> output = compute_gemmlowp_reference<reinterpret_input_as_3d, TI, TW, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset, data_type_a, data_type_b,
305*c217d954SCole Faust                                                                                                                             b_qinfo);
306*c217d954SCole Faust 
307*c217d954SCole Faust         TensorShape           bias_shape(shape_b[0]);
308*c217d954SCole Faust         SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
309*c217d954SCole Faust         (run_twice) ? fill(bias, 5) : fill(bias, 2); // Fill bias with same seed as last run of gemmlowp_target
310*c217d954SCole Faust 
311*c217d954SCole Faust         switch(output_stage.type)
312*c217d954SCole Faust         {
313*c217d954SCole Faust             case GEMMLowpOutputStageType::QUANTIZE_DOWN:
314*c217d954SCole Faust                 return reference::gemmlowp_quantize_down_scale<int32_t, TW>(output, bias,
315*c217d954SCole Faust                                                                             output_stage.gemmlowp_offset, output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
316*c217d954SCole Faust                 break;
317*c217d954SCole Faust             case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
318*c217d954SCole Faust                 return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TW>(output, bias,
319*c217d954SCole Faust                                                                                           output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
320*c217d954SCole Faust                 break;
321*c217d954SCole Faust             default:
322*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Not Supported!");
323*c217d954SCole Faust         }
324*c217d954SCole Faust     }
325*c217d954SCole Faust 
326*c217d954SCole Faust     TensorType       _target{};
327*c217d954SCole Faust     SimpleTensor<TI> _reference{};
328*c217d954SCole Faust };
329*c217d954SCole Faust 
330*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
331*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
332*c217d954SCole Faust     GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
333*c217d954SCole Faust {
334*c217d954SCole Faust public:
335*c217d954SCole Faust     template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_b)336*c217d954SCole Faust     void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b)
337*c217d954SCole Faust     {
338*c217d954SCole Faust         GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
339*c217d954SCole Faust                 shape_output, a_offset, b_offset, output_stage, data_type_b, false);
340*c217d954SCole Faust     }
341*c217d954SCole Faust };
342*c217d954SCole Faust 
343*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
344*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToUint8ScaleValidationFixture : public framework::Fixture
345*c217d954SCole Faust {
346*c217d954SCole Faust public:
347*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)348*c217d954SCole Faust     void setup(TensorShape shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
349*c217d954SCole Faust     {
350*c217d954SCole Faust         _target    = compute_target(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
351*c217d954SCole Faust         _reference = compute_reference(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
352*c217d954SCole Faust     }
353*c217d954SCole Faust 
354*c217d954SCole Faust protected:
355*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)356*c217d954SCole Faust     void fill(U &&tensor, int i)
357*c217d954SCole Faust     {
358*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-6000, 6000);
359*c217d954SCole Faust         library->fill(tensor, distribution, i);
360*c217d954SCole Faust     }
361*c217d954SCole Faust 
compute_target(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)362*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
363*c217d954SCole Faust     {
364*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
365*c217d954SCole Faust 
366*c217d954SCole Faust         // Create tensors
367*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
368*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
369*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8, 1);
370*c217d954SCole Faust 
371*c217d954SCole Faust         // Create and configure function
372*c217d954SCole Faust         FunctionType            output_stage;
373*c217d954SCole Faust         GEMMLowpOutputStageInfo output_stage_info = GEMMLowpOutputStageInfo();
374*c217d954SCole Faust         output_stage_info.type                    = GEMMLowpOutputStageType::QUANTIZE_DOWN;
375*c217d954SCole Faust         output_stage_info.gemmlowp_offset         = result_offset;
376*c217d954SCole Faust         output_stage_info.gemmlowp_multiplier     = result_mult_int;
377*c217d954SCole Faust         output_stage_info.gemmlowp_shift          = result_shift;
378*c217d954SCole Faust         output_stage_info.gemmlowp_min_bound      = min;
379*c217d954SCole Faust         output_stage_info.gemmlowp_max_bound      = max;
380*c217d954SCole Faust         output_stage_info.output_data_type        = DataType::QASYMM8;
381*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, output_stage_info);
382*c217d954SCole Faust 
383*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
384*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
385*c217d954SCole Faust 
386*c217d954SCole Faust         // Allocate tensors
387*c217d954SCole Faust         a.allocator()->allocate();
388*c217d954SCole Faust         c.allocator()->allocate();
389*c217d954SCole Faust 
390*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
391*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
392*c217d954SCole Faust 
393*c217d954SCole Faust         // Fill tensor
394*c217d954SCole Faust         fill(AccessorType(a), 0);
395*c217d954SCole Faust 
396*c217d954SCole Faust         if(add_bias)
397*c217d954SCole Faust         {
398*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
399*c217d954SCole Faust 
400*c217d954SCole Faust             // Allocate bias tensor
401*c217d954SCole Faust             b.allocator()->allocate();
402*c217d954SCole Faust 
403*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
404*c217d954SCole Faust 
405*c217d954SCole Faust             // Fill tensor
406*c217d954SCole Faust             fill(AccessorType(b), 1);
407*c217d954SCole Faust         }
408*c217d954SCole Faust 
409*c217d954SCole Faust         // Compute GEMM function
410*c217d954SCole Faust         output_stage.run();
411*c217d954SCole Faust         return c;
412*c217d954SCole Faust     }
413*c217d954SCole Faust 
compute_reference(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)414*c217d954SCole Faust     SimpleTensor<uint8_t> compute_reference(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
415*c217d954SCole Faust     {
416*c217d954SCole Faust         // Create reference
417*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
418*c217d954SCole Faust 
419*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
420*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
421*c217d954SCole Faust 
422*c217d954SCole Faust         // Fill reference
423*c217d954SCole Faust         fill(a, 0);
424*c217d954SCole Faust 
425*c217d954SCole Faust         const std::vector<int32_t> result_mult_int_vec = { result_mult_int };
426*c217d954SCole Faust         const std::vector<int32_t> result_shift_vec    = { result_shift };
427*c217d954SCole Faust 
428*c217d954SCole Faust         if(add_bias)
429*c217d954SCole Faust         {
430*c217d954SCole Faust             // Fill bias
431*c217d954SCole Faust             fill(b, 1);
432*c217d954SCole Faust 
433*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale<int32_t, uint8_t>(a, b, result_offset, result_mult_int_vec, result_shift_vec, min, max);
434*c217d954SCole Faust         }
435*c217d954SCole Faust         else
436*c217d954SCole Faust         {
437*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale<int32_t, uint8_t>(a, result_offset, result_mult_int_vec, result_shift_vec, min, max);
438*c217d954SCole Faust         }
439*c217d954SCole Faust     }
440*c217d954SCole Faust 
441*c217d954SCole Faust     TensorType            _target{};
442*c217d954SCole Faust     SimpleTensor<uint8_t> _reference{};
443*c217d954SCole Faust };
444*c217d954SCole Faust 
445*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
446*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt8ScaleValidationFixture : public framework::Fixture
447*c217d954SCole Faust {
448*c217d954SCole Faust public:
449*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)450*c217d954SCole Faust     void setup(TensorShape shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
451*c217d954SCole Faust     {
452*c217d954SCole Faust         _target    = compute_target(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
453*c217d954SCole Faust         _reference = compute_reference(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
454*c217d954SCole Faust     }
455*c217d954SCole Faust 
456*c217d954SCole Faust protected:
457*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)458*c217d954SCole Faust     void fill(U &&tensor, int i)
459*c217d954SCole Faust     {
460*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-6000, 6000);
461*c217d954SCole Faust         library->fill(tensor, distribution, i);
462*c217d954SCole Faust     }
463*c217d954SCole Faust 
compute_target(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)464*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
465*c217d954SCole Faust     {
466*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
467*c217d954SCole Faust 
468*c217d954SCole Faust         // Create tensors
469*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
470*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
471*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8_SIGNED, 1);
472*c217d954SCole Faust 
473*c217d954SCole Faust         // Create and configure function
474*c217d954SCole Faust         FunctionType            output_stage;
475*c217d954SCole Faust         GEMMLowpOutputStageInfo output_stage_info = GEMMLowpOutputStageInfo();
476*c217d954SCole Faust         output_stage_info.type                    = GEMMLowpOutputStageType::QUANTIZE_DOWN;
477*c217d954SCole Faust         output_stage_info.gemmlowp_offset         = result_offset;
478*c217d954SCole Faust         output_stage_info.gemmlowp_multiplier     = result_mult_int;
479*c217d954SCole Faust         output_stage_info.gemmlowp_shift          = result_shift;
480*c217d954SCole Faust         output_stage_info.gemmlowp_min_bound      = min;
481*c217d954SCole Faust         output_stage_info.gemmlowp_max_bound      = max;
482*c217d954SCole Faust         output_stage_info.output_data_type        = DataType::QASYMM8_SIGNED;
483*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, output_stage_info);
484*c217d954SCole Faust 
485*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
486*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
487*c217d954SCole Faust 
488*c217d954SCole Faust         // Allocate tensors
489*c217d954SCole Faust         a.allocator()->allocate();
490*c217d954SCole Faust         c.allocator()->allocate();
491*c217d954SCole Faust 
492*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
493*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
494*c217d954SCole Faust 
495*c217d954SCole Faust         // Fill tensor
496*c217d954SCole Faust         fill(AccessorType(a), 0);
497*c217d954SCole Faust 
498*c217d954SCole Faust         if(add_bias)
499*c217d954SCole Faust         {
500*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
501*c217d954SCole Faust 
502*c217d954SCole Faust             // Allocate bias tensor
503*c217d954SCole Faust             b.allocator()->allocate();
504*c217d954SCole Faust 
505*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
506*c217d954SCole Faust 
507*c217d954SCole Faust             // Fill tensor
508*c217d954SCole Faust             fill(AccessorType(b), 1);
509*c217d954SCole Faust         }
510*c217d954SCole Faust 
511*c217d954SCole Faust         // Compute GEMM function
512*c217d954SCole Faust         output_stage.run();
513*c217d954SCole Faust         return c;
514*c217d954SCole Faust     }
515*c217d954SCole Faust 
compute_reference(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)516*c217d954SCole Faust     SimpleTensor<int8_t> compute_reference(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
517*c217d954SCole Faust     {
518*c217d954SCole Faust         // Create reference
519*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
520*c217d954SCole Faust 
521*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
522*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
523*c217d954SCole Faust 
524*c217d954SCole Faust         // Fill reference
525*c217d954SCole Faust         fill(a, 0);
526*c217d954SCole Faust 
527*c217d954SCole Faust         const std::vector<int32_t> result_mult_int_vec = { result_mult_int };
528*c217d954SCole Faust         const std::vector<int32_t> result_shift_vec    = { result_shift };
529*c217d954SCole Faust 
530*c217d954SCole Faust         if(add_bias)
531*c217d954SCole Faust         {
532*c217d954SCole Faust             // Fill bias
533*c217d954SCole Faust             fill(b, 1);
534*c217d954SCole Faust 
535*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale<int32_t, int8_t>(a, b, result_offset, result_mult_int_vec, result_shift_vec, min, max);
536*c217d954SCole Faust         }
537*c217d954SCole Faust         else
538*c217d954SCole Faust         {
539*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale<int32_t, int8_t>(a, result_offset, result_mult_int_vec, result_shift_vec, min, max);
540*c217d954SCole Faust         }
541*c217d954SCole Faust     }
542*c217d954SCole Faust 
543*c217d954SCole Faust     TensorType           _target{};
544*c217d954SCole Faust     SimpleTensor<int8_t> _reference{};
545*c217d954SCole Faust };
546*c217d954SCole Faust 
547*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
548*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointValidationFixture : public framework::Fixture
549*c217d954SCole Faust {
550*c217d954SCole Faust public:
551*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)552*c217d954SCole Faust     void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
553*c217d954SCole Faust     {
554*c217d954SCole Faust         _target    = compute_target(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
555*c217d954SCole Faust         _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
556*c217d954SCole Faust     }
557*c217d954SCole Faust 
558*c217d954SCole Faust protected:
559*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)560*c217d954SCole Faust     void fill(U &&tensor, int i)
561*c217d954SCole Faust     {
562*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-6000, 6000);
563*c217d954SCole Faust         library->fill(tensor, distribution, i);
564*c217d954SCole Faust     }
565*c217d954SCole Faust 
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)566*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
567*c217d954SCole Faust     {
568*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
569*c217d954SCole Faust 
570*c217d954SCole Faust         // Create tensors
571*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
572*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
573*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8_SIGNED, 1);
574*c217d954SCole Faust 
575*c217d954SCole Faust         // Create and configure function
576*c217d954SCole Faust         FunctionType output_stage;
577*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
578*c217d954SCole Faust 
579*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
580*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
581*c217d954SCole Faust 
582*c217d954SCole Faust         // Allocate tensors
583*c217d954SCole Faust         a.allocator()->allocate();
584*c217d954SCole Faust         c.allocator()->allocate();
585*c217d954SCole Faust 
586*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
587*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
588*c217d954SCole Faust 
589*c217d954SCole Faust         // Fill tensor
590*c217d954SCole Faust         fill(AccessorType(a), 0);
591*c217d954SCole Faust 
592*c217d954SCole Faust         if(add_bias)
593*c217d954SCole Faust         {
594*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
595*c217d954SCole Faust 
596*c217d954SCole Faust             // Allocate bias tensor
597*c217d954SCole Faust             b.allocator()->allocate();
598*c217d954SCole Faust 
599*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
600*c217d954SCole Faust 
601*c217d954SCole Faust             // Fill tensor
602*c217d954SCole Faust             fill(AccessorType(b), 1);
603*c217d954SCole Faust         }
604*c217d954SCole Faust 
605*c217d954SCole Faust         // Compute GEMM function
606*c217d954SCole Faust         output_stage.run();
607*c217d954SCole Faust         return c;
608*c217d954SCole Faust     }
609*c217d954SCole Faust 
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)610*c217d954SCole Faust     SimpleTensor<int8_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max,
611*c217d954SCole Faust                                            bool add_bias)
612*c217d954SCole Faust     {
613*c217d954SCole Faust         // Create reference
614*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
615*c217d954SCole Faust 
616*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
617*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
618*c217d954SCole Faust 
619*c217d954SCole Faust         // Fill reference
620*c217d954SCole Faust         fill(a, 0);
621*c217d954SCole Faust 
622*c217d954SCole Faust         const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
623*c217d954SCole Faust         const std::vector<int32_t> result_shift_vec                  = { result_shift };
624*c217d954SCole Faust 
625*c217d954SCole Faust         if(add_bias)
626*c217d954SCole Faust         {
627*c217d954SCole Faust             // Fill bias
628*c217d954SCole Faust             fill(b, 1);
629*c217d954SCole Faust 
630*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int8_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
631*c217d954SCole Faust         }
632*c217d954SCole Faust         else
633*c217d954SCole Faust         {
634*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int8_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
635*c217d954SCole Faust         }
636*c217d954SCole Faust     }
637*c217d954SCole Faust 
638*c217d954SCole Faust     TensorType           _target{};
639*c217d954SCole Faust     SimpleTensor<int8_t> _reference{};
640*c217d954SCole Faust };
641*c217d954SCole Faust 
642*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
643*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointValidationFixture : public framework::Fixture
644*c217d954SCole Faust {
645*c217d954SCole Faust public:
646*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)647*c217d954SCole Faust     void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
648*c217d954SCole Faust     {
649*c217d954SCole Faust         _target    = compute_target(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
650*c217d954SCole Faust         _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
651*c217d954SCole Faust     }
652*c217d954SCole Faust 
653*c217d954SCole Faust protected:
654*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)655*c217d954SCole Faust     void fill(U &&tensor, int i)
656*c217d954SCole Faust     {
657*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-6000, 6000);
658*c217d954SCole Faust         library->fill(tensor, distribution, i);
659*c217d954SCole Faust     }
660*c217d954SCole Faust 
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)661*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
662*c217d954SCole Faust     {
663*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
664*c217d954SCole Faust 
665*c217d954SCole Faust         // Create tensors
666*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
667*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
668*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8, 1);
669*c217d954SCole Faust 
670*c217d954SCole Faust         // Create and configure function
671*c217d954SCole Faust         FunctionType output_stage;
672*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
673*c217d954SCole Faust 
674*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
675*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
676*c217d954SCole Faust 
677*c217d954SCole Faust         // Allocate tensors
678*c217d954SCole Faust         a.allocator()->allocate();
679*c217d954SCole Faust         c.allocator()->allocate();
680*c217d954SCole Faust 
681*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
682*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
683*c217d954SCole Faust 
684*c217d954SCole Faust         // Fill tensor
685*c217d954SCole Faust         fill(AccessorType(a), 0);
686*c217d954SCole Faust 
687*c217d954SCole Faust         if(add_bias)
688*c217d954SCole Faust         {
689*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
690*c217d954SCole Faust 
691*c217d954SCole Faust             // Allocate bias tensor
692*c217d954SCole Faust             b.allocator()->allocate();
693*c217d954SCole Faust 
694*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
695*c217d954SCole Faust 
696*c217d954SCole Faust             // Fill tensor
697*c217d954SCole Faust             fill(AccessorType(b), 1);
698*c217d954SCole Faust         }
699*c217d954SCole Faust 
700*c217d954SCole Faust         // Compute GEMM function
701*c217d954SCole Faust         output_stage.run();
702*c217d954SCole Faust         return c;
703*c217d954SCole Faust     }
704*c217d954SCole Faust 
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)705*c217d954SCole Faust     SimpleTensor<uint8_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max,
706*c217d954SCole Faust                                             bool add_bias)
707*c217d954SCole Faust     {
708*c217d954SCole Faust         // Create reference
709*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
710*c217d954SCole Faust 
711*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
712*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
713*c217d954SCole Faust 
714*c217d954SCole Faust         // Fill reference
715*c217d954SCole Faust         fill(a, 0);
716*c217d954SCole Faust 
717*c217d954SCole Faust         const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
718*c217d954SCole Faust         const std::vector<int32_t> result_shift_vec                  = { result_shift };
719*c217d954SCole Faust 
720*c217d954SCole Faust         if(add_bias)
721*c217d954SCole Faust         {
722*c217d954SCole Faust             // Fill bias
723*c217d954SCole Faust             fill(b, 1);
724*c217d954SCole Faust 
725*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
726*c217d954SCole Faust         }
727*c217d954SCole Faust         else
728*c217d954SCole Faust         {
729*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
730*c217d954SCole Faust         }
731*c217d954SCole Faust     }
732*c217d954SCole Faust 
733*c217d954SCole Faust     TensorType            _target{};
734*c217d954SCole Faust     SimpleTensor<uint8_t> _reference{};
735*c217d954SCole Faust };
736*c217d954SCole Faust 
737*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
738*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ScaleByFloatValidationFixture : public framework::Fixture
739*c217d954SCole Faust {
740*c217d954SCole Faust public:
741*c217d954SCole Faust     template <typename...>
setup(DataType data_type,TensorShape shape,float result_real_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)742*c217d954SCole Faust     void setup(DataType data_type, TensorShape shape, float result_real_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
743*c217d954SCole Faust     {
744*c217d954SCole Faust         _target    = compute_target(data_type, shape, result_real_multiplier, result_offset, min, max, add_bias);
745*c217d954SCole Faust         _reference = compute_reference(shape, result_real_multiplier, result_offset, min, max, add_bias);
746*c217d954SCole Faust     }
747*c217d954SCole Faust 
748*c217d954SCole Faust protected:
749*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)750*c217d954SCole Faust     void fill(U &&tensor, int i)
751*c217d954SCole Faust     {
752*c217d954SCole Faust         // To avoid data all being clampped
753*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-500, 500);
754*c217d954SCole Faust         library->fill(tensor, distribution, i);
755*c217d954SCole Faust     }
756*c217d954SCole Faust 
compute_target(DataType data_type,const TensorShape & shape,float result_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)757*c217d954SCole Faust     TensorType compute_target(DataType data_type, const TensorShape &shape, float result_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
758*c217d954SCole Faust     {
759*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
760*c217d954SCole Faust 
761*c217d954SCole Faust         // Create tensors
762*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
763*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
764*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, data_type, 1);
765*c217d954SCole Faust 
766*c217d954SCole Faust         // create output stage info
767*c217d954SCole Faust         GEMMLowpOutputStageInfo info;
768*c217d954SCole Faust         info.gemmlowp_max_bound       = max;
769*c217d954SCole Faust         info.gemmlowp_min_bound       = min;
770*c217d954SCole Faust         info.gemmlowp_real_multiplier = result_multiplier;
771*c217d954SCole Faust         info.gemmlowp_offset          = result_offset;
772*c217d954SCole Faust         info.type                     = GEMMLowpOutputStageType::QUANTIZE_DOWN_FLOAT;
773*c217d954SCole Faust         info.output_data_type         = data_type;
774*c217d954SCole Faust 
775*c217d954SCole Faust         // Create and configure function
776*c217d954SCole Faust         FunctionType output_stage;
777*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, info);
778*c217d954SCole Faust 
779*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
780*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
781*c217d954SCole Faust 
782*c217d954SCole Faust         // Allocate tensors
783*c217d954SCole Faust         a.allocator()->allocate();
784*c217d954SCole Faust         c.allocator()->allocate();
785*c217d954SCole Faust 
786*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
787*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
788*c217d954SCole Faust 
789*c217d954SCole Faust         // Fill tensor
790*c217d954SCole Faust         fill(AccessorType(a), 0);
791*c217d954SCole Faust 
792*c217d954SCole Faust         if(add_bias)
793*c217d954SCole Faust         {
794*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
795*c217d954SCole Faust 
796*c217d954SCole Faust             // Allocate bias tensor
797*c217d954SCole Faust             b.allocator()->allocate();
798*c217d954SCole Faust 
799*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
800*c217d954SCole Faust 
801*c217d954SCole Faust             // Fill tensor
802*c217d954SCole Faust             fill(AccessorType(b), 1);
803*c217d954SCole Faust         }
804*c217d954SCole Faust 
805*c217d954SCole Faust         // Compute GEMM function
806*c217d954SCole Faust         output_stage.run();
807*c217d954SCole Faust         return c;
808*c217d954SCole Faust     }
809*c217d954SCole Faust 
compute_reference(const TensorShape & shape,float_t result_real_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)810*c217d954SCole Faust     SimpleTensor<T> compute_reference(const TensorShape &shape, float_t result_real_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
811*c217d954SCole Faust     {
812*c217d954SCole Faust         // Create reference
813*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
814*c217d954SCole Faust 
815*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
816*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
817*c217d954SCole Faust 
818*c217d954SCole Faust         // Fill reference
819*c217d954SCole Faust         fill(a, 0);
820*c217d954SCole Faust 
821*c217d954SCole Faust         const std::vector<float_t> result_float_multiplier_vec = { result_real_multiplier };
822*c217d954SCole Faust 
823*c217d954SCole Faust         if(add_bias)
824*c217d954SCole Faust         {
825*c217d954SCole Faust             // Fill bias
826*c217d954SCole Faust             fill(b, 1);
827*c217d954SCole Faust 
828*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_float<int32_t, T>(a, b, result_float_multiplier_vec, result_offset, min, max);
829*c217d954SCole Faust         }
830*c217d954SCole Faust         else
831*c217d954SCole Faust         {
832*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_float<int32_t, T>(a, result_float_multiplier_vec, result_offset, min, max);
833*c217d954SCole Faust         }
834*c217d954SCole Faust     }
835*c217d954SCole Faust 
836*c217d954SCole Faust     TensorType      _target{};
837*c217d954SCole Faust     SimpleTensor<T> _reference{};
838*c217d954SCole Faust };
839*c217d954SCole Faust 
840*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
841*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture : public framework::Fixture
842*c217d954SCole Faust {
843*c217d954SCole Faust public:
844*c217d954SCole Faust     template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)845*c217d954SCole Faust     void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
846*c217d954SCole Faust     {
847*c217d954SCole Faust         _target    = compute_target(shape, result_fixedpoint_multiplier, result_shift, min, max, add_bias);
848*c217d954SCole Faust         _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, min, max, add_bias);
849*c217d954SCole Faust     }
850*c217d954SCole Faust 
851*c217d954SCole Faust protected:
852*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)853*c217d954SCole Faust     void fill(U &&tensor, int i)
854*c217d954SCole Faust     {
855*c217d954SCole Faust         std::uniform_int_distribution<> distribution(-6000, 6000);
856*c217d954SCole Faust         library->fill(tensor, distribution, i);
857*c217d954SCole Faust     }
858*c217d954SCole Faust 
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)859*c217d954SCole Faust     TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
860*c217d954SCole Faust     {
861*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
862*c217d954SCole Faust 
863*c217d954SCole Faust         // Create tensors
864*c217d954SCole Faust         TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
865*c217d954SCole Faust         TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
866*c217d954SCole Faust         TensorType c = create_tensor<TensorType>(shape, DataType::QSYMM16, 1);
867*c217d954SCole Faust 
868*c217d954SCole Faust         // Create and configure function
869*c217d954SCole Faust         FunctionType output_stage;
870*c217d954SCole Faust         output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, min, max);
871*c217d954SCole Faust 
872*c217d954SCole Faust         ARM_COMPUTE_ASSERT(a.info()->is_resizable());
873*c217d954SCole Faust         ARM_COMPUTE_ASSERT(c.info()->is_resizable());
874*c217d954SCole Faust 
875*c217d954SCole Faust         // Allocate tensors
876*c217d954SCole Faust         a.allocator()->allocate();
877*c217d954SCole Faust         c.allocator()->allocate();
878*c217d954SCole Faust 
879*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
880*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
881*c217d954SCole Faust 
882*c217d954SCole Faust         // Fill tensor
883*c217d954SCole Faust         fill(AccessorType(a), 0);
884*c217d954SCole Faust 
885*c217d954SCole Faust         if(add_bias)
886*c217d954SCole Faust         {
887*c217d954SCole Faust             ARM_COMPUTE_ASSERT(b.info()->is_resizable());
888*c217d954SCole Faust 
889*c217d954SCole Faust             // Allocate bias tensor
890*c217d954SCole Faust             b.allocator()->allocate();
891*c217d954SCole Faust 
892*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
893*c217d954SCole Faust 
894*c217d954SCole Faust             // Fill tensor
895*c217d954SCole Faust             fill(AccessorType(b), 1);
896*c217d954SCole Faust         }
897*c217d954SCole Faust 
898*c217d954SCole Faust         // Compute GEMM function
899*c217d954SCole Faust         output_stage.run();
900*c217d954SCole Faust         return c;
901*c217d954SCole Faust     }
902*c217d954SCole Faust 
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)903*c217d954SCole Faust     SimpleTensor<int16_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t min, int32_t max,
904*c217d954SCole Faust                                             bool add_bias)
905*c217d954SCole Faust     {
906*c217d954SCole Faust         // Create reference
907*c217d954SCole Faust         TensorShape shape_bias(shape[0]);
908*c217d954SCole Faust 
909*c217d954SCole Faust         SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
910*c217d954SCole Faust         SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
911*c217d954SCole Faust 
912*c217d954SCole Faust         // Fill reference
913*c217d954SCole Faust         fill(a, 0);
914*c217d954SCole Faust 
915*c217d954SCole Faust         const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
916*c217d954SCole Faust         const std::vector<int32_t> result_shift_vec                  = { result_shift };
917*c217d954SCole Faust 
918*c217d954SCole Faust         if(add_bias)
919*c217d954SCole Faust         {
920*c217d954SCole Faust             // Fill bias
921*c217d954SCole Faust             fill(b, 1);
922*c217d954SCole Faust 
923*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int16_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, 0, min, max);
924*c217d954SCole Faust         }
925*c217d954SCole Faust         else
926*c217d954SCole Faust         {
927*c217d954SCole Faust             return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int16_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, 0, min, max);
928*c217d954SCole Faust         }
929*c217d954SCole Faust     }
930*c217d954SCole Faust 
931*c217d954SCole Faust     TensorType            _target{};
932*c217d954SCole Faust     SimpleTensor<int16_t> _reference{};
933*c217d954SCole Faust };
934*c217d954SCole Faust 
935*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeLHSOperatorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
936*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedValidationFixture : public framework::Fixture
937*c217d954SCole Faust {
938*c217d954SCole Faust public:
939*c217d954SCole Faust     template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int v0,unsigned int h0,bool interleave_lhs,bool interleave_rhs,DataType data_type)940*c217d954SCole Faust     void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool interleave_lhs,
941*c217d954SCole Faust                bool interleave_rhs, DataType data_type)
942*c217d954SCole Faust     {
943*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
944*c217d954SCole Faust         lhs_info.m0         = m0;
945*c217d954SCole Faust         lhs_info.k0         = k0;
946*c217d954SCole Faust         lhs_info.v0         = v0;
947*c217d954SCole Faust         lhs_info.interleave = interleave_lhs;
948*c217d954SCole Faust         lhs_info.transpose  = false;
949*c217d954SCole Faust 
950*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
951*c217d954SCole Faust         rhs_info.n0         = n0;
952*c217d954SCole Faust         rhs_info.k0         = k0;
953*c217d954SCole Faust         rhs_info.h0         = h0;
954*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
955*c217d954SCole Faust         rhs_info.transpose  = true;
956*c217d954SCole Faust 
957*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
958*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
959*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
960*c217d954SCole Faust 
961*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
962*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape, data_type);
963*c217d954SCole Faust     }
964*c217d954SCole Faust 
965*c217d954SCole Faust protected:
966*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)967*c217d954SCole Faust     void fill(U &&tensor, int i)
968*c217d954SCole Faust     {
969*c217d954SCole Faust         switch(tensor.data_type())
970*c217d954SCole Faust         {
971*c217d954SCole Faust             case DataType::QASYMM8:
972*c217d954SCole Faust             {
973*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
974*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
975*c217d954SCole Faust                 library->fill(tensor, distribution, i);
976*c217d954SCole Faust             }
977*c217d954SCole Faust             break;
978*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
979*c217d954SCole Faust             {
980*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
981*c217d954SCole Faust                 library->fill(tensor, distribution, i);
982*c217d954SCole Faust             }
983*c217d954SCole Faust             break;
984*c217d954SCole Faust             default:
985*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
986*c217d954SCole Faust         }
987*c217d954SCole Faust     }
988*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)989*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
990*c217d954SCole Faust     {
991*c217d954SCole Faust         // Create tensors
992*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
993*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
994*c217d954SCole Faust         TensorType lhs_reshaped;
995*c217d954SCole Faust         TensorType rhs_reshaped;
996*c217d954SCole Faust         TensorType dst;
997*c217d954SCole Faust 
998*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
999*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1000*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1001*c217d954SCole Faust 
1002*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1003*c217d954SCole Faust 
1004*c217d954SCole Faust         // Create and configure function
1005*c217d954SCole Faust         ReshapeLHSOperatorType reshape_lhs;
1006*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1007*c217d954SCole Faust         GEMMFunctionType       gemm;
1008*c217d954SCole Faust         reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
1009*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1010*c217d954SCole Faust         gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
1011*c217d954SCole Faust 
1012*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1013*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1014*c217d954SCole Faust 
1015*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &lhs_reshaped, &rhs_reshaped, &dst });
1016*c217d954SCole Faust 
1017*c217d954SCole Faust         // Allocate tensors
1018*c217d954SCole Faust         lhs.allocator()->allocate();
1019*c217d954SCole Faust         rhs.allocator()->allocate();
1020*c217d954SCole Faust         lhs_reshaped.allocator()->allocate();
1021*c217d954SCole Faust         rhs_reshaped.allocator()->allocate();
1022*c217d954SCole Faust         dst.allocator()->allocate();
1023*c217d954SCole Faust 
1024*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1025*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1026*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs_reshaped.info()->is_resizable());
1027*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1028*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1029*c217d954SCole Faust 
1030*c217d954SCole Faust         // Fill tensors
1031*c217d954SCole Faust         fill(AccessorType(lhs), 0);
1032*c217d954SCole Faust         fill(AccessorType(rhs), 1);
1033*c217d954SCole Faust 
1034*c217d954SCole Faust         // Compute GEMM
1035*c217d954SCole Faust         ITensorPack reshape_lhs_pack = { { ACL_SRC, &lhs }, { ACL_DST, &lhs_reshaped } };
1036*c217d954SCole Faust         reshape_lhs.run(reshape_lhs_pack);
1037*c217d954SCole Faust         ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1038*c217d954SCole Faust         reshape_rhs.run(reshape_rhs_pack);
1039*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1040*c217d954SCole Faust         gemm.run(gemm_pack);
1041*c217d954SCole Faust 
1042*c217d954SCole Faust         return dst;
1043*c217d954SCole Faust     }
1044*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1045*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1046*c217d954SCole Faust     {
1047*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1048*c217d954SCole Faust         dst_shape[0]          = rhs_shape[0];
1049*c217d954SCole Faust         dst_shape[1]          = lhs_shape[1];
1050*c217d954SCole Faust 
1051*c217d954SCole Faust         switch(data_type)
1052*c217d954SCole Faust         {
1053*c217d954SCole Faust             case DataType::QASYMM8:
1054*c217d954SCole Faust             {
1055*c217d954SCole Faust                 // Create reference
1056*c217d954SCole Faust                 SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1057*c217d954SCole Faust                 SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1058*c217d954SCole Faust 
1059*c217d954SCole Faust                 // Fill reference
1060*c217d954SCole Faust                 fill(lhs, 0);
1061*c217d954SCole Faust                 fill(rhs, 1);
1062*c217d954SCole Faust 
1063*c217d954SCole Faust                 return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1064*c217d954SCole Faust             }
1065*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1066*c217d954SCole Faust             {
1067*c217d954SCole Faust                 // Create reference
1068*c217d954SCole Faust                 SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1069*c217d954SCole Faust                 SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1070*c217d954SCole Faust 
1071*c217d954SCole Faust                 // Fill reference
1072*c217d954SCole Faust                 fill(lhs, 0);
1073*c217d954SCole Faust                 fill(rhs, 1);
1074*c217d954SCole Faust 
1075*c217d954SCole Faust                 return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1076*c217d954SCole Faust             }
1077*c217d954SCole Faust             default:
1078*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1079*c217d954SCole Faust         }
1080*c217d954SCole Faust     }
1081*c217d954SCole Faust 
1082*c217d954SCole Faust     TensorType            _target{};
1083*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1084*c217d954SCole Faust };
1085*c217d954SCole Faust 
1086*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeLHSOperatorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1087*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
1088*c217d954SCole Faust {
1089*c217d954SCole Faust public:
1090*c217d954SCole Faust     template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int v0,unsigned int h0,bool interleave_lhs,bool interleave_rhs,DataType data_type)1091*c217d954SCole Faust     void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
1092*c217d954SCole Faust                bool interleave_lhs, bool interleave_rhs, DataType data_type)
1093*c217d954SCole Faust     {
1094*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1095*c217d954SCole Faust         lhs_info.m0         = m0;
1096*c217d954SCole Faust         lhs_info.k0         = k0;
1097*c217d954SCole Faust         lhs_info.v0         = v0;
1098*c217d954SCole Faust         lhs_info.interleave = interleave_lhs;
1099*c217d954SCole Faust         lhs_info.transpose  = false;
1100*c217d954SCole Faust 
1101*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1102*c217d954SCole Faust         rhs_info.n0         = n0;
1103*c217d954SCole Faust         rhs_info.k0         = k0;
1104*c217d954SCole Faust         rhs_info.h0         = h0;
1105*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
1106*c217d954SCole Faust         rhs_info.transpose  = true;
1107*c217d954SCole Faust 
1108*c217d954SCole Faust         // In case of GEMM3D, m is the product between m_w and m_h
1109*c217d954SCole Faust         const unsigned int m = m_w * m_h;
1110*c217d954SCole Faust 
1111*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1112*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1113*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1114*c217d954SCole Faust 
1115*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h, data_type);
1116*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape, m_h, data_type);
1117*c217d954SCole Faust     }
1118*c217d954SCole Faust 
1119*c217d954SCole Faust protected:
1120*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1121*c217d954SCole Faust     void fill(U &&tensor, int i)
1122*c217d954SCole Faust     {
1123*c217d954SCole Faust         switch(tensor.data_type())
1124*c217d954SCole Faust         {
1125*c217d954SCole Faust             case DataType::QASYMM8:
1126*c217d954SCole Faust             {
1127*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1128*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
1129*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1130*c217d954SCole Faust             }
1131*c217d954SCole Faust             break;
1132*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1133*c217d954SCole Faust             {
1134*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
1135*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1136*c217d954SCole Faust             }
1137*c217d954SCole Faust             break;
1138*c217d954SCole Faust             default:
1139*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1140*c217d954SCole Faust         }
1141*c217d954SCole Faust     }
1142*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h,DataType data_type)1143*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h,
1144*c217d954SCole Faust                               DataType data_type)
1145*c217d954SCole Faust     {
1146*c217d954SCole Faust         // Create tensors
1147*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1148*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1149*c217d954SCole Faust         TensorType lhs_reshaped;
1150*c217d954SCole Faust         TensorType rhs_reshaped;
1151*c217d954SCole Faust         TensorType dst;
1152*c217d954SCole Faust 
1153*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1154*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1155*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1156*c217d954SCole Faust 
1157*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1158*c217d954SCole Faust 
1159*c217d954SCole Faust         // Create and configure function
1160*c217d954SCole Faust         ReshapeLHSOperatorType reshape_lhs;
1161*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1162*c217d954SCole Faust         GEMMFunctionType       gemm;
1163*c217d954SCole Faust         reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
1164*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1165*c217d954SCole Faust         gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
1166*c217d954SCole Faust 
1167*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1168*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1169*c217d954SCole Faust 
1170*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &lhs_reshaped, &rhs_reshaped, &dst });
1171*c217d954SCole Faust 
1172*c217d954SCole Faust         // Allocate tensors
1173*c217d954SCole Faust         lhs.allocator()->allocate();
1174*c217d954SCole Faust         rhs.allocator()->allocate();
1175*c217d954SCole Faust         lhs_reshaped.allocator()->allocate();
1176*c217d954SCole Faust         rhs_reshaped.allocator()->allocate();
1177*c217d954SCole Faust         dst.allocator()->allocate();
1178*c217d954SCole Faust 
1179*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1180*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1181*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs_reshaped.info()->is_resizable());
1182*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1183*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1184*c217d954SCole Faust 
1185*c217d954SCole Faust         // Fill tensors
1186*c217d954SCole Faust         fill(AccessorType(lhs), 0);
1187*c217d954SCole Faust         fill(AccessorType(rhs), 1);
1188*c217d954SCole Faust 
1189*c217d954SCole Faust         // Compute GEMM
1190*c217d954SCole Faust         ITensorPack reshape_lhs_pack = { { ACL_SRC, &lhs }, { ACL_DST, &lhs_reshaped } };
1191*c217d954SCole Faust         reshape_lhs.run(reshape_lhs_pack);
1192*c217d954SCole Faust         ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1193*c217d954SCole Faust         reshape_rhs.run(reshape_rhs_pack);
1194*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1195*c217d954SCole Faust         gemm.run(gemm_pack);
1196*c217d954SCole Faust 
1197*c217d954SCole Faust         return dst;
1198*c217d954SCole Faust     }
1199*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h,DataType data_type)1200*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h, DataType data_type)
1201*c217d954SCole Faust     {
1202*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1203*c217d954SCole Faust         dst_shape.set(0, rhs_shape[0]);
1204*c217d954SCole Faust         dst_shape.set(1, lhs_shape[1] / m_h);
1205*c217d954SCole Faust         dst_shape.set(2, m_h);
1206*c217d954SCole Faust         dst_shape.set(3, lhs_shape[2]);
1207*c217d954SCole Faust 
1208*c217d954SCole Faust         switch(data_type)
1209*c217d954SCole Faust         {
1210*c217d954SCole Faust             case DataType::QASYMM8:
1211*c217d954SCole Faust             {
1212*c217d954SCole Faust                 // Create reference
1213*c217d954SCole Faust                 SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1214*c217d954SCole Faust                 SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1215*c217d954SCole Faust 
1216*c217d954SCole Faust                 // Fill reference
1217*c217d954SCole Faust                 fill(lhs, 0);
1218*c217d954SCole Faust                 fill(rhs, 1);
1219*c217d954SCole Faust 
1220*c217d954SCole Faust                 return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1221*c217d954SCole Faust             }
1222*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1223*c217d954SCole Faust             {
1224*c217d954SCole Faust                 // Create reference
1225*c217d954SCole Faust                 SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1226*c217d954SCole Faust                 SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1227*c217d954SCole Faust 
1228*c217d954SCole Faust                 // Fill reference
1229*c217d954SCole Faust                 fill(lhs, 0);
1230*c217d954SCole Faust                 fill(rhs, 1);
1231*c217d954SCole Faust 
1232*c217d954SCole Faust                 return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1233*c217d954SCole Faust             }
1234*c217d954SCole Faust             default:
1235*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1236*c217d954SCole Faust         }
1237*c217d954SCole Faust     }
1238*c217d954SCole Faust 
1239*c217d954SCole Faust     TensorType            _target{};
1240*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1241*c217d954SCole Faust };
1242*c217d954SCole Faust 
1243*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1244*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fixture
1245*c217d954SCole Faust {
1246*c217d954SCole Faust public:
1247*c217d954SCole Faust     template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1248*c217d954SCole Faust     void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1249*c217d954SCole Faust                unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1250*c217d954SCole Faust     {
1251*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1252*c217d954SCole Faust         lhs_info.m0 = m0;
1253*c217d954SCole Faust         lhs_info.k0 = k0;
1254*c217d954SCole Faust 
1255*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1256*c217d954SCole Faust         rhs_info.n0         = n0;
1257*c217d954SCole Faust         rhs_info.k0         = k0;
1258*c217d954SCole Faust         rhs_info.h0         = h0;
1259*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
1260*c217d954SCole Faust         rhs_info.transpose  = transpose_rhs;
1261*c217d954SCole Faust 
1262*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1263*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1264*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1265*c217d954SCole Faust 
1266*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
1267*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape, data_type);
1268*c217d954SCole Faust     }
1269*c217d954SCole Faust 
1270*c217d954SCole Faust protected:
1271*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1272*c217d954SCole Faust     void fill(U &&tensor, int i)
1273*c217d954SCole Faust     {
1274*c217d954SCole Faust         switch(tensor.data_type())
1275*c217d954SCole Faust         {
1276*c217d954SCole Faust             case DataType::QASYMM8:
1277*c217d954SCole Faust             {
1278*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1279*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
1280*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1281*c217d954SCole Faust             }
1282*c217d954SCole Faust             break;
1283*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1284*c217d954SCole Faust             {
1285*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
1286*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1287*c217d954SCole Faust             }
1288*c217d954SCole Faust             break;
1289*c217d954SCole Faust             default:
1290*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1291*c217d954SCole Faust         }
1292*c217d954SCole Faust     }
1293*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)1294*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1295*c217d954SCole Faust                               const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
1296*c217d954SCole Faust     {
1297*c217d954SCole Faust         // Create tensors
1298*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1299*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1300*c217d954SCole Faust         TensorType rhs_reshaped;
1301*c217d954SCole Faust         TensorType dst;
1302*c217d954SCole Faust 
1303*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1304*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1305*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1306*c217d954SCole Faust 
1307*c217d954SCole Faust         GEMMKernelInfo gemm_info;
1308*c217d954SCole Faust         gemm_info.m        = M;
1309*c217d954SCole Faust         gemm_info.n        = N;
1310*c217d954SCole Faust         gemm_info.k        = K;
1311*c217d954SCole Faust         gemm_info.lhs_info = lhs_info;
1312*c217d954SCole Faust         gemm_info.rhs_info = rhs_info;
1313*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1314*c217d954SCole Faust 
1315*c217d954SCole Faust         // Create and configure function
1316*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1317*c217d954SCole Faust         GEMMFunctionType       gemm;
1318*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1319*c217d954SCole Faust         gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info);
1320*c217d954SCole Faust 
1321*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1322*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1323*c217d954SCole Faust 
1324*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &rhs_reshaped, &dst });
1325*c217d954SCole Faust 
1326*c217d954SCole Faust         // Allocate tensors
1327*c217d954SCole Faust         lhs.allocator()->allocate();
1328*c217d954SCole Faust         rhs.allocator()->allocate();
1329*c217d954SCole Faust         rhs_reshaped.allocator()->allocate();
1330*c217d954SCole Faust         dst.allocator()->allocate();
1331*c217d954SCole Faust 
1332*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1333*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1334*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1335*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1336*c217d954SCole Faust 
1337*c217d954SCole Faust         // Fill tensors
1338*c217d954SCole Faust         fill(AccessorType(lhs), 0);
1339*c217d954SCole Faust         fill(AccessorType(rhs), 1);
1340*c217d954SCole Faust 
1341*c217d954SCole Faust         // Compute GEMM
1342*c217d954SCole Faust         ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1343*c217d954SCole Faust         reshape_rhs.run(reshape_rhs_pack);
1344*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1345*c217d954SCole Faust         gemm.run(gemm_pack);
1346*c217d954SCole Faust 
1347*c217d954SCole Faust         return dst;
1348*c217d954SCole Faust     }
1349*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1350*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1351*c217d954SCole Faust     {
1352*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1353*c217d954SCole Faust         dst_shape[0]          = rhs_shape[0];
1354*c217d954SCole Faust         dst_shape[1]          = lhs_shape[1];
1355*c217d954SCole Faust 
1356*c217d954SCole Faust         if(data_type == DataType::QASYMM8)
1357*c217d954SCole Faust         {
1358*c217d954SCole Faust             // Create reference
1359*c217d954SCole Faust             SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1360*c217d954SCole Faust             SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1361*c217d954SCole Faust 
1362*c217d954SCole Faust             // Fill reference
1363*c217d954SCole Faust             fill(lhs, 0);
1364*c217d954SCole Faust             fill(rhs, 1);
1365*c217d954SCole Faust 
1366*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1367*c217d954SCole Faust         }
1368*c217d954SCole Faust         else
1369*c217d954SCole Faust         {
1370*c217d954SCole Faust             // Create reference
1371*c217d954SCole Faust             SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1372*c217d954SCole Faust             SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1373*c217d954SCole Faust 
1374*c217d954SCole Faust             // Fill reference
1375*c217d954SCole Faust             fill(lhs, 0);
1376*c217d954SCole Faust             fill(rhs, 1);
1377*c217d954SCole Faust 
1378*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1379*c217d954SCole Faust         }
1380*c217d954SCole Faust     }
1381*c217d954SCole Faust 
1382*c217d954SCole Faust     TensorType            _target{};
1383*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1384*c217d954SCole Faust };
1385*c217d954SCole Faust 
1386*c217d954SCole Faust template <typename T, typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType, typename ReduceOperation, typename CastOperation>
1387*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture : public framework::Fixture
1388*c217d954SCole Faust {
1389*c217d954SCole Faust public:
1390*c217d954SCole Faust     template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,bool broadcast_bias,DataType data_type)1391*c217d954SCole Faust     void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1392*c217d954SCole Faust                unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, bool broadcast_bias, DataType data_type)
1393*c217d954SCole Faust     {
1394*c217d954SCole Faust         GEMMLowpOutputStageInfo output_stage;
1395*c217d954SCole Faust         output_stage.type                    = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
1396*c217d954SCole Faust         output_stage.output_data_type        = data_type;
1397*c217d954SCole Faust         output_stage.gemmlowp_multipliers    = std::vector<int32_t> { 1 };
1398*c217d954SCole Faust         output_stage.gemmlowp_shifts         = std::vector<int32_t> { 1 };
1399*c217d954SCole Faust         output_stage.gemmlowp_multipliers[0] = 1;
1400*c217d954SCole Faust         output_stage.gemmlowp_shifts[0]      = 1;
1401*c217d954SCole Faust         output_stage.gemmlowp_offset         = 0;
1402*c217d954SCole Faust         constexpr float scale                = 0.001f;
1403*c217d954SCole Faust         quantization::calculate_quantized_multiplier(scale, &output_stage.gemmlowp_multipliers[0], &output_stage.gemmlowp_shifts[0]);
1404*c217d954SCole Faust         output_stage.gemmlowp_min_bound = -100;
1405*c217d954SCole Faust         output_stage.gemmlowp_max_bound = 100;
1406*c217d954SCole Faust 
1407*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1408*c217d954SCole Faust         lhs_info.m0 = m0;
1409*c217d954SCole Faust         lhs_info.k0 = k0;
1410*c217d954SCole Faust 
1411*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1412*c217d954SCole Faust         rhs_info.n0         = n0;
1413*c217d954SCole Faust         rhs_info.k0         = k0;
1414*c217d954SCole Faust         rhs_info.h0         = h0;
1415*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
1416*c217d954SCole Faust         rhs_info.transpose  = transpose_rhs;
1417*c217d954SCole Faust 
1418*c217d954SCole Faust         int a_offset = 1;
1419*c217d954SCole Faust         int b_offset = 1;
1420*c217d954SCole Faust 
1421*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1422*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1423*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1424*c217d954SCole Faust         const TensorShape bias_shape(n,
1425*c217d954SCole Faust                                      broadcast_bias ? 1 : m,
1426*c217d954SCole Faust                                      broadcast_bias ? 1 : batch_size);
1427*c217d954SCole Faust 
1428*c217d954SCole Faust         _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
1429*c217d954SCole Faust         if(gemm_validated == true)
1430*c217d954SCole Faust         {
1431*c217d954SCole Faust             _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset);
1432*c217d954SCole Faust         }
1433*c217d954SCole Faust     }
1434*c217d954SCole Faust 
1435*c217d954SCole Faust protected:
1436*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1437*c217d954SCole Faust     void fill(U &&tensor, int i)
1438*c217d954SCole Faust     {
1439*c217d954SCole Faust         switch(tensor.data_type())
1440*c217d954SCole Faust         {
1441*c217d954SCole Faust             case DataType::QASYMM8:
1442*c217d954SCole Faust             {
1443*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1444*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
1445*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1446*c217d954SCole Faust             }
1447*c217d954SCole Faust             break;
1448*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1449*c217d954SCole Faust             {
1450*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
1451*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1452*c217d954SCole Faust             }
1453*c217d954SCole Faust             break;
1454*c217d954SCole Faust             case DataType::S32:
1455*c217d954SCole Faust             {
1456*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-10000, 10000);
1457*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1458*c217d954SCole Faust             }
1459*c217d954SCole Faust             break;
1460*c217d954SCole Faust             default:
1461*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1462*c217d954SCole Faust         }
1463*c217d954SCole Faust     }
1464*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & bias_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type,GEMMLowpOutputStageInfo output_stage,const int a_offset,const int b_offset)1465*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info,
1466*c217d954SCole Faust                               const GEMMRHSMatrixInfo &rhs_info, DataType data_type, GEMMLowpOutputStageInfo output_stage, const int a_offset, const int b_offset)
1467*c217d954SCole Faust     {
1468*c217d954SCole Faust         // Create tensors
1469*c217d954SCole Faust         TensorType lhs  = create_tensor<TensorType>(lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset));
1470*c217d954SCole Faust         TensorType rhs  = create_tensor<TensorType>(rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset));
1471*c217d954SCole Faust         TensorType bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
1472*c217d954SCole Faust         TensorType dst;
1473*c217d954SCole Faust         TensorType rhs_reshaped;
1474*c217d954SCole Faust 
1475*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1476*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1477*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1478*c217d954SCole Faust 
1479*c217d954SCole Faust         // Tensors for precomputing sum of lhs rows / rhs columns
1480*c217d954SCole Faust         TensorType vec_sum_rows = create_tensor<TensorType>(TensorShape(M, 1, lhs_shape[2]), DataType::S32, 1);
1481*c217d954SCole Faust         TensorType vec_sum_cols = create_tensor<TensorType>(TensorShape(N, 1, rhs_shape[2]), DataType::S32, 1);
1482*c217d954SCole Faust 
1483*c217d954SCole Faust         GEMMKernelInfo gemm_info;
1484*c217d954SCole Faust         gemm_info.m            = M;
1485*c217d954SCole Faust         gemm_info.n            = N;
1486*c217d954SCole Faust         gemm_info.k            = K;
1487*c217d954SCole Faust         gemm_info.lhs_info     = lhs_info;
1488*c217d954SCole Faust         gemm_info.rhs_info     = rhs_info;
1489*c217d954SCole Faust         gemm_info.output_stage = output_stage;
1490*c217d954SCole Faust         gemm_info.a_offset     = a_offset;
1491*c217d954SCole Faust         gemm_info.b_offset     = b_offset;
1492*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1493*c217d954SCole Faust 
1494*c217d954SCole Faust         // Create and configure function
1495*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1496*c217d954SCole Faust         GEMMFunctionType       gemm;
1497*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1498*c217d954SCole Faust 
1499*c217d954SCole Faust         // If GEMM is not validated, do not try to run. The validation will check
1500*c217d954SCole Faust         // if the technology supports this extension. If not, the test will be skipped.
1501*c217d954SCole Faust         // If it supports, the test will fail anyway because target and reference
1502*c217d954SCole Faust         // will not match.
1503*c217d954SCole Faust         gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info()));
1504*c217d954SCole Faust         if(gemm_validated == true)
1505*c217d954SCole Faust         {
1506*c217d954SCole Faust             gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info());
1507*c217d954SCole Faust 
1508*c217d954SCole Faust             ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1509*c217d954SCole Faust             ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1510*c217d954SCole Faust             ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
1511*c217d954SCole Faust 
1512*c217d954SCole Faust             // Allocate tensors
1513*c217d954SCole Faust             lhs.allocator()->allocate();
1514*c217d954SCole Faust             rhs.allocator()->allocate();
1515*c217d954SCole Faust             rhs_reshaped.allocator()->allocate();
1516*c217d954SCole Faust             bias.allocator()->allocate();
1517*c217d954SCole Faust             vec_sum_cols.allocator()->allocate();
1518*c217d954SCole Faust             vec_sum_rows.allocator()->allocate();
1519*c217d954SCole Faust             dst.allocator()->allocate();
1520*c217d954SCole Faust 
1521*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1522*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1523*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1524*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
1525*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!vec_sum_cols.info()->is_resizable());
1526*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!vec_sum_rows.info()->is_resizable());
1527*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1528*c217d954SCole Faust 
1529*c217d954SCole Faust             // Fill tensors
1530*c217d954SCole Faust             fill(AccessorType(lhs), 0);
1531*c217d954SCole Faust             fill(AccessorType(rhs), 1);
1532*c217d954SCole Faust             fill(AccessorType(bias), 2);
1533*c217d954SCole Faust 
1534*c217d954SCole Faust             TensorType    lhs_32 = create_tensor<TensorType>(lhs_shape, DataType::S32, 1);
1535*c217d954SCole Faust             TensorType    rhs_32 = create_tensor<TensorType>(rhs_shape, DataType::S32, 1);
1536*c217d954SCole Faust             CastOperation cast_lhs;
1537*c217d954SCole Faust             CastOperation cast_rhs;
1538*c217d954SCole Faust             cast_lhs.configure(&lhs, &lhs_32, ConvertPolicy::SATURATE);
1539*c217d954SCole Faust             cast_rhs.configure(&rhs, &rhs_32, ConvertPolicy::SATURATE);
1540*c217d954SCole Faust             lhs_32.allocator()->allocate();
1541*c217d954SCole Faust             rhs_32.allocator()->allocate();
1542*c217d954SCole Faust             cast_lhs.run();
1543*c217d954SCole Faust             cast_rhs.run();
1544*c217d954SCole Faust 
1545*c217d954SCole Faust             ReduceOperation lhs_sum_rows;
1546*c217d954SCole Faust             ReduceOperation rhs_sum_cols;
1547*c217d954SCole Faust 
1548*c217d954SCole Faust             lhs_sum_rows.configure(&lhs_32, &vec_sum_rows, 0, ReductionOperation::SUM, false);
1549*c217d954SCole Faust             rhs_sum_cols.configure(&rhs_32, &vec_sum_cols, 1, ReductionOperation::SUM);
1550*c217d954SCole Faust 
1551*c217d954SCole Faust             lhs_sum_rows.run();
1552*c217d954SCole Faust             rhs_sum_cols.run();
1553*c217d954SCole Faust 
1554*c217d954SCole Faust             // Compute GEMM
1555*c217d954SCole Faust             ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1556*c217d954SCole Faust             reshape_rhs.run(reshape_rhs_pack);
1557*c217d954SCole Faust             ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_SRC_2, &bias }, { ACL_DST, &dst }, { ACL_VEC_COL_SUM, &vec_sum_cols }, { ACL_VEC_ROW_SUM, &vec_sum_rows } });
1558*c217d954SCole Faust             gemm.run(gemm_pack);
1559*c217d954SCole Faust         }
1560*c217d954SCole Faust 
1561*c217d954SCole Faust         return dst;
1562*c217d954SCole Faust     }
1563*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & bias_shape,DataType data_type,GEMMLowpOutputStageInfo output_stage,const int a_offset,const int b_offset)1564*c217d954SCole Faust     SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, GEMMLowpOutputStageInfo output_stage,
1565*c217d954SCole Faust                                       const int a_offset, const int b_offset)
1566*c217d954SCole Faust     {
1567*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1568*c217d954SCole Faust         dst_shape[0]          = rhs_shape[0];
1569*c217d954SCole Faust         dst_shape[1]          = lhs_shape[1];
1570*c217d954SCole Faust 
1571*c217d954SCole Faust         // Create reference
1572*c217d954SCole Faust         SimpleTensor<T>       lhs{ lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset) };
1573*c217d954SCole Faust         SimpleTensor<T>       rhs{ rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset) };
1574*c217d954SCole Faust         SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
1575*c217d954SCole Faust         SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1576*c217d954SCole Faust         SimpleTensor<T>       dst_final{ dst_shape, data_type, 1 };
1577*c217d954SCole Faust 
1578*c217d954SCole Faust         // Fill reference
1579*c217d954SCole Faust         fill(lhs, 0);
1580*c217d954SCole Faust         fill(rhs, 1);
1581*c217d954SCole Faust         fill(bias, 2);
1582*c217d954SCole Faust 
1583*c217d954SCole Faust         dst       = reference::gemmlowp_matrix_multiply_core<int32_t, T>(lhs, rhs, dst_shape, a_offset, b_offset);
1584*c217d954SCole Faust         dst_final = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, T>(dst, bias,
1585*c217d954SCole Faust                                                                                       output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
1586*c217d954SCole Faust         return dst_final;
1587*c217d954SCole Faust     }
1588*c217d954SCole Faust 
1589*c217d954SCole Faust     bool            gemm_validated = true;
1590*c217d954SCole Faust     TensorType      _target{};
1591*c217d954SCole Faust     SimpleTensor<T> _reference{};
1592*c217d954SCole Faust };
1593*c217d954SCole Faust 
1594*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1595*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULValidationFixture : public framework::Fixture
1596*c217d954SCole Faust {
1597*c217d954SCole Faust public:
1598*c217d954SCole Faust     template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1599*c217d954SCole Faust     void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1600*c217d954SCole Faust                unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1601*c217d954SCole Faust     {
1602*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1603*c217d954SCole Faust         lhs_info.m0 = m0;
1604*c217d954SCole Faust         lhs_info.k0 = k0;
1605*c217d954SCole Faust 
1606*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1607*c217d954SCole Faust         rhs_info.n0         = n0;
1608*c217d954SCole Faust         rhs_info.k0         = k0;
1609*c217d954SCole Faust         rhs_info.h0         = h0;
1610*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
1611*c217d954SCole Faust         rhs_info.transpose  = transpose_rhs;
1612*c217d954SCole Faust 
1613*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1614*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1615*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1616*c217d954SCole Faust 
1617*c217d954SCole Faust         _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
1618*c217d954SCole Faust         if(gemm_validated == true)
1619*c217d954SCole Faust         {
1620*c217d954SCole Faust             _reference = compute_reference(lhs_shape, rhs_shape, data_type);
1621*c217d954SCole Faust         }
1622*c217d954SCole Faust     }
1623*c217d954SCole Faust 
1624*c217d954SCole Faust protected:
1625*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1626*c217d954SCole Faust     void fill(U &&tensor, int i)
1627*c217d954SCole Faust     {
1628*c217d954SCole Faust         switch(tensor.data_type())
1629*c217d954SCole Faust         {
1630*c217d954SCole Faust             case DataType::QASYMM8:
1631*c217d954SCole Faust             {
1632*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1633*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
1634*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1635*c217d954SCole Faust             }
1636*c217d954SCole Faust             break;
1637*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1638*c217d954SCole Faust             {
1639*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
1640*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1641*c217d954SCole Faust             }
1642*c217d954SCole Faust             break;
1643*c217d954SCole Faust             default:
1644*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1645*c217d954SCole Faust         }
1646*c217d954SCole Faust     }
1647*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)1648*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1649*c217d954SCole Faust                               const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
1650*c217d954SCole Faust     {
1651*c217d954SCole Faust         // Create tensors
1652*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1653*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1654*c217d954SCole Faust         TensorType rhs_reshaped;
1655*c217d954SCole Faust         TensorType dst;
1656*c217d954SCole Faust 
1657*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1658*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1659*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1660*c217d954SCole Faust 
1661*c217d954SCole Faust         GEMMKernelInfo gemm_info;
1662*c217d954SCole Faust         gemm_info.m        = M;
1663*c217d954SCole Faust         gemm_info.n        = N;
1664*c217d954SCole Faust         gemm_info.k        = K;
1665*c217d954SCole Faust         gemm_info.lhs_info = lhs_info;
1666*c217d954SCole Faust         gemm_info.rhs_info = rhs_info;
1667*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1668*c217d954SCole Faust 
1669*c217d954SCole Faust         // Create and configure function
1670*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1671*c217d954SCole Faust         GEMMFunctionType       gemm;
1672*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1673*c217d954SCole Faust 
1674*c217d954SCole Faust         // If GEMM is not validated, do not try to run. The validation will check
1675*c217d954SCole Faust         // if the technology supports this extension. If not, the test will be skipped.
1676*c217d954SCole Faust         // If it supports, the test will fail anyway because target and reference
1677*c217d954SCole Faust         // will not match.
1678*c217d954SCole Faust         gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr));
1679*c217d954SCole Faust         if(gemm_validated == true)
1680*c217d954SCole Faust         {
1681*c217d954SCole Faust             gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr);
1682*c217d954SCole Faust 
1683*c217d954SCole Faust             ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1684*c217d954SCole Faust             ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1685*c217d954SCole Faust 
1686*c217d954SCole Faust             // Allocate tensors
1687*c217d954SCole Faust             lhs.allocator()->allocate();
1688*c217d954SCole Faust             rhs.allocator()->allocate();
1689*c217d954SCole Faust             rhs_reshaped.allocator()->allocate();
1690*c217d954SCole Faust             dst.allocator()->allocate();
1691*c217d954SCole Faust 
1692*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1693*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1694*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1695*c217d954SCole Faust             ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1696*c217d954SCole Faust 
1697*c217d954SCole Faust             // Fill tensors
1698*c217d954SCole Faust             fill(AccessorType(lhs), 0);
1699*c217d954SCole Faust             fill(AccessorType(rhs), 1);
1700*c217d954SCole Faust 
1701*c217d954SCole Faust             // Compute GEMM
1702*c217d954SCole Faust             ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1703*c217d954SCole Faust             reshape_rhs.run(reshape_rhs_pack);
1704*c217d954SCole Faust             ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1705*c217d954SCole Faust             gemm.run(gemm_pack);
1706*c217d954SCole Faust         }
1707*c217d954SCole Faust 
1708*c217d954SCole Faust         return dst;
1709*c217d954SCole Faust     }
1710*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1711*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1712*c217d954SCole Faust     {
1713*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1714*c217d954SCole Faust         dst_shape[0]          = rhs_shape[0];
1715*c217d954SCole Faust         dst_shape[1]          = lhs_shape[1];
1716*c217d954SCole Faust 
1717*c217d954SCole Faust         if(data_type == DataType::QASYMM8)
1718*c217d954SCole Faust         {
1719*c217d954SCole Faust             // Create reference
1720*c217d954SCole Faust             SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1721*c217d954SCole Faust             SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1722*c217d954SCole Faust             SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1723*c217d954SCole Faust 
1724*c217d954SCole Faust             // Fill reference
1725*c217d954SCole Faust             fill(lhs, 0);
1726*c217d954SCole Faust             fill(rhs, 1);
1727*c217d954SCole Faust 
1728*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1729*c217d954SCole Faust         }
1730*c217d954SCole Faust         else
1731*c217d954SCole Faust         {
1732*c217d954SCole Faust             // Create reference
1733*c217d954SCole Faust             SimpleTensor<int8_t>  lhs{ lhs_shape, data_type, 1 };
1734*c217d954SCole Faust             SimpleTensor<int8_t>  rhs{ rhs_shape, data_type, 1 };
1735*c217d954SCole Faust             SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1736*c217d954SCole Faust 
1737*c217d954SCole Faust             // Fill reference
1738*c217d954SCole Faust             fill(lhs, 0);
1739*c217d954SCole Faust             fill(rhs, 1);
1740*c217d954SCole Faust 
1741*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1742*c217d954SCole Faust         }
1743*c217d954SCole Faust     }
1744*c217d954SCole Faust 
1745*c217d954SCole Faust     bool                  gemm_validated = true;
1746*c217d954SCole Faust     TensorType            _target{};
1747*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1748*c217d954SCole Faust };
1749*c217d954SCole Faust 
1750*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1751*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
1752*c217d954SCole Faust {
1753*c217d954SCole Faust public:
1754*c217d954SCole Faust     template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1755*c217d954SCole Faust     void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1756*c217d954SCole Faust                unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1757*c217d954SCole Faust     {
1758*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1759*c217d954SCole Faust         lhs_info.m0 = m0;
1760*c217d954SCole Faust         lhs_info.k0 = k0;
1761*c217d954SCole Faust 
1762*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1763*c217d954SCole Faust         rhs_info.n0         = n0;
1764*c217d954SCole Faust         rhs_info.k0         = k0;
1765*c217d954SCole Faust         rhs_info.h0         = h0;
1766*c217d954SCole Faust         rhs_info.interleave = interleave_rhs;
1767*c217d954SCole Faust         rhs_info.transpose  = transpose_rhs;
1768*c217d954SCole Faust 
1769*c217d954SCole Faust         // In case of GEMM3D, m is the product between m_w and m_h
1770*c217d954SCole Faust         const unsigned int m = m_w * m_h;
1771*c217d954SCole Faust 
1772*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1773*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1774*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1775*c217d954SCole Faust 
1776*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h, data_type);
1777*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape, m_h, data_type);
1778*c217d954SCole Faust     }
1779*c217d954SCole Faust 
1780*c217d954SCole Faust protected:
1781*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1782*c217d954SCole Faust     void fill(U &&tensor, int i)
1783*c217d954SCole Faust     {
1784*c217d954SCole Faust         switch(tensor.data_type())
1785*c217d954SCole Faust         {
1786*c217d954SCole Faust             case DataType::QASYMM8:
1787*c217d954SCole Faust             {
1788*c217d954SCole Faust                 // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1789*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(1, 254);
1790*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1791*c217d954SCole Faust             }
1792*c217d954SCole Faust             break;
1793*c217d954SCole Faust             case DataType::QASYMM8_SIGNED:
1794*c217d954SCole Faust             {
1795*c217d954SCole Faust                 std::uniform_int_distribution<> distribution(-127, 126);
1796*c217d954SCole Faust                 library->fill(tensor, distribution, i);
1797*c217d954SCole Faust             }
1798*c217d954SCole Faust             break;
1799*c217d954SCole Faust             default:
1800*c217d954SCole Faust                 ARM_COMPUTE_ERROR("Unsupported data type");
1801*c217d954SCole Faust         }
1802*c217d954SCole Faust     }
1803*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h,DataType data_type)1804*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1805*c217d954SCole Faust                               const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h, DataType data_type)
1806*c217d954SCole Faust     {
1807*c217d954SCole Faust         // Create tensors
1808*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1809*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1810*c217d954SCole Faust         TensorType rhs_reshaped;
1811*c217d954SCole Faust         TensorType dst;
1812*c217d954SCole Faust 
1813*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1814*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1815*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1816*c217d954SCole Faust 
1817*c217d954SCole Faust         GEMMKernelInfo gemm_info;
1818*c217d954SCole Faust         gemm_info.m                   = M;
1819*c217d954SCole Faust         gemm_info.n                   = N;
1820*c217d954SCole Faust         gemm_info.k                   = K;
1821*c217d954SCole Faust         gemm_info.depth_output_gemm3d = m_h;
1822*c217d954SCole Faust         gemm_info.lhs_info            = lhs_info;
1823*c217d954SCole Faust         gemm_info.rhs_info            = rhs_info;
1824*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1825*c217d954SCole Faust 
1826*c217d954SCole Faust         // Create and configure function
1827*c217d954SCole Faust         ReshapeRHSOperatorType reshape_rhs;
1828*c217d954SCole Faust         GEMMFunctionType       gemm;
1829*c217d954SCole Faust         reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1830*c217d954SCole Faust         gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info);
1831*c217d954SCole Faust 
1832*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1833*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1834*c217d954SCole Faust 
1835*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &rhs_reshaped, &dst });
1836*c217d954SCole Faust 
1837*c217d954SCole Faust         // Allocate tensors
1838*c217d954SCole Faust         lhs.allocator()->allocate();
1839*c217d954SCole Faust         rhs.allocator()->allocate();
1840*c217d954SCole Faust         rhs_reshaped.allocator()->allocate();
1841*c217d954SCole Faust         dst.allocator()->allocate();
1842*c217d954SCole Faust 
1843*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1844*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1845*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1846*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1847*c217d954SCole Faust 
1848*c217d954SCole Faust         // Fill tensors
1849*c217d954SCole Faust         fill(AccessorType(lhs), 0);
1850*c217d954SCole Faust         fill(AccessorType(rhs), 1);
1851*c217d954SCole Faust 
1852*c217d954SCole Faust         // Compute GEMM
1853*c217d954SCole Faust         ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1854*c217d954SCole Faust         reshape_rhs.run(reshape_rhs_pack);
1855*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1856*c217d954SCole Faust         gemm.run(gemm_pack);
1857*c217d954SCole Faust 
1858*c217d954SCole Faust         return dst;
1859*c217d954SCole Faust     }
1860*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h,DataType data_type)1861*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h, DataType data_type)
1862*c217d954SCole Faust     {
1863*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1864*c217d954SCole Faust         dst_shape.set(0, rhs_shape[0]);
1865*c217d954SCole Faust         dst_shape.set(1, lhs_shape[1] / m_h);
1866*c217d954SCole Faust         dst_shape.set(2, m_h);
1867*c217d954SCole Faust         dst_shape.set(3, lhs_shape[2]);
1868*c217d954SCole Faust 
1869*c217d954SCole Faust         if(data_type == DataType::QASYMM8)
1870*c217d954SCole Faust         {
1871*c217d954SCole Faust             // Create reference
1872*c217d954SCole Faust             SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1873*c217d954SCole Faust             SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1874*c217d954SCole Faust 
1875*c217d954SCole Faust             // Fill reference
1876*c217d954SCole Faust             fill(lhs, 0);
1877*c217d954SCole Faust             fill(rhs, 1);
1878*c217d954SCole Faust 
1879*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1880*c217d954SCole Faust         }
1881*c217d954SCole Faust         else
1882*c217d954SCole Faust         {
1883*c217d954SCole Faust             // Create reference
1884*c217d954SCole Faust             SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1885*c217d954SCole Faust             SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1886*c217d954SCole Faust 
1887*c217d954SCole Faust             // Fill reference
1888*c217d954SCole Faust             fill(lhs, 0);
1889*c217d954SCole Faust             fill(rhs, 1);
1890*c217d954SCole Faust 
1891*c217d954SCole Faust             return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1892*c217d954SCole Faust         }
1893*c217d954SCole Faust     }
1894*c217d954SCole Faust 
1895*c217d954SCole Faust     TensorType            _target{};
1896*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1897*c217d954SCole Faust };
1898*c217d954SCole Faust 
1899*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename GEMMFunctionType>
1900*c217d954SCole Faust class GEMMLowpMatrixMultiplyNativeValidationFixture : public framework::Fixture
1901*c217d954SCole Faust {
1902*c217d954SCole Faust public:
1903*c217d954SCole Faust     template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0)1904*c217d954SCole Faust     void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0)
1905*c217d954SCole Faust     {
1906*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
1907*c217d954SCole Faust         lhs_info.m0 = m0;
1908*c217d954SCole Faust         lhs_info.k0 = k0;
1909*c217d954SCole Faust 
1910*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
1911*c217d954SCole Faust         rhs_info.n0 = n0;
1912*c217d954SCole Faust         rhs_info.k0 = k0;
1913*c217d954SCole Faust 
1914*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
1915*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
1916*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
1917*c217d954SCole Faust 
1918*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info);
1919*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape);
1920*c217d954SCole Faust     }
1921*c217d954SCole Faust 
1922*c217d954SCole Faust protected:
1923*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)1924*c217d954SCole Faust     void fill(U &&tensor, int i)
1925*c217d954SCole Faust     {
1926*c217d954SCole Faust         // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1927*c217d954SCole Faust         std::uniform_int_distribution<> distribution(1, 254);
1928*c217d954SCole Faust         library->fill(tensor, distribution, i);
1929*c217d954SCole Faust     }
1930*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info)1931*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
1932*c217d954SCole Faust     {
1933*c217d954SCole Faust         // Create tensors
1934*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
1935*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
1936*c217d954SCole Faust         TensorType dst;
1937*c217d954SCole Faust 
1938*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
1939*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
1940*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
1941*c217d954SCole Faust 
1942*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
1943*c217d954SCole Faust 
1944*c217d954SCole Faust         // Create and configure function
1945*c217d954SCole Faust         GEMMFunctionType gemm;
1946*c217d954SCole Faust         gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
1947*c217d954SCole Faust 
1948*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1949*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1950*c217d954SCole Faust 
1951*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &dst });
1952*c217d954SCole Faust 
1953*c217d954SCole Faust         // Allocate tensors
1954*c217d954SCole Faust         lhs.allocator()->allocate();
1955*c217d954SCole Faust         rhs.allocator()->allocate();
1956*c217d954SCole Faust         dst.allocator()->allocate();
1957*c217d954SCole Faust 
1958*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1959*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1960*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1961*c217d954SCole Faust 
1962*c217d954SCole Faust         // Fill tensors
1963*c217d954SCole Faust         fill(AccessorType(lhs), 0);
1964*c217d954SCole Faust         fill(AccessorType(rhs), 1);
1965*c217d954SCole Faust 
1966*c217d954SCole Faust         // Compute GEMM
1967*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } });
1968*c217d954SCole Faust         gemm.run(gemm_pack);
1969*c217d954SCole Faust 
1970*c217d954SCole Faust         return dst;
1971*c217d954SCole Faust     }
1972*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape)1973*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape)
1974*c217d954SCole Faust     {
1975*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
1976*c217d954SCole Faust         dst_shape[0]          = rhs_shape[0];
1977*c217d954SCole Faust         dst_shape[1]          = lhs_shape[1];
1978*c217d954SCole Faust 
1979*c217d954SCole Faust         // Create reference
1980*c217d954SCole Faust         SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
1981*c217d954SCole Faust         SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
1982*c217d954SCole Faust 
1983*c217d954SCole Faust         // Fill reference
1984*c217d954SCole Faust         fill(lhs, 0);
1985*c217d954SCole Faust         fill(rhs, 1);
1986*c217d954SCole Faust 
1987*c217d954SCole Faust         return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1988*c217d954SCole Faust     }
1989*c217d954SCole Faust 
1990*c217d954SCole Faust     TensorType            _target{};
1991*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
1992*c217d954SCole Faust };
1993*c217d954SCole Faust 
1994*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename GEMMFunctionType>
1995*c217d954SCole Faust class GEMMLowpMatrixMultiplyNative3DValidationFixture : public framework::Fixture
1996*c217d954SCole Faust {
1997*c217d954SCole Faust public:
1998*c217d954SCole Faust     template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0)1999*c217d954SCole Faust     void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0)
2000*c217d954SCole Faust     {
2001*c217d954SCole Faust         GEMMLHSMatrixInfo lhs_info;
2002*c217d954SCole Faust         lhs_info.m0 = m0;
2003*c217d954SCole Faust         lhs_info.k0 = k0;
2004*c217d954SCole Faust 
2005*c217d954SCole Faust         GEMMRHSMatrixInfo rhs_info;
2006*c217d954SCole Faust         rhs_info.n0 = n0;
2007*c217d954SCole Faust         rhs_info.k0 = k0;
2008*c217d954SCole Faust 
2009*c217d954SCole Faust         // In case of GEMM3D, m is the product between m_w and m_h
2010*c217d954SCole Faust         const unsigned int m = m_w * m_h;
2011*c217d954SCole Faust 
2012*c217d954SCole Faust         // Set the tensor shapes for LHS and RHS matrices
2013*c217d954SCole Faust         const TensorShape lhs_shape(k, m, batch_size);
2014*c217d954SCole Faust         const TensorShape rhs_shape(n, k, batch_size);
2015*c217d954SCole Faust 
2016*c217d954SCole Faust         _target    = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h);
2017*c217d954SCole Faust         _reference = compute_reference(lhs_shape, rhs_shape, m_h);
2018*c217d954SCole Faust     }
2019*c217d954SCole Faust 
2020*c217d954SCole Faust protected:
2021*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)2022*c217d954SCole Faust     void fill(U &&tensor, int i)
2023*c217d954SCole Faust     {
2024*c217d954SCole Faust         // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
2025*c217d954SCole Faust         std::uniform_int_distribution<> distribution(1, 254);
2026*c217d954SCole Faust         library->fill(tensor, distribution, i);
2027*c217d954SCole Faust     }
2028*c217d954SCole Faust 
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h)2029*c217d954SCole Faust     TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h)
2030*c217d954SCole Faust     {
2031*c217d954SCole Faust         // Create tensors
2032*c217d954SCole Faust         TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
2033*c217d954SCole Faust         TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
2034*c217d954SCole Faust         TensorType dst;
2035*c217d954SCole Faust 
2036*c217d954SCole Faust         const unsigned int M = lhs_shape[1];
2037*c217d954SCole Faust         const unsigned int N = rhs_shape[0];
2038*c217d954SCole Faust         const unsigned int K = lhs_shape[0];
2039*c217d954SCole Faust 
2040*c217d954SCole Faust         // The output tensor will be auto-initialized within the function
2041*c217d954SCole Faust 
2042*c217d954SCole Faust         // Create and configure function
2043*c217d954SCole Faust         GEMMFunctionType gemm;
2044*c217d954SCole Faust         gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
2045*c217d954SCole Faust 
2046*c217d954SCole Faust         ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
2047*c217d954SCole Faust         ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
2048*c217d954SCole Faust 
2049*c217d954SCole Faust         add_padding_x({ &lhs, &rhs, &dst });
2050*c217d954SCole Faust 
2051*c217d954SCole Faust         // Allocate tensors
2052*c217d954SCole Faust         lhs.allocator()->allocate();
2053*c217d954SCole Faust         rhs.allocator()->allocate();
2054*c217d954SCole Faust         dst.allocator()->allocate();
2055*c217d954SCole Faust 
2056*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
2057*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
2058*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
2059*c217d954SCole Faust 
2060*c217d954SCole Faust         // Fill tensors
2061*c217d954SCole Faust         fill(AccessorType(lhs), 0);
2062*c217d954SCole Faust         fill(AccessorType(rhs), 1);
2063*c217d954SCole Faust 
2064*c217d954SCole Faust         // Compute GEMM
2065*c217d954SCole Faust         ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } });
2066*c217d954SCole Faust         gemm.run(gemm_pack);
2067*c217d954SCole Faust 
2068*c217d954SCole Faust         return dst;
2069*c217d954SCole Faust     }
2070*c217d954SCole Faust 
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h)2071*c217d954SCole Faust     SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h)
2072*c217d954SCole Faust     {
2073*c217d954SCole Faust         TensorShape dst_shape = lhs_shape;
2074*c217d954SCole Faust         dst_shape.set(0, rhs_shape[0]);
2075*c217d954SCole Faust         dst_shape.set(1, lhs_shape[1] / m_h);
2076*c217d954SCole Faust         dst_shape.set(2, m_h);
2077*c217d954SCole Faust         dst_shape.set(3, lhs_shape[2]);
2078*c217d954SCole Faust 
2079*c217d954SCole Faust         // Create reference
2080*c217d954SCole Faust         SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
2081*c217d954SCole Faust         SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
2082*c217d954SCole Faust 
2083*c217d954SCole Faust         // Fill reference
2084*c217d954SCole Faust         fill(lhs, 0);
2085*c217d954SCole Faust         fill(rhs, 1);
2086*c217d954SCole Faust 
2087*c217d954SCole Faust         return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
2088*c217d954SCole Faust     }
2089*c217d954SCole Faust 
2090*c217d954SCole Faust     TensorType            _target{};
2091*c217d954SCole Faust     SimpleTensor<int32_t> _reference{};
2092*c217d954SCole Faust };
2093*c217d954SCole Faust } // namespace validation
2094*c217d954SCole Faust } // namespace test
2095*c217d954SCole Faust } // namespace arm_compute
2096*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE */
2097