1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
25*c217d954SCole Faust #define ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
26*c217d954SCole Faust
27*c217d954SCole Faust #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
28*c217d954SCole Faust #include "tests/framework/Fixture.h"
29*c217d954SCole Faust #include "tests/validation/Validation.h"
30*c217d954SCole Faust #include "tests/validation/reference/GEMMLowp.h"
31*c217d954SCole Faust
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace test
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace validation
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace
39*c217d954SCole Faust {
40*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)41*c217d954SCole Faust void fill(U &&tensor, int i)
42*c217d954SCole Faust {
43*c217d954SCole Faust switch(tensor.data_type())
44*c217d954SCole Faust {
45*c217d954SCole Faust case DataType::QSYMM8_PER_CHANNEL:
46*c217d954SCole Faust {
47*c217d954SCole Faust int min_bound = 128;
48*c217d954SCole Faust int max_bound = -127;
49*c217d954SCole Faust for(size_t j = 0; j < tensor.quantization_info().scale().size(); j++)
50*c217d954SCole Faust {
51*c217d954SCole Faust std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i);
52*c217d954SCole Faust if(bounds.first < min_bound)
53*c217d954SCole Faust {
54*c217d954SCole Faust min_bound = bounds.first;
55*c217d954SCole Faust }
56*c217d954SCole Faust if(bounds.second > max_bound)
57*c217d954SCole Faust {
58*c217d954SCole Faust max_bound = bounds.second;
59*c217d954SCole Faust }
60*c217d954SCole Faust }
61*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(min_bound, max_bound);
62*c217d954SCole Faust library->fill(tensor, distribution, i);
63*c217d954SCole Faust break;
64*c217d954SCole Faust }
65*c217d954SCole Faust case DataType::QASYMM8:
66*c217d954SCole Faust {
67*c217d954SCole Faust std::uniform_int_distribution<uint32_t> distribution(1, 254);
68*c217d954SCole Faust library->fill(tensor, distribution, i);
69*c217d954SCole Faust break;
70*c217d954SCole Faust }
71*c217d954SCole Faust case DataType::S32:
72*c217d954SCole Faust {
73*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(-20000, 20000);
74*c217d954SCole Faust library->fill(tensor, distribution, i);
75*c217d954SCole Faust break;
76*c217d954SCole Faust }
77*c217d954SCole Faust case DataType::F16:
78*c217d954SCole Faust {
79*c217d954SCole Faust arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
80*c217d954SCole Faust library->fill(tensor, distribution, i);
81*c217d954SCole Faust break;
82*c217d954SCole Faust }
83*c217d954SCole Faust case DataType::F32:
84*c217d954SCole Faust {
85*c217d954SCole Faust std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
86*c217d954SCole Faust library->fill(tensor, distribution, i);
87*c217d954SCole Faust break;
88*c217d954SCole Faust }
89*c217d954SCole Faust default:
90*c217d954SCole Faust library->fill_tensor_uniform(tensor, i);
91*c217d954SCole Faust }
92*c217d954SCole Faust }
93*c217d954SCole Faust
94*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
95*c217d954SCole Faust TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
96*c217d954SCole Faust GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
97*c217d954SCole Faust QuantizationInfo b_qinfo = QuantizationInfo(), bool reshape_b_only_on_first_run = false)
98*c217d954SCole Faust {
99*c217d954SCole Faust // Create tensors
100*c217d954SCole Faust DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
101*c217d954SCole Faust
102*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape_a, data_type_a, 1);
103*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_b, data_type_b, 1); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
104*c217d954SCole Faust TensorType output = create_tensor<TensorType>(shape_output, data_type_output, 1);
105*c217d954SCole Faust
106*c217d954SCole Faust a.info()->set_quantization_info(QuantizationInfo(1.0f / 255, a_offset));
107*c217d954SCole Faust
108*c217d954SCole Faust if(data_type_b == DataType::QSYMM8_PER_CHANNEL)
109*c217d954SCole Faust {
110*c217d954SCole Faust b.info()->set_quantization_info(b_qinfo);
111*c217d954SCole Faust }
112*c217d954SCole Faust else
113*c217d954SCole Faust {
114*c217d954SCole Faust b.info()->set_quantization_info(QuantizationInfo(1.0f / 255, b_offset));
115*c217d954SCole Faust }
116*c217d954SCole Faust TensorType bias;
117*c217d954SCole Faust if(is_fused)
118*c217d954SCole Faust {
119*c217d954SCole Faust TensorShape bias_shape(shape_b[0]);
120*c217d954SCole Faust bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
121*c217d954SCole Faust }
122*c217d954SCole Faust
123*c217d954SCole Faust // Create and configure function
124*c217d954SCole Faust // The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
125*c217d954SCole Faust FunctionType gemmlowp;
126*c217d954SCole Faust gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
127*c217d954SCole Faust output_stage));
128*c217d954SCole Faust
129*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
130*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
131*c217d954SCole Faust ARM_COMPUTE_ASSERT(output.info()->is_resizable());
132*c217d954SCole Faust
133*c217d954SCole Faust add_padding_x({ &a, &b, &output });
134*c217d954SCole Faust
135*c217d954SCole Faust // Allocate tensors
136*c217d954SCole Faust a.allocator()->allocate();
137*c217d954SCole Faust b.allocator()->allocate();
138*c217d954SCole Faust output.allocator()->allocate();
139*c217d954SCole Faust
140*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
141*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
142*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
143*c217d954SCole Faust
144*c217d954SCole Faust // Fill tensors
145*c217d954SCole Faust fill(AccessorType(a), 0);
146*c217d954SCole Faust fill(AccessorType(b), 1);
147*c217d954SCole Faust
148*c217d954SCole Faust if(is_fused)
149*c217d954SCole Faust {
150*c217d954SCole Faust ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
151*c217d954SCole Faust bias.allocator()->allocate();
152*c217d954SCole Faust ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
153*c217d954SCole Faust fill(AccessorType(bias), 2);
154*c217d954SCole Faust }
155*c217d954SCole Faust
156*c217d954SCole Faust // Run with variable inputs.
157*c217d954SCole Faust if(run_twice)
158*c217d954SCole Faust {
159*c217d954SCole Faust gemmlowp.run();
160*c217d954SCole Faust fill(AccessorType(a), 3); // Fill tensors with new seed after run
161*c217d954SCole Faust fill(AccessorType(b), 4);
162*c217d954SCole Faust if(is_fused)
163*c217d954SCole Faust {
164*c217d954SCole Faust fill(AccessorType(bias), 5);
165*c217d954SCole Faust }
166*c217d954SCole Faust }
167*c217d954SCole Faust
168*c217d954SCole Faust // Compute GEMM function
169*c217d954SCole Faust gemmlowp.run();
170*c217d954SCole Faust return output;
171*c217d954SCole Faust }
172*c217d954SCole Faust
173*c217d954SCole Faust template <bool reinterpret_input_as_3d, typename TI = uint8_t, typename TW = uint8_t, bool pretranspose_A = false, bool pretranspose_B = false, bool run_twice = false>
174*c217d954SCole Faust SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
175*c217d954SCole Faust DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, QuantizationInfo b_qinfo = QuantizationInfo())
176*c217d954SCole Faust {
177*c217d954SCole Faust TensorShape shape_a_to_use = shape_a;
178*c217d954SCole Faust if(reinterpret_input_as_3d)
179*c217d954SCole Faust {
180*c217d954SCole Faust // Collapse the second and third dimension if the input is 3D
181*c217d954SCole Faust shape_a_to_use.collapse(2U, 1U);
182*c217d954SCole Faust }
183*c217d954SCole Faust
184*c217d954SCole Faust // Create reference
185*c217d954SCole Faust SimpleTensor<TI> a{ shape_a_to_use, data_type_a, 1 };
186*c217d954SCole Faust SimpleTensor<TW> b{ shape_b, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
187*c217d954SCole Faust
188*c217d954SCole Faust TensorShape shape_a_to_use_transposed{ shape_a_to_use };
189*c217d954SCole Faust TensorShape shape_b_transposed{ shape_b };
190*c217d954SCole Faust
191*c217d954SCole Faust shape_a_to_use_transposed.set(0, shape_a_to_use[1]);
192*c217d954SCole Faust shape_a_to_use_transposed.set(1, shape_a_to_use[0]);
193*c217d954SCole Faust shape_b_transposed.set(0, shape_b[1]);
194*c217d954SCole Faust shape_b_transposed.set(1, shape_b[0]);
195*c217d954SCole Faust
196*c217d954SCole Faust SimpleTensor<TI> a_transposed{ shape_a_to_use_transposed, data_type_a, 1 };
197*c217d954SCole Faust SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, data_type_b == DataType::QSYMM8_PER_CHANNEL ? b_qinfo : QuantizationInfo(1.0f / 255, b_offset) };
198*c217d954SCole Faust
199*c217d954SCole Faust // Fill reference
200*c217d954SCole Faust fill(a, 0);
201*c217d954SCole Faust fill(b, 1);
202*c217d954SCole Faust
203*c217d954SCole Faust // Transpose reference if required
204*c217d954SCole Faust /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
205*c217d954SCole Faust therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
206*c217d954SCole Faust in order to be able to call reference implementation that works with (B x M x K) input.
207*c217d954SCole Faust Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
208*c217d954SCole Faust if(pretranspose_A)
209*c217d954SCole Faust {
210*c217d954SCole Faust transpose_matrix<TI>(a, a_transposed);
211*c217d954SCole Faust }
212*c217d954SCole Faust
213*c217d954SCole Faust if(pretranspose_B)
214*c217d954SCole Faust {
215*c217d954SCole Faust transpose_matrix<TW>(b, b_transposed);
216*c217d954SCole Faust }
217*c217d954SCole Faust
218*c217d954SCole Faust // Run with variable inputs.
219*c217d954SCole Faust if(run_twice)
220*c217d954SCole Faust {
221*c217d954SCole Faust reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
222*c217d954SCole Faust fill((pretranspose_A) ? a_transposed : a, 3);
223*c217d954SCole Faust fill((pretranspose_B) ? b_transposed : b, 4);
224*c217d954SCole Faust }
225*c217d954SCole Faust
226*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
227*c217d954SCole Faust }
228*c217d954SCole Faust }
229*c217d954SCole Faust
230*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
231*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
232*c217d954SCole Faust {
233*c217d954SCole Faust public:
234*c217d954SCole Faust template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset)235*c217d954SCole Faust void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
236*c217d954SCole Faust {
237*c217d954SCole Faust _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset);
238*c217d954SCole Faust _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset);
239*c217d954SCole Faust }
240*c217d954SCole Faust
241*c217d954SCole Faust protected:
compute_target(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset)242*c217d954SCole Faust TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
243*c217d954SCole Faust {
244*c217d954SCole Faust return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_offset,
245*c217d954SCole Faust b_offset);
246*c217d954SCole Faust }
247*c217d954SCole Faust
compute_reference(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset)248*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset)
249*c217d954SCole Faust {
250*c217d954SCole Faust return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset);
251*c217d954SCole Faust }
252*c217d954SCole Faust
253*c217d954SCole Faust TensorType _target{};
254*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
255*c217d954SCole Faust };
256*c217d954SCole Faust
257*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
258*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
259*c217d954SCole Faust {
260*c217d954SCole Faust public:
261*c217d954SCole Faust template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_b,bool reshape_b_only_on_first_run)262*c217d954SCole Faust void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b,
263*c217d954SCole Faust bool reshape_b_only_on_first_run)
264*c217d954SCole Faust {
265*c217d954SCole Faust ARM_COMPUTE_ASSERT(output_stage.type != GEMMLowpOutputStageType::NONE);
266*c217d954SCole Faust DataType data_type_a = data_type_b == DataType::QASYMM8_SIGNED ? DataType::QASYMM8_SIGNED : DataType::QASYMM8;
267*c217d954SCole Faust
268*c217d954SCole Faust if(data_type_b == DataType::QSYMM8_PER_CHANNEL)
269*c217d954SCole Faust {
270*c217d954SCole Faust output_stage.is_quantized_per_channel = true;
271*c217d954SCole Faust const size_t num_channels = shape_b[0];
272*c217d954SCole Faust std::vector<float> scales(num_channels);
273*c217d954SCole Faust std::uniform_real_distribution<float> distribution(0.f, 1.f);
274*c217d954SCole Faust library->fill(scales, distribution, 0);
275*c217d954SCole Faust output_stage.gemmlowp_multipliers.resize(num_channels);
276*c217d954SCole Faust output_stage.gemmlowp_shifts.resize(num_channels);
277*c217d954SCole Faust for(size_t i = 0; i < num_channels; ++i)
278*c217d954SCole Faust {
279*c217d954SCole Faust quantization::calculate_quantized_multiplier(scales[i], &output_stage.gemmlowp_multipliers[i], &output_stage.gemmlowp_shifts[i]);
280*c217d954SCole Faust }
281*c217d954SCole Faust
282*c217d954SCole Faust _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales));
283*c217d954SCole Faust _target = compute_target(shape_a, shape_b, shape_output, a_offset, 0, output_stage, data_type_a, data_type_b, QuantizationInfo(scales), reshape_b_only_on_first_run);
284*c217d954SCole Faust }
285*c217d954SCole Faust else
286*c217d954SCole Faust {
287*c217d954SCole Faust _reference = compute_reference(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo());
288*c217d954SCole Faust _target = compute_target(shape_a, shape_b, shape_output, a_offset, b_offset, output_stage, data_type_a, data_type_b, QuantizationInfo(), reshape_b_only_on_first_run);
289*c217d954SCole Faust }
290*c217d954SCole Faust }
291*c217d954SCole Faust
292*c217d954SCole Faust protected:
293*c217d954SCole Faust TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage,
294*c217d954SCole Faust DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo, bool reshape_b_only_on_first_run = false)
295*c217d954SCole Faust {
296*c217d954SCole Faust return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, qasymm8_t, true, run_twice>(shape_a, shape_b, shape_output, a_offset,
297*c217d954SCole Faust b_offset,
298*c217d954SCole Faust output_stage, data_type_a, data_type_b, b_qinfo, reshape_b_only_on_first_run);
299*c217d954SCole Faust }
300*c217d954SCole Faust
compute_reference(const TensorShape & shape_a,const TensorShape & shape_b,const TensorShape & shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_a,DataType data_type_b,QuantizationInfo b_qinfo)301*c217d954SCole Faust SimpleTensor<TI> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, int32_t a_offset, int32_t b_offset,
302*c217d954SCole Faust GEMMLowpOutputStageInfo output_stage, DataType data_type_a, DataType data_type_b, QuantizationInfo b_qinfo)
303*c217d954SCole Faust {
304*c217d954SCole Faust SimpleTensor<int32_t> output = compute_gemmlowp_reference<reinterpret_input_as_3d, TI, TW, false, false, run_twice>(shape_a, shape_b, shape_output, a_offset, b_offset, data_type_a, data_type_b,
305*c217d954SCole Faust b_qinfo);
306*c217d954SCole Faust
307*c217d954SCole Faust TensorShape bias_shape(shape_b[0]);
308*c217d954SCole Faust SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
309*c217d954SCole Faust (run_twice) ? fill(bias, 5) : fill(bias, 2); // Fill bias with same seed as last run of gemmlowp_target
310*c217d954SCole Faust
311*c217d954SCole Faust switch(output_stage.type)
312*c217d954SCole Faust {
313*c217d954SCole Faust case GEMMLowpOutputStageType::QUANTIZE_DOWN:
314*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale<int32_t, TW>(output, bias,
315*c217d954SCole Faust output_stage.gemmlowp_offset, output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
316*c217d954SCole Faust break;
317*c217d954SCole Faust case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
318*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TW>(output, bias,
319*c217d954SCole Faust output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
320*c217d954SCole Faust break;
321*c217d954SCole Faust default:
322*c217d954SCole Faust ARM_COMPUTE_ERROR("Not Supported!");
323*c217d954SCole Faust }
324*c217d954SCole Faust }
325*c217d954SCole Faust
326*c217d954SCole Faust TensorType _target{};
327*c217d954SCole Faust SimpleTensor<TI> _reference{};
328*c217d954SCole Faust };
329*c217d954SCole Faust
330*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
331*c217d954SCole Faust class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
332*c217d954SCole Faust GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
333*c217d954SCole Faust {
334*c217d954SCole Faust public:
335*c217d954SCole Faust template <typename...>
setup(TensorShape shape_a,TensorShape shape_b,TensorShape shape_output,int32_t a_offset,int32_t b_offset,GEMMLowpOutputStageInfo output_stage,DataType data_type_b)336*c217d954SCole Faust void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, GEMMLowpOutputStageInfo output_stage, DataType data_type_b)
337*c217d954SCole Faust {
338*c217d954SCole Faust GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
339*c217d954SCole Faust shape_output, a_offset, b_offset, output_stage, data_type_b, false);
340*c217d954SCole Faust }
341*c217d954SCole Faust };
342*c217d954SCole Faust
343*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
344*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToUint8ScaleValidationFixture : public framework::Fixture
345*c217d954SCole Faust {
346*c217d954SCole Faust public:
347*c217d954SCole Faust template <typename...>
setup(TensorShape shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)348*c217d954SCole Faust void setup(TensorShape shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
349*c217d954SCole Faust {
350*c217d954SCole Faust _target = compute_target(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
351*c217d954SCole Faust _reference = compute_reference(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
352*c217d954SCole Faust }
353*c217d954SCole Faust
354*c217d954SCole Faust protected:
355*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)356*c217d954SCole Faust void fill(U &&tensor, int i)
357*c217d954SCole Faust {
358*c217d954SCole Faust std::uniform_int_distribution<> distribution(-6000, 6000);
359*c217d954SCole Faust library->fill(tensor, distribution, i);
360*c217d954SCole Faust }
361*c217d954SCole Faust
compute_target(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)362*c217d954SCole Faust TensorType compute_target(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
363*c217d954SCole Faust {
364*c217d954SCole Faust TensorShape shape_bias(shape[0]);
365*c217d954SCole Faust
366*c217d954SCole Faust // Create tensors
367*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
368*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
369*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8, 1);
370*c217d954SCole Faust
371*c217d954SCole Faust // Create and configure function
372*c217d954SCole Faust FunctionType output_stage;
373*c217d954SCole Faust GEMMLowpOutputStageInfo output_stage_info = GEMMLowpOutputStageInfo();
374*c217d954SCole Faust output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN;
375*c217d954SCole Faust output_stage_info.gemmlowp_offset = result_offset;
376*c217d954SCole Faust output_stage_info.gemmlowp_multiplier = result_mult_int;
377*c217d954SCole Faust output_stage_info.gemmlowp_shift = result_shift;
378*c217d954SCole Faust output_stage_info.gemmlowp_min_bound = min;
379*c217d954SCole Faust output_stage_info.gemmlowp_max_bound = max;
380*c217d954SCole Faust output_stage_info.output_data_type = DataType::QASYMM8;
381*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, output_stage_info);
382*c217d954SCole Faust
383*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
384*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
385*c217d954SCole Faust
386*c217d954SCole Faust // Allocate tensors
387*c217d954SCole Faust a.allocator()->allocate();
388*c217d954SCole Faust c.allocator()->allocate();
389*c217d954SCole Faust
390*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
391*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
392*c217d954SCole Faust
393*c217d954SCole Faust // Fill tensor
394*c217d954SCole Faust fill(AccessorType(a), 0);
395*c217d954SCole Faust
396*c217d954SCole Faust if(add_bias)
397*c217d954SCole Faust {
398*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
399*c217d954SCole Faust
400*c217d954SCole Faust // Allocate bias tensor
401*c217d954SCole Faust b.allocator()->allocate();
402*c217d954SCole Faust
403*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
404*c217d954SCole Faust
405*c217d954SCole Faust // Fill tensor
406*c217d954SCole Faust fill(AccessorType(b), 1);
407*c217d954SCole Faust }
408*c217d954SCole Faust
409*c217d954SCole Faust // Compute GEMM function
410*c217d954SCole Faust output_stage.run();
411*c217d954SCole Faust return c;
412*c217d954SCole Faust }
413*c217d954SCole Faust
compute_reference(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)414*c217d954SCole Faust SimpleTensor<uint8_t> compute_reference(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
415*c217d954SCole Faust {
416*c217d954SCole Faust // Create reference
417*c217d954SCole Faust TensorShape shape_bias(shape[0]);
418*c217d954SCole Faust
419*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
420*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
421*c217d954SCole Faust
422*c217d954SCole Faust // Fill reference
423*c217d954SCole Faust fill(a, 0);
424*c217d954SCole Faust
425*c217d954SCole Faust const std::vector<int32_t> result_mult_int_vec = { result_mult_int };
426*c217d954SCole Faust const std::vector<int32_t> result_shift_vec = { result_shift };
427*c217d954SCole Faust
428*c217d954SCole Faust if(add_bias)
429*c217d954SCole Faust {
430*c217d954SCole Faust // Fill bias
431*c217d954SCole Faust fill(b, 1);
432*c217d954SCole Faust
433*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale<int32_t, uint8_t>(a, b, result_offset, result_mult_int_vec, result_shift_vec, min, max);
434*c217d954SCole Faust }
435*c217d954SCole Faust else
436*c217d954SCole Faust {
437*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale<int32_t, uint8_t>(a, result_offset, result_mult_int_vec, result_shift_vec, min, max);
438*c217d954SCole Faust }
439*c217d954SCole Faust }
440*c217d954SCole Faust
441*c217d954SCole Faust TensorType _target{};
442*c217d954SCole Faust SimpleTensor<uint8_t> _reference{};
443*c217d954SCole Faust };
444*c217d954SCole Faust
445*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
446*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt8ScaleValidationFixture : public framework::Fixture
447*c217d954SCole Faust {
448*c217d954SCole Faust public:
449*c217d954SCole Faust template <typename...>
setup(TensorShape shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)450*c217d954SCole Faust void setup(TensorShape shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
451*c217d954SCole Faust {
452*c217d954SCole Faust _target = compute_target(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
453*c217d954SCole Faust _reference = compute_reference(shape, result_offset, result_mult_int, result_shift, min, max, add_bias);
454*c217d954SCole Faust }
455*c217d954SCole Faust
456*c217d954SCole Faust protected:
457*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)458*c217d954SCole Faust void fill(U &&tensor, int i)
459*c217d954SCole Faust {
460*c217d954SCole Faust std::uniform_int_distribution<> distribution(-6000, 6000);
461*c217d954SCole Faust library->fill(tensor, distribution, i);
462*c217d954SCole Faust }
463*c217d954SCole Faust
compute_target(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)464*c217d954SCole Faust TensorType compute_target(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
465*c217d954SCole Faust {
466*c217d954SCole Faust TensorShape shape_bias(shape[0]);
467*c217d954SCole Faust
468*c217d954SCole Faust // Create tensors
469*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
470*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
471*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8_SIGNED, 1);
472*c217d954SCole Faust
473*c217d954SCole Faust // Create and configure function
474*c217d954SCole Faust FunctionType output_stage;
475*c217d954SCole Faust GEMMLowpOutputStageInfo output_stage_info = GEMMLowpOutputStageInfo();
476*c217d954SCole Faust output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN;
477*c217d954SCole Faust output_stage_info.gemmlowp_offset = result_offset;
478*c217d954SCole Faust output_stage_info.gemmlowp_multiplier = result_mult_int;
479*c217d954SCole Faust output_stage_info.gemmlowp_shift = result_shift;
480*c217d954SCole Faust output_stage_info.gemmlowp_min_bound = min;
481*c217d954SCole Faust output_stage_info.gemmlowp_max_bound = max;
482*c217d954SCole Faust output_stage_info.output_data_type = DataType::QASYMM8_SIGNED;
483*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, output_stage_info);
484*c217d954SCole Faust
485*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
486*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
487*c217d954SCole Faust
488*c217d954SCole Faust // Allocate tensors
489*c217d954SCole Faust a.allocator()->allocate();
490*c217d954SCole Faust c.allocator()->allocate();
491*c217d954SCole Faust
492*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
493*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
494*c217d954SCole Faust
495*c217d954SCole Faust // Fill tensor
496*c217d954SCole Faust fill(AccessorType(a), 0);
497*c217d954SCole Faust
498*c217d954SCole Faust if(add_bias)
499*c217d954SCole Faust {
500*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
501*c217d954SCole Faust
502*c217d954SCole Faust // Allocate bias tensor
503*c217d954SCole Faust b.allocator()->allocate();
504*c217d954SCole Faust
505*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
506*c217d954SCole Faust
507*c217d954SCole Faust // Fill tensor
508*c217d954SCole Faust fill(AccessorType(b), 1);
509*c217d954SCole Faust }
510*c217d954SCole Faust
511*c217d954SCole Faust // Compute GEMM function
512*c217d954SCole Faust output_stage.run();
513*c217d954SCole Faust return c;
514*c217d954SCole Faust }
515*c217d954SCole Faust
compute_reference(const TensorShape & shape,int32_t result_offset,int32_t result_mult_int,int32_t result_shift,int32_t min,int32_t max,bool add_bias)516*c217d954SCole Faust SimpleTensor<int8_t> compute_reference(const TensorShape &shape, int32_t result_offset, int32_t result_mult_int, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
517*c217d954SCole Faust {
518*c217d954SCole Faust // Create reference
519*c217d954SCole Faust TensorShape shape_bias(shape[0]);
520*c217d954SCole Faust
521*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
522*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
523*c217d954SCole Faust
524*c217d954SCole Faust // Fill reference
525*c217d954SCole Faust fill(a, 0);
526*c217d954SCole Faust
527*c217d954SCole Faust const std::vector<int32_t> result_mult_int_vec = { result_mult_int };
528*c217d954SCole Faust const std::vector<int32_t> result_shift_vec = { result_shift };
529*c217d954SCole Faust
530*c217d954SCole Faust if(add_bias)
531*c217d954SCole Faust {
532*c217d954SCole Faust // Fill bias
533*c217d954SCole Faust fill(b, 1);
534*c217d954SCole Faust
535*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale<int32_t, int8_t>(a, b, result_offset, result_mult_int_vec, result_shift_vec, min, max);
536*c217d954SCole Faust }
537*c217d954SCole Faust else
538*c217d954SCole Faust {
539*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale<int32_t, int8_t>(a, result_offset, result_mult_int_vec, result_shift_vec, min, max);
540*c217d954SCole Faust }
541*c217d954SCole Faust }
542*c217d954SCole Faust
543*c217d954SCole Faust TensorType _target{};
544*c217d954SCole Faust SimpleTensor<int8_t> _reference{};
545*c217d954SCole Faust };
546*c217d954SCole Faust
547*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
548*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointValidationFixture : public framework::Fixture
549*c217d954SCole Faust {
550*c217d954SCole Faust public:
551*c217d954SCole Faust template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)552*c217d954SCole Faust void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
553*c217d954SCole Faust {
554*c217d954SCole Faust _target = compute_target(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
555*c217d954SCole Faust _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
556*c217d954SCole Faust }
557*c217d954SCole Faust
558*c217d954SCole Faust protected:
559*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)560*c217d954SCole Faust void fill(U &&tensor, int i)
561*c217d954SCole Faust {
562*c217d954SCole Faust std::uniform_int_distribution<> distribution(-6000, 6000);
563*c217d954SCole Faust library->fill(tensor, distribution, i);
564*c217d954SCole Faust }
565*c217d954SCole Faust
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)566*c217d954SCole Faust TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
567*c217d954SCole Faust {
568*c217d954SCole Faust TensorShape shape_bias(shape[0]);
569*c217d954SCole Faust
570*c217d954SCole Faust // Create tensors
571*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
572*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
573*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8_SIGNED, 1);
574*c217d954SCole Faust
575*c217d954SCole Faust // Create and configure function
576*c217d954SCole Faust FunctionType output_stage;
577*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
578*c217d954SCole Faust
579*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
580*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
581*c217d954SCole Faust
582*c217d954SCole Faust // Allocate tensors
583*c217d954SCole Faust a.allocator()->allocate();
584*c217d954SCole Faust c.allocator()->allocate();
585*c217d954SCole Faust
586*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
587*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
588*c217d954SCole Faust
589*c217d954SCole Faust // Fill tensor
590*c217d954SCole Faust fill(AccessorType(a), 0);
591*c217d954SCole Faust
592*c217d954SCole Faust if(add_bias)
593*c217d954SCole Faust {
594*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
595*c217d954SCole Faust
596*c217d954SCole Faust // Allocate bias tensor
597*c217d954SCole Faust b.allocator()->allocate();
598*c217d954SCole Faust
599*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
600*c217d954SCole Faust
601*c217d954SCole Faust // Fill tensor
602*c217d954SCole Faust fill(AccessorType(b), 1);
603*c217d954SCole Faust }
604*c217d954SCole Faust
605*c217d954SCole Faust // Compute GEMM function
606*c217d954SCole Faust output_stage.run();
607*c217d954SCole Faust return c;
608*c217d954SCole Faust }
609*c217d954SCole Faust
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)610*c217d954SCole Faust SimpleTensor<int8_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max,
611*c217d954SCole Faust bool add_bias)
612*c217d954SCole Faust {
613*c217d954SCole Faust // Create reference
614*c217d954SCole Faust TensorShape shape_bias(shape[0]);
615*c217d954SCole Faust
616*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
617*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
618*c217d954SCole Faust
619*c217d954SCole Faust // Fill reference
620*c217d954SCole Faust fill(a, 0);
621*c217d954SCole Faust
622*c217d954SCole Faust const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
623*c217d954SCole Faust const std::vector<int32_t> result_shift_vec = { result_shift };
624*c217d954SCole Faust
625*c217d954SCole Faust if(add_bias)
626*c217d954SCole Faust {
627*c217d954SCole Faust // Fill bias
628*c217d954SCole Faust fill(b, 1);
629*c217d954SCole Faust
630*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int8_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
631*c217d954SCole Faust }
632*c217d954SCole Faust else
633*c217d954SCole Faust {
634*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int8_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
635*c217d954SCole Faust }
636*c217d954SCole Faust }
637*c217d954SCole Faust
638*c217d954SCole Faust TensorType _target{};
639*c217d954SCole Faust SimpleTensor<int8_t> _reference{};
640*c217d954SCole Faust };
641*c217d954SCole Faust
642*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
643*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointValidationFixture : public framework::Fixture
644*c217d954SCole Faust {
645*c217d954SCole Faust public:
646*c217d954SCole Faust template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)647*c217d954SCole Faust void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
648*c217d954SCole Faust {
649*c217d954SCole Faust _target = compute_target(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
650*c217d954SCole Faust _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max, add_bias);
651*c217d954SCole Faust }
652*c217d954SCole Faust
653*c217d954SCole Faust protected:
654*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)655*c217d954SCole Faust void fill(U &&tensor, int i)
656*c217d954SCole Faust {
657*c217d954SCole Faust std::uniform_int_distribution<> distribution(-6000, 6000);
658*c217d954SCole Faust library->fill(tensor, distribution, i);
659*c217d954SCole Faust }
660*c217d954SCole Faust
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)661*c217d954SCole Faust TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max, bool add_bias)
662*c217d954SCole Faust {
663*c217d954SCole Faust TensorShape shape_bias(shape[0]);
664*c217d954SCole Faust
665*c217d954SCole Faust // Create tensors
666*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
667*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
668*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, DataType::QASYMM8, 1);
669*c217d954SCole Faust
670*c217d954SCole Faust // Create and configure function
671*c217d954SCole Faust FunctionType output_stage;
672*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, result_offset_after_shift, min, max);
673*c217d954SCole Faust
674*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
675*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
676*c217d954SCole Faust
677*c217d954SCole Faust // Allocate tensors
678*c217d954SCole Faust a.allocator()->allocate();
679*c217d954SCole Faust c.allocator()->allocate();
680*c217d954SCole Faust
681*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
682*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
683*c217d954SCole Faust
684*c217d954SCole Faust // Fill tensor
685*c217d954SCole Faust fill(AccessorType(a), 0);
686*c217d954SCole Faust
687*c217d954SCole Faust if(add_bias)
688*c217d954SCole Faust {
689*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
690*c217d954SCole Faust
691*c217d954SCole Faust // Allocate bias tensor
692*c217d954SCole Faust b.allocator()->allocate();
693*c217d954SCole Faust
694*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
695*c217d954SCole Faust
696*c217d954SCole Faust // Fill tensor
697*c217d954SCole Faust fill(AccessorType(b), 1);
698*c217d954SCole Faust }
699*c217d954SCole Faust
700*c217d954SCole Faust // Compute GEMM function
701*c217d954SCole Faust output_stage.run();
702*c217d954SCole Faust return c;
703*c217d954SCole Faust }
704*c217d954SCole Faust
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t result_offset_after_shift,int32_t min,int32_t max,bool add_bias)705*c217d954SCole Faust SimpleTensor<uint8_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max,
706*c217d954SCole Faust bool add_bias)
707*c217d954SCole Faust {
708*c217d954SCole Faust // Create reference
709*c217d954SCole Faust TensorShape shape_bias(shape[0]);
710*c217d954SCole Faust
711*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
712*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
713*c217d954SCole Faust
714*c217d954SCole Faust // Fill reference
715*c217d954SCole Faust fill(a, 0);
716*c217d954SCole Faust
717*c217d954SCole Faust const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
718*c217d954SCole Faust const std::vector<int32_t> result_shift_vec = { result_shift };
719*c217d954SCole Faust
720*c217d954SCole Faust if(add_bias)
721*c217d954SCole Faust {
722*c217d954SCole Faust // Fill bias
723*c217d954SCole Faust fill(b, 1);
724*c217d954SCole Faust
725*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
726*c217d954SCole Faust }
727*c217d954SCole Faust else
728*c217d954SCole Faust {
729*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, uint8_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, result_offset_after_shift, min, max);
730*c217d954SCole Faust }
731*c217d954SCole Faust }
732*c217d954SCole Faust
733*c217d954SCole Faust TensorType _target{};
734*c217d954SCole Faust SimpleTensor<uint8_t> _reference{};
735*c217d954SCole Faust };
736*c217d954SCole Faust
737*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
738*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ScaleByFloatValidationFixture : public framework::Fixture
739*c217d954SCole Faust {
740*c217d954SCole Faust public:
741*c217d954SCole Faust template <typename...>
setup(DataType data_type,TensorShape shape,float result_real_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)742*c217d954SCole Faust void setup(DataType data_type, TensorShape shape, float result_real_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
743*c217d954SCole Faust {
744*c217d954SCole Faust _target = compute_target(data_type, shape, result_real_multiplier, result_offset, min, max, add_bias);
745*c217d954SCole Faust _reference = compute_reference(shape, result_real_multiplier, result_offset, min, max, add_bias);
746*c217d954SCole Faust }
747*c217d954SCole Faust
748*c217d954SCole Faust protected:
749*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)750*c217d954SCole Faust void fill(U &&tensor, int i)
751*c217d954SCole Faust {
752*c217d954SCole Faust // To avoid data all being clampped
753*c217d954SCole Faust std::uniform_int_distribution<> distribution(-500, 500);
754*c217d954SCole Faust library->fill(tensor, distribution, i);
755*c217d954SCole Faust }
756*c217d954SCole Faust
compute_target(DataType data_type,const TensorShape & shape,float result_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)757*c217d954SCole Faust TensorType compute_target(DataType data_type, const TensorShape &shape, float result_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
758*c217d954SCole Faust {
759*c217d954SCole Faust TensorShape shape_bias(shape[0]);
760*c217d954SCole Faust
761*c217d954SCole Faust // Create tensors
762*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
763*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
764*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, data_type, 1);
765*c217d954SCole Faust
766*c217d954SCole Faust // create output stage info
767*c217d954SCole Faust GEMMLowpOutputStageInfo info;
768*c217d954SCole Faust info.gemmlowp_max_bound = max;
769*c217d954SCole Faust info.gemmlowp_min_bound = min;
770*c217d954SCole Faust info.gemmlowp_real_multiplier = result_multiplier;
771*c217d954SCole Faust info.gemmlowp_offset = result_offset;
772*c217d954SCole Faust info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FLOAT;
773*c217d954SCole Faust info.output_data_type = data_type;
774*c217d954SCole Faust
775*c217d954SCole Faust // Create and configure function
776*c217d954SCole Faust FunctionType output_stage;
777*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, info);
778*c217d954SCole Faust
779*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
780*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
781*c217d954SCole Faust
782*c217d954SCole Faust // Allocate tensors
783*c217d954SCole Faust a.allocator()->allocate();
784*c217d954SCole Faust c.allocator()->allocate();
785*c217d954SCole Faust
786*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
787*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
788*c217d954SCole Faust
789*c217d954SCole Faust // Fill tensor
790*c217d954SCole Faust fill(AccessorType(a), 0);
791*c217d954SCole Faust
792*c217d954SCole Faust if(add_bias)
793*c217d954SCole Faust {
794*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
795*c217d954SCole Faust
796*c217d954SCole Faust // Allocate bias tensor
797*c217d954SCole Faust b.allocator()->allocate();
798*c217d954SCole Faust
799*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
800*c217d954SCole Faust
801*c217d954SCole Faust // Fill tensor
802*c217d954SCole Faust fill(AccessorType(b), 1);
803*c217d954SCole Faust }
804*c217d954SCole Faust
805*c217d954SCole Faust // Compute GEMM function
806*c217d954SCole Faust output_stage.run();
807*c217d954SCole Faust return c;
808*c217d954SCole Faust }
809*c217d954SCole Faust
compute_reference(const TensorShape & shape,float_t result_real_multiplier,int32_t result_offset,int32_t min,int32_t max,bool add_bias)810*c217d954SCole Faust SimpleTensor<T> compute_reference(const TensorShape &shape, float_t result_real_multiplier, int32_t result_offset, int32_t min, int32_t max, bool add_bias)
811*c217d954SCole Faust {
812*c217d954SCole Faust // Create reference
813*c217d954SCole Faust TensorShape shape_bias(shape[0]);
814*c217d954SCole Faust
815*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
816*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
817*c217d954SCole Faust
818*c217d954SCole Faust // Fill reference
819*c217d954SCole Faust fill(a, 0);
820*c217d954SCole Faust
821*c217d954SCole Faust const std::vector<float_t> result_float_multiplier_vec = { result_real_multiplier };
822*c217d954SCole Faust
823*c217d954SCole Faust if(add_bias)
824*c217d954SCole Faust {
825*c217d954SCole Faust // Fill bias
826*c217d954SCole Faust fill(b, 1);
827*c217d954SCole Faust
828*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_float<int32_t, T>(a, b, result_float_multiplier_vec, result_offset, min, max);
829*c217d954SCole Faust }
830*c217d954SCole Faust else
831*c217d954SCole Faust {
832*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_float<int32_t, T>(a, result_float_multiplier_vec, result_offset, min, max);
833*c217d954SCole Faust }
834*c217d954SCole Faust }
835*c217d954SCole Faust
836*c217d954SCole Faust TensorType _target{};
837*c217d954SCole Faust SimpleTensor<T> _reference{};
838*c217d954SCole Faust };
839*c217d954SCole Faust
840*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType>
841*c217d954SCole Faust class GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture : public framework::Fixture
842*c217d954SCole Faust {
843*c217d954SCole Faust public:
844*c217d954SCole Faust template <typename...>
setup(TensorShape shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)845*c217d954SCole Faust void setup(TensorShape shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
846*c217d954SCole Faust {
847*c217d954SCole Faust _target = compute_target(shape, result_fixedpoint_multiplier, result_shift, min, max, add_bias);
848*c217d954SCole Faust _reference = compute_reference(shape, result_fixedpoint_multiplier, result_shift, min, max, add_bias);
849*c217d954SCole Faust }
850*c217d954SCole Faust
851*c217d954SCole Faust protected:
852*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)853*c217d954SCole Faust void fill(U &&tensor, int i)
854*c217d954SCole Faust {
855*c217d954SCole Faust std::uniform_int_distribution<> distribution(-6000, 6000);
856*c217d954SCole Faust library->fill(tensor, distribution, i);
857*c217d954SCole Faust }
858*c217d954SCole Faust
compute_target(const TensorShape & shape,int32_t result_fixedpoint_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)859*c217d954SCole Faust TensorType compute_target(const TensorShape &shape, int32_t result_fixedpoint_multiplier, int32_t result_shift, int32_t min, int32_t max, bool add_bias)
860*c217d954SCole Faust {
861*c217d954SCole Faust TensorShape shape_bias(shape[0]);
862*c217d954SCole Faust
863*c217d954SCole Faust // Create tensors
864*c217d954SCole Faust TensorType a = create_tensor<TensorType>(shape, DataType::S32, 1);
865*c217d954SCole Faust TensorType b = create_tensor<TensorType>(shape_bias, DataType::S32, 1);
866*c217d954SCole Faust TensorType c = create_tensor<TensorType>(shape, DataType::QSYMM16, 1);
867*c217d954SCole Faust
868*c217d954SCole Faust // Create and configure function
869*c217d954SCole Faust FunctionType output_stage;
870*c217d954SCole Faust output_stage.configure(&a, add_bias ? &b : nullptr, &c, result_fixedpoint_multiplier, result_shift, min, max);
871*c217d954SCole Faust
872*c217d954SCole Faust ARM_COMPUTE_ASSERT(a.info()->is_resizable());
873*c217d954SCole Faust ARM_COMPUTE_ASSERT(c.info()->is_resizable());
874*c217d954SCole Faust
875*c217d954SCole Faust // Allocate tensors
876*c217d954SCole Faust a.allocator()->allocate();
877*c217d954SCole Faust c.allocator()->allocate();
878*c217d954SCole Faust
879*c217d954SCole Faust ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
880*c217d954SCole Faust ARM_COMPUTE_ASSERT(!c.info()->is_resizable());
881*c217d954SCole Faust
882*c217d954SCole Faust // Fill tensor
883*c217d954SCole Faust fill(AccessorType(a), 0);
884*c217d954SCole Faust
885*c217d954SCole Faust if(add_bias)
886*c217d954SCole Faust {
887*c217d954SCole Faust ARM_COMPUTE_ASSERT(b.info()->is_resizable());
888*c217d954SCole Faust
889*c217d954SCole Faust // Allocate bias tensor
890*c217d954SCole Faust b.allocator()->allocate();
891*c217d954SCole Faust
892*c217d954SCole Faust ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
893*c217d954SCole Faust
894*c217d954SCole Faust // Fill tensor
895*c217d954SCole Faust fill(AccessorType(b), 1);
896*c217d954SCole Faust }
897*c217d954SCole Faust
898*c217d954SCole Faust // Compute GEMM function
899*c217d954SCole Faust output_stage.run();
900*c217d954SCole Faust return c;
901*c217d954SCole Faust }
902*c217d954SCole Faust
compute_reference(const TensorShape & shape,int32_t result_fixed_point_multiplier,int32_t result_shift,int32_t min,int32_t max,bool add_bias)903*c217d954SCole Faust SimpleTensor<int16_t> compute_reference(const TensorShape &shape, int32_t result_fixed_point_multiplier, int32_t result_shift, int32_t min, int32_t max,
904*c217d954SCole Faust bool add_bias)
905*c217d954SCole Faust {
906*c217d954SCole Faust // Create reference
907*c217d954SCole Faust TensorShape shape_bias(shape[0]);
908*c217d954SCole Faust
909*c217d954SCole Faust SimpleTensor<int32_t> a{ shape, DataType::S32, 1 };
910*c217d954SCole Faust SimpleTensor<int32_t> b{ shape_bias, DataType::S32, 1 };
911*c217d954SCole Faust
912*c217d954SCole Faust // Fill reference
913*c217d954SCole Faust fill(a, 0);
914*c217d954SCole Faust
915*c217d954SCole Faust const std::vector<int32_t> result_fixed_point_multiplier_vec = { result_fixed_point_multiplier };
916*c217d954SCole Faust const std::vector<int32_t> result_shift_vec = { result_shift };
917*c217d954SCole Faust
918*c217d954SCole Faust if(add_bias)
919*c217d954SCole Faust {
920*c217d954SCole Faust // Fill bias
921*c217d954SCole Faust fill(b, 1);
922*c217d954SCole Faust
923*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int16_t>(a, b, result_fixed_point_multiplier_vec, result_shift_vec, 0, min, max);
924*c217d954SCole Faust }
925*c217d954SCole Faust else
926*c217d954SCole Faust {
927*c217d954SCole Faust return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, int16_t>(a, result_fixed_point_multiplier_vec, result_shift_vec, 0, min, max);
928*c217d954SCole Faust }
929*c217d954SCole Faust }
930*c217d954SCole Faust
931*c217d954SCole Faust TensorType _target{};
932*c217d954SCole Faust SimpleTensor<int16_t> _reference{};
933*c217d954SCole Faust };
934*c217d954SCole Faust
935*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeLHSOperatorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
936*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedValidationFixture : public framework::Fixture
937*c217d954SCole Faust {
938*c217d954SCole Faust public:
939*c217d954SCole Faust template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int v0,unsigned int h0,bool interleave_lhs,bool interleave_rhs,DataType data_type)940*c217d954SCole Faust void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool interleave_lhs,
941*c217d954SCole Faust bool interleave_rhs, DataType data_type)
942*c217d954SCole Faust {
943*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
944*c217d954SCole Faust lhs_info.m0 = m0;
945*c217d954SCole Faust lhs_info.k0 = k0;
946*c217d954SCole Faust lhs_info.v0 = v0;
947*c217d954SCole Faust lhs_info.interleave = interleave_lhs;
948*c217d954SCole Faust lhs_info.transpose = false;
949*c217d954SCole Faust
950*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
951*c217d954SCole Faust rhs_info.n0 = n0;
952*c217d954SCole Faust rhs_info.k0 = k0;
953*c217d954SCole Faust rhs_info.h0 = h0;
954*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
955*c217d954SCole Faust rhs_info.transpose = true;
956*c217d954SCole Faust
957*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
958*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
959*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
960*c217d954SCole Faust
961*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
962*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, data_type);
963*c217d954SCole Faust }
964*c217d954SCole Faust
965*c217d954SCole Faust protected:
966*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)967*c217d954SCole Faust void fill(U &&tensor, int i)
968*c217d954SCole Faust {
969*c217d954SCole Faust switch(tensor.data_type())
970*c217d954SCole Faust {
971*c217d954SCole Faust case DataType::QASYMM8:
972*c217d954SCole Faust {
973*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
974*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
975*c217d954SCole Faust library->fill(tensor, distribution, i);
976*c217d954SCole Faust }
977*c217d954SCole Faust break;
978*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
979*c217d954SCole Faust {
980*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
981*c217d954SCole Faust library->fill(tensor, distribution, i);
982*c217d954SCole Faust }
983*c217d954SCole Faust break;
984*c217d954SCole Faust default:
985*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
986*c217d954SCole Faust }
987*c217d954SCole Faust }
988*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)989*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
990*c217d954SCole Faust {
991*c217d954SCole Faust // Create tensors
992*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
993*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
994*c217d954SCole Faust TensorType lhs_reshaped;
995*c217d954SCole Faust TensorType rhs_reshaped;
996*c217d954SCole Faust TensorType dst;
997*c217d954SCole Faust
998*c217d954SCole Faust const unsigned int M = lhs_shape[1];
999*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1000*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1001*c217d954SCole Faust
1002*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1003*c217d954SCole Faust
1004*c217d954SCole Faust // Create and configure function
1005*c217d954SCole Faust ReshapeLHSOperatorType reshape_lhs;
1006*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1007*c217d954SCole Faust GEMMFunctionType gemm;
1008*c217d954SCole Faust reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
1009*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1010*c217d954SCole Faust gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
1011*c217d954SCole Faust
1012*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1013*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1014*c217d954SCole Faust
1015*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &lhs_reshaped, &rhs_reshaped, &dst });
1016*c217d954SCole Faust
1017*c217d954SCole Faust // Allocate tensors
1018*c217d954SCole Faust lhs.allocator()->allocate();
1019*c217d954SCole Faust rhs.allocator()->allocate();
1020*c217d954SCole Faust lhs_reshaped.allocator()->allocate();
1021*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1022*c217d954SCole Faust dst.allocator()->allocate();
1023*c217d954SCole Faust
1024*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1025*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1026*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs_reshaped.info()->is_resizable());
1027*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1028*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1029*c217d954SCole Faust
1030*c217d954SCole Faust // Fill tensors
1031*c217d954SCole Faust fill(AccessorType(lhs), 0);
1032*c217d954SCole Faust fill(AccessorType(rhs), 1);
1033*c217d954SCole Faust
1034*c217d954SCole Faust // Compute GEMM
1035*c217d954SCole Faust ITensorPack reshape_lhs_pack = { { ACL_SRC, &lhs }, { ACL_DST, &lhs_reshaped } };
1036*c217d954SCole Faust reshape_lhs.run(reshape_lhs_pack);
1037*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1038*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1039*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1040*c217d954SCole Faust gemm.run(gemm_pack);
1041*c217d954SCole Faust
1042*c217d954SCole Faust return dst;
1043*c217d954SCole Faust }
1044*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1045*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1046*c217d954SCole Faust {
1047*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1048*c217d954SCole Faust dst_shape[0] = rhs_shape[0];
1049*c217d954SCole Faust dst_shape[1] = lhs_shape[1];
1050*c217d954SCole Faust
1051*c217d954SCole Faust switch(data_type)
1052*c217d954SCole Faust {
1053*c217d954SCole Faust case DataType::QASYMM8:
1054*c217d954SCole Faust {
1055*c217d954SCole Faust // Create reference
1056*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1057*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1058*c217d954SCole Faust
1059*c217d954SCole Faust // Fill reference
1060*c217d954SCole Faust fill(lhs, 0);
1061*c217d954SCole Faust fill(rhs, 1);
1062*c217d954SCole Faust
1063*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1064*c217d954SCole Faust }
1065*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1066*c217d954SCole Faust {
1067*c217d954SCole Faust // Create reference
1068*c217d954SCole Faust SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1069*c217d954SCole Faust SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1070*c217d954SCole Faust
1071*c217d954SCole Faust // Fill reference
1072*c217d954SCole Faust fill(lhs, 0);
1073*c217d954SCole Faust fill(rhs, 1);
1074*c217d954SCole Faust
1075*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1076*c217d954SCole Faust }
1077*c217d954SCole Faust default:
1078*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1079*c217d954SCole Faust }
1080*c217d954SCole Faust }
1081*c217d954SCole Faust
1082*c217d954SCole Faust TensorType _target{};
1083*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1084*c217d954SCole Faust };
1085*c217d954SCole Faust
1086*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeLHSOperatorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1087*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
1088*c217d954SCole Faust {
1089*c217d954SCole Faust public:
1090*c217d954SCole Faust template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int v0,unsigned int h0,bool interleave_lhs,bool interleave_rhs,DataType data_type)1091*c217d954SCole Faust void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
1092*c217d954SCole Faust bool interleave_lhs, bool interleave_rhs, DataType data_type)
1093*c217d954SCole Faust {
1094*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1095*c217d954SCole Faust lhs_info.m0 = m0;
1096*c217d954SCole Faust lhs_info.k0 = k0;
1097*c217d954SCole Faust lhs_info.v0 = v0;
1098*c217d954SCole Faust lhs_info.interleave = interleave_lhs;
1099*c217d954SCole Faust lhs_info.transpose = false;
1100*c217d954SCole Faust
1101*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1102*c217d954SCole Faust rhs_info.n0 = n0;
1103*c217d954SCole Faust rhs_info.k0 = k0;
1104*c217d954SCole Faust rhs_info.h0 = h0;
1105*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
1106*c217d954SCole Faust rhs_info.transpose = true;
1107*c217d954SCole Faust
1108*c217d954SCole Faust // In case of GEMM3D, m is the product between m_w and m_h
1109*c217d954SCole Faust const unsigned int m = m_w * m_h;
1110*c217d954SCole Faust
1111*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1112*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1113*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1114*c217d954SCole Faust
1115*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h, data_type);
1116*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, m_h, data_type);
1117*c217d954SCole Faust }
1118*c217d954SCole Faust
1119*c217d954SCole Faust protected:
1120*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1121*c217d954SCole Faust void fill(U &&tensor, int i)
1122*c217d954SCole Faust {
1123*c217d954SCole Faust switch(tensor.data_type())
1124*c217d954SCole Faust {
1125*c217d954SCole Faust case DataType::QASYMM8:
1126*c217d954SCole Faust {
1127*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1128*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1129*c217d954SCole Faust library->fill(tensor, distribution, i);
1130*c217d954SCole Faust }
1131*c217d954SCole Faust break;
1132*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1133*c217d954SCole Faust {
1134*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
1135*c217d954SCole Faust library->fill(tensor, distribution, i);
1136*c217d954SCole Faust }
1137*c217d954SCole Faust break;
1138*c217d954SCole Faust default:
1139*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1140*c217d954SCole Faust }
1141*c217d954SCole Faust }
1142*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h,DataType data_type)1143*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h,
1144*c217d954SCole Faust DataType data_type)
1145*c217d954SCole Faust {
1146*c217d954SCole Faust // Create tensors
1147*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1148*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1149*c217d954SCole Faust TensorType lhs_reshaped;
1150*c217d954SCole Faust TensorType rhs_reshaped;
1151*c217d954SCole Faust TensorType dst;
1152*c217d954SCole Faust
1153*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1154*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1155*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1156*c217d954SCole Faust
1157*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1158*c217d954SCole Faust
1159*c217d954SCole Faust // Create and configure function
1160*c217d954SCole Faust ReshapeLHSOperatorType reshape_lhs;
1161*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1162*c217d954SCole Faust GEMMFunctionType gemm;
1163*c217d954SCole Faust reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info);
1164*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1165*c217d954SCole Faust gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
1166*c217d954SCole Faust
1167*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1168*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1169*c217d954SCole Faust
1170*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &lhs_reshaped, &rhs_reshaped, &dst });
1171*c217d954SCole Faust
1172*c217d954SCole Faust // Allocate tensors
1173*c217d954SCole Faust lhs.allocator()->allocate();
1174*c217d954SCole Faust rhs.allocator()->allocate();
1175*c217d954SCole Faust lhs_reshaped.allocator()->allocate();
1176*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1177*c217d954SCole Faust dst.allocator()->allocate();
1178*c217d954SCole Faust
1179*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1180*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1181*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs_reshaped.info()->is_resizable());
1182*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1183*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1184*c217d954SCole Faust
1185*c217d954SCole Faust // Fill tensors
1186*c217d954SCole Faust fill(AccessorType(lhs), 0);
1187*c217d954SCole Faust fill(AccessorType(rhs), 1);
1188*c217d954SCole Faust
1189*c217d954SCole Faust // Compute GEMM
1190*c217d954SCole Faust ITensorPack reshape_lhs_pack = { { ACL_SRC, &lhs }, { ACL_DST, &lhs_reshaped } };
1191*c217d954SCole Faust reshape_lhs.run(reshape_lhs_pack);
1192*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1193*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1194*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1195*c217d954SCole Faust gemm.run(gemm_pack);
1196*c217d954SCole Faust
1197*c217d954SCole Faust return dst;
1198*c217d954SCole Faust }
1199*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h,DataType data_type)1200*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h, DataType data_type)
1201*c217d954SCole Faust {
1202*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1203*c217d954SCole Faust dst_shape.set(0, rhs_shape[0]);
1204*c217d954SCole Faust dst_shape.set(1, lhs_shape[1] / m_h);
1205*c217d954SCole Faust dst_shape.set(2, m_h);
1206*c217d954SCole Faust dst_shape.set(3, lhs_shape[2]);
1207*c217d954SCole Faust
1208*c217d954SCole Faust switch(data_type)
1209*c217d954SCole Faust {
1210*c217d954SCole Faust case DataType::QASYMM8:
1211*c217d954SCole Faust {
1212*c217d954SCole Faust // Create reference
1213*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1214*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1215*c217d954SCole Faust
1216*c217d954SCole Faust // Fill reference
1217*c217d954SCole Faust fill(lhs, 0);
1218*c217d954SCole Faust fill(rhs, 1);
1219*c217d954SCole Faust
1220*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1221*c217d954SCole Faust }
1222*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1223*c217d954SCole Faust {
1224*c217d954SCole Faust // Create reference
1225*c217d954SCole Faust SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1226*c217d954SCole Faust SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1227*c217d954SCole Faust
1228*c217d954SCole Faust // Fill reference
1229*c217d954SCole Faust fill(lhs, 0);
1230*c217d954SCole Faust fill(rhs, 1);
1231*c217d954SCole Faust
1232*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1233*c217d954SCole Faust }
1234*c217d954SCole Faust default:
1235*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1236*c217d954SCole Faust }
1237*c217d954SCole Faust }
1238*c217d954SCole Faust
1239*c217d954SCole Faust TensorType _target{};
1240*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1241*c217d954SCole Faust };
1242*c217d954SCole Faust
1243*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1244*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fixture
1245*c217d954SCole Faust {
1246*c217d954SCole Faust public:
1247*c217d954SCole Faust template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1248*c217d954SCole Faust void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1249*c217d954SCole Faust unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1250*c217d954SCole Faust {
1251*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1252*c217d954SCole Faust lhs_info.m0 = m0;
1253*c217d954SCole Faust lhs_info.k0 = k0;
1254*c217d954SCole Faust
1255*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1256*c217d954SCole Faust rhs_info.n0 = n0;
1257*c217d954SCole Faust rhs_info.k0 = k0;
1258*c217d954SCole Faust rhs_info.h0 = h0;
1259*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
1260*c217d954SCole Faust rhs_info.transpose = transpose_rhs;
1261*c217d954SCole Faust
1262*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1263*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1264*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1265*c217d954SCole Faust
1266*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
1267*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, data_type);
1268*c217d954SCole Faust }
1269*c217d954SCole Faust
1270*c217d954SCole Faust protected:
1271*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1272*c217d954SCole Faust void fill(U &&tensor, int i)
1273*c217d954SCole Faust {
1274*c217d954SCole Faust switch(tensor.data_type())
1275*c217d954SCole Faust {
1276*c217d954SCole Faust case DataType::QASYMM8:
1277*c217d954SCole Faust {
1278*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1279*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1280*c217d954SCole Faust library->fill(tensor, distribution, i);
1281*c217d954SCole Faust }
1282*c217d954SCole Faust break;
1283*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1284*c217d954SCole Faust {
1285*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
1286*c217d954SCole Faust library->fill(tensor, distribution, i);
1287*c217d954SCole Faust }
1288*c217d954SCole Faust break;
1289*c217d954SCole Faust default:
1290*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1291*c217d954SCole Faust }
1292*c217d954SCole Faust }
1293*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)1294*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1295*c217d954SCole Faust const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
1296*c217d954SCole Faust {
1297*c217d954SCole Faust // Create tensors
1298*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1299*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1300*c217d954SCole Faust TensorType rhs_reshaped;
1301*c217d954SCole Faust TensorType dst;
1302*c217d954SCole Faust
1303*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1304*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1305*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1306*c217d954SCole Faust
1307*c217d954SCole Faust GEMMKernelInfo gemm_info;
1308*c217d954SCole Faust gemm_info.m = M;
1309*c217d954SCole Faust gemm_info.n = N;
1310*c217d954SCole Faust gemm_info.k = K;
1311*c217d954SCole Faust gemm_info.lhs_info = lhs_info;
1312*c217d954SCole Faust gemm_info.rhs_info = rhs_info;
1313*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1314*c217d954SCole Faust
1315*c217d954SCole Faust // Create and configure function
1316*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1317*c217d954SCole Faust GEMMFunctionType gemm;
1318*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1319*c217d954SCole Faust gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info);
1320*c217d954SCole Faust
1321*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1322*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1323*c217d954SCole Faust
1324*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &rhs_reshaped, &dst });
1325*c217d954SCole Faust
1326*c217d954SCole Faust // Allocate tensors
1327*c217d954SCole Faust lhs.allocator()->allocate();
1328*c217d954SCole Faust rhs.allocator()->allocate();
1329*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1330*c217d954SCole Faust dst.allocator()->allocate();
1331*c217d954SCole Faust
1332*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1333*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1334*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1335*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1336*c217d954SCole Faust
1337*c217d954SCole Faust // Fill tensors
1338*c217d954SCole Faust fill(AccessorType(lhs), 0);
1339*c217d954SCole Faust fill(AccessorType(rhs), 1);
1340*c217d954SCole Faust
1341*c217d954SCole Faust // Compute GEMM
1342*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1343*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1344*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1345*c217d954SCole Faust gemm.run(gemm_pack);
1346*c217d954SCole Faust
1347*c217d954SCole Faust return dst;
1348*c217d954SCole Faust }
1349*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1350*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1351*c217d954SCole Faust {
1352*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1353*c217d954SCole Faust dst_shape[0] = rhs_shape[0];
1354*c217d954SCole Faust dst_shape[1] = lhs_shape[1];
1355*c217d954SCole Faust
1356*c217d954SCole Faust if(data_type == DataType::QASYMM8)
1357*c217d954SCole Faust {
1358*c217d954SCole Faust // Create reference
1359*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1360*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1361*c217d954SCole Faust
1362*c217d954SCole Faust // Fill reference
1363*c217d954SCole Faust fill(lhs, 0);
1364*c217d954SCole Faust fill(rhs, 1);
1365*c217d954SCole Faust
1366*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1367*c217d954SCole Faust }
1368*c217d954SCole Faust else
1369*c217d954SCole Faust {
1370*c217d954SCole Faust // Create reference
1371*c217d954SCole Faust SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1372*c217d954SCole Faust SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1373*c217d954SCole Faust
1374*c217d954SCole Faust // Fill reference
1375*c217d954SCole Faust fill(lhs, 0);
1376*c217d954SCole Faust fill(rhs, 1);
1377*c217d954SCole Faust
1378*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1379*c217d954SCole Faust }
1380*c217d954SCole Faust }
1381*c217d954SCole Faust
1382*c217d954SCole Faust TensorType _target{};
1383*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1384*c217d954SCole Faust };
1385*c217d954SCole Faust
1386*c217d954SCole Faust template <typename T, typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType, typename ReduceOperation, typename CastOperation>
1387*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture : public framework::Fixture
1388*c217d954SCole Faust {
1389*c217d954SCole Faust public:
1390*c217d954SCole Faust template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,bool broadcast_bias,DataType data_type)1391*c217d954SCole Faust void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1392*c217d954SCole Faust unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, bool broadcast_bias, DataType data_type)
1393*c217d954SCole Faust {
1394*c217d954SCole Faust GEMMLowpOutputStageInfo output_stage;
1395*c217d954SCole Faust output_stage.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
1396*c217d954SCole Faust output_stage.output_data_type = data_type;
1397*c217d954SCole Faust output_stage.gemmlowp_multipliers = std::vector<int32_t> { 1 };
1398*c217d954SCole Faust output_stage.gemmlowp_shifts = std::vector<int32_t> { 1 };
1399*c217d954SCole Faust output_stage.gemmlowp_multipliers[0] = 1;
1400*c217d954SCole Faust output_stage.gemmlowp_shifts[0] = 1;
1401*c217d954SCole Faust output_stage.gemmlowp_offset = 0;
1402*c217d954SCole Faust constexpr float scale = 0.001f;
1403*c217d954SCole Faust quantization::calculate_quantized_multiplier(scale, &output_stage.gemmlowp_multipliers[0], &output_stage.gemmlowp_shifts[0]);
1404*c217d954SCole Faust output_stage.gemmlowp_min_bound = -100;
1405*c217d954SCole Faust output_stage.gemmlowp_max_bound = 100;
1406*c217d954SCole Faust
1407*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1408*c217d954SCole Faust lhs_info.m0 = m0;
1409*c217d954SCole Faust lhs_info.k0 = k0;
1410*c217d954SCole Faust
1411*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1412*c217d954SCole Faust rhs_info.n0 = n0;
1413*c217d954SCole Faust rhs_info.k0 = k0;
1414*c217d954SCole Faust rhs_info.h0 = h0;
1415*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
1416*c217d954SCole Faust rhs_info.transpose = transpose_rhs;
1417*c217d954SCole Faust
1418*c217d954SCole Faust int a_offset = 1;
1419*c217d954SCole Faust int b_offset = 1;
1420*c217d954SCole Faust
1421*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1422*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1423*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1424*c217d954SCole Faust const TensorShape bias_shape(n,
1425*c217d954SCole Faust broadcast_bias ? 1 : m,
1426*c217d954SCole Faust broadcast_bias ? 1 : batch_size);
1427*c217d954SCole Faust
1428*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset);
1429*c217d954SCole Faust if(gemm_validated == true)
1430*c217d954SCole Faust {
1431*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset);
1432*c217d954SCole Faust }
1433*c217d954SCole Faust }
1434*c217d954SCole Faust
1435*c217d954SCole Faust protected:
1436*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1437*c217d954SCole Faust void fill(U &&tensor, int i)
1438*c217d954SCole Faust {
1439*c217d954SCole Faust switch(tensor.data_type())
1440*c217d954SCole Faust {
1441*c217d954SCole Faust case DataType::QASYMM8:
1442*c217d954SCole Faust {
1443*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1444*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1445*c217d954SCole Faust library->fill(tensor, distribution, i);
1446*c217d954SCole Faust }
1447*c217d954SCole Faust break;
1448*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1449*c217d954SCole Faust {
1450*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
1451*c217d954SCole Faust library->fill(tensor, distribution, i);
1452*c217d954SCole Faust }
1453*c217d954SCole Faust break;
1454*c217d954SCole Faust case DataType::S32:
1455*c217d954SCole Faust {
1456*c217d954SCole Faust std::uniform_int_distribution<> distribution(-10000, 10000);
1457*c217d954SCole Faust library->fill(tensor, distribution, i);
1458*c217d954SCole Faust }
1459*c217d954SCole Faust break;
1460*c217d954SCole Faust default:
1461*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1462*c217d954SCole Faust }
1463*c217d954SCole Faust }
1464*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & bias_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type,GEMMLowpOutputStageInfo output_stage,const int a_offset,const int b_offset)1465*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info,
1466*c217d954SCole Faust const GEMMRHSMatrixInfo &rhs_info, DataType data_type, GEMMLowpOutputStageInfo output_stage, const int a_offset, const int b_offset)
1467*c217d954SCole Faust {
1468*c217d954SCole Faust // Create tensors
1469*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset));
1470*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset));
1471*c217d954SCole Faust TensorType bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
1472*c217d954SCole Faust TensorType dst;
1473*c217d954SCole Faust TensorType rhs_reshaped;
1474*c217d954SCole Faust
1475*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1476*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1477*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1478*c217d954SCole Faust
1479*c217d954SCole Faust // Tensors for precomputing sum of lhs rows / rhs columns
1480*c217d954SCole Faust TensorType vec_sum_rows = create_tensor<TensorType>(TensorShape(M, 1, lhs_shape[2]), DataType::S32, 1);
1481*c217d954SCole Faust TensorType vec_sum_cols = create_tensor<TensorType>(TensorShape(N, 1, rhs_shape[2]), DataType::S32, 1);
1482*c217d954SCole Faust
1483*c217d954SCole Faust GEMMKernelInfo gemm_info;
1484*c217d954SCole Faust gemm_info.m = M;
1485*c217d954SCole Faust gemm_info.n = N;
1486*c217d954SCole Faust gemm_info.k = K;
1487*c217d954SCole Faust gemm_info.lhs_info = lhs_info;
1488*c217d954SCole Faust gemm_info.rhs_info = rhs_info;
1489*c217d954SCole Faust gemm_info.output_stage = output_stage;
1490*c217d954SCole Faust gemm_info.a_offset = a_offset;
1491*c217d954SCole Faust gemm_info.b_offset = b_offset;
1492*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1493*c217d954SCole Faust
1494*c217d954SCole Faust // Create and configure function
1495*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1496*c217d954SCole Faust GEMMFunctionType gemm;
1497*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1498*c217d954SCole Faust
1499*c217d954SCole Faust // If GEMM is not validated, do not try to run. The validation will check
1500*c217d954SCole Faust // if the technology supports this extension. If not, the test will be skipped.
1501*c217d954SCole Faust // If it supports, the test will fail anyway because target and reference
1502*c217d954SCole Faust // will not match.
1503*c217d954SCole Faust gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info()));
1504*c217d954SCole Faust if(gemm_validated == true)
1505*c217d954SCole Faust {
1506*c217d954SCole Faust gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info());
1507*c217d954SCole Faust
1508*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1509*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1510*c217d954SCole Faust ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
1511*c217d954SCole Faust
1512*c217d954SCole Faust // Allocate tensors
1513*c217d954SCole Faust lhs.allocator()->allocate();
1514*c217d954SCole Faust rhs.allocator()->allocate();
1515*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1516*c217d954SCole Faust bias.allocator()->allocate();
1517*c217d954SCole Faust vec_sum_cols.allocator()->allocate();
1518*c217d954SCole Faust vec_sum_rows.allocator()->allocate();
1519*c217d954SCole Faust dst.allocator()->allocate();
1520*c217d954SCole Faust
1521*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1522*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1523*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1524*c217d954SCole Faust ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
1525*c217d954SCole Faust ARM_COMPUTE_ASSERT(!vec_sum_cols.info()->is_resizable());
1526*c217d954SCole Faust ARM_COMPUTE_ASSERT(!vec_sum_rows.info()->is_resizable());
1527*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1528*c217d954SCole Faust
1529*c217d954SCole Faust // Fill tensors
1530*c217d954SCole Faust fill(AccessorType(lhs), 0);
1531*c217d954SCole Faust fill(AccessorType(rhs), 1);
1532*c217d954SCole Faust fill(AccessorType(bias), 2);
1533*c217d954SCole Faust
1534*c217d954SCole Faust TensorType lhs_32 = create_tensor<TensorType>(lhs_shape, DataType::S32, 1);
1535*c217d954SCole Faust TensorType rhs_32 = create_tensor<TensorType>(rhs_shape, DataType::S32, 1);
1536*c217d954SCole Faust CastOperation cast_lhs;
1537*c217d954SCole Faust CastOperation cast_rhs;
1538*c217d954SCole Faust cast_lhs.configure(&lhs, &lhs_32, ConvertPolicy::SATURATE);
1539*c217d954SCole Faust cast_rhs.configure(&rhs, &rhs_32, ConvertPolicy::SATURATE);
1540*c217d954SCole Faust lhs_32.allocator()->allocate();
1541*c217d954SCole Faust rhs_32.allocator()->allocate();
1542*c217d954SCole Faust cast_lhs.run();
1543*c217d954SCole Faust cast_rhs.run();
1544*c217d954SCole Faust
1545*c217d954SCole Faust ReduceOperation lhs_sum_rows;
1546*c217d954SCole Faust ReduceOperation rhs_sum_cols;
1547*c217d954SCole Faust
1548*c217d954SCole Faust lhs_sum_rows.configure(&lhs_32, &vec_sum_rows, 0, ReductionOperation::SUM, false);
1549*c217d954SCole Faust rhs_sum_cols.configure(&rhs_32, &vec_sum_cols, 1, ReductionOperation::SUM);
1550*c217d954SCole Faust
1551*c217d954SCole Faust lhs_sum_rows.run();
1552*c217d954SCole Faust rhs_sum_cols.run();
1553*c217d954SCole Faust
1554*c217d954SCole Faust // Compute GEMM
1555*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1556*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1557*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_SRC_2, &bias }, { ACL_DST, &dst }, { ACL_VEC_COL_SUM, &vec_sum_cols }, { ACL_VEC_ROW_SUM, &vec_sum_rows } });
1558*c217d954SCole Faust gemm.run(gemm_pack);
1559*c217d954SCole Faust }
1560*c217d954SCole Faust
1561*c217d954SCole Faust return dst;
1562*c217d954SCole Faust }
1563*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const TensorShape & bias_shape,DataType data_type,GEMMLowpOutputStageInfo output_stage,const int a_offset,const int b_offset)1564*c217d954SCole Faust SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, GEMMLowpOutputStageInfo output_stage,
1565*c217d954SCole Faust const int a_offset, const int b_offset)
1566*c217d954SCole Faust {
1567*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1568*c217d954SCole Faust dst_shape[0] = rhs_shape[0];
1569*c217d954SCole Faust dst_shape[1] = lhs_shape[1];
1570*c217d954SCole Faust
1571*c217d954SCole Faust // Create reference
1572*c217d954SCole Faust SimpleTensor<T> lhs{ lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset) };
1573*c217d954SCole Faust SimpleTensor<T> rhs{ rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset) };
1574*c217d954SCole Faust SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
1575*c217d954SCole Faust SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1576*c217d954SCole Faust SimpleTensor<T> dst_final{ dst_shape, data_type, 1 };
1577*c217d954SCole Faust
1578*c217d954SCole Faust // Fill reference
1579*c217d954SCole Faust fill(lhs, 0);
1580*c217d954SCole Faust fill(rhs, 1);
1581*c217d954SCole Faust fill(bias, 2);
1582*c217d954SCole Faust
1583*c217d954SCole Faust dst = reference::gemmlowp_matrix_multiply_core<int32_t, T>(lhs, rhs, dst_shape, a_offset, b_offset);
1584*c217d954SCole Faust dst_final = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, T>(dst, bias,
1585*c217d954SCole Faust output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
1586*c217d954SCole Faust return dst_final;
1587*c217d954SCole Faust }
1588*c217d954SCole Faust
1589*c217d954SCole Faust bool gemm_validated = true;
1590*c217d954SCole Faust TensorType _target{};
1591*c217d954SCole Faust SimpleTensor<T> _reference{};
1592*c217d954SCole Faust };
1593*c217d954SCole Faust
1594*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1595*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULValidationFixture : public framework::Fixture
1596*c217d954SCole Faust {
1597*c217d954SCole Faust public:
1598*c217d954SCole Faust template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1599*c217d954SCole Faust void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1600*c217d954SCole Faust unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1601*c217d954SCole Faust {
1602*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1603*c217d954SCole Faust lhs_info.m0 = m0;
1604*c217d954SCole Faust lhs_info.k0 = k0;
1605*c217d954SCole Faust
1606*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1607*c217d954SCole Faust rhs_info.n0 = n0;
1608*c217d954SCole Faust rhs_info.k0 = k0;
1609*c217d954SCole Faust rhs_info.h0 = h0;
1610*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
1611*c217d954SCole Faust rhs_info.transpose = transpose_rhs;
1612*c217d954SCole Faust
1613*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1614*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1615*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1616*c217d954SCole Faust
1617*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
1618*c217d954SCole Faust if(gemm_validated == true)
1619*c217d954SCole Faust {
1620*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, data_type);
1621*c217d954SCole Faust }
1622*c217d954SCole Faust }
1623*c217d954SCole Faust
1624*c217d954SCole Faust protected:
1625*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1626*c217d954SCole Faust void fill(U &&tensor, int i)
1627*c217d954SCole Faust {
1628*c217d954SCole Faust switch(tensor.data_type())
1629*c217d954SCole Faust {
1630*c217d954SCole Faust case DataType::QASYMM8:
1631*c217d954SCole Faust {
1632*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1633*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1634*c217d954SCole Faust library->fill(tensor, distribution, i);
1635*c217d954SCole Faust }
1636*c217d954SCole Faust break;
1637*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1638*c217d954SCole Faust {
1639*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
1640*c217d954SCole Faust library->fill(tensor, distribution, i);
1641*c217d954SCole Faust }
1642*c217d954SCole Faust break;
1643*c217d954SCole Faust default:
1644*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1645*c217d954SCole Faust }
1646*c217d954SCole Faust }
1647*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,DataType data_type)1648*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1649*c217d954SCole Faust const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
1650*c217d954SCole Faust {
1651*c217d954SCole Faust // Create tensors
1652*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1653*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1654*c217d954SCole Faust TensorType rhs_reshaped;
1655*c217d954SCole Faust TensorType dst;
1656*c217d954SCole Faust
1657*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1658*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1659*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1660*c217d954SCole Faust
1661*c217d954SCole Faust GEMMKernelInfo gemm_info;
1662*c217d954SCole Faust gemm_info.m = M;
1663*c217d954SCole Faust gemm_info.n = N;
1664*c217d954SCole Faust gemm_info.k = K;
1665*c217d954SCole Faust gemm_info.lhs_info = lhs_info;
1666*c217d954SCole Faust gemm_info.rhs_info = rhs_info;
1667*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1668*c217d954SCole Faust
1669*c217d954SCole Faust // Create and configure function
1670*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1671*c217d954SCole Faust GEMMFunctionType gemm;
1672*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1673*c217d954SCole Faust
1674*c217d954SCole Faust // If GEMM is not validated, do not try to run. The validation will check
1675*c217d954SCole Faust // if the technology supports this extension. If not, the test will be skipped.
1676*c217d954SCole Faust // If it supports, the test will fail anyway because target and reference
1677*c217d954SCole Faust // will not match.
1678*c217d954SCole Faust gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr));
1679*c217d954SCole Faust if(gemm_validated == true)
1680*c217d954SCole Faust {
1681*c217d954SCole Faust gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr);
1682*c217d954SCole Faust
1683*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1684*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1685*c217d954SCole Faust
1686*c217d954SCole Faust // Allocate tensors
1687*c217d954SCole Faust lhs.allocator()->allocate();
1688*c217d954SCole Faust rhs.allocator()->allocate();
1689*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1690*c217d954SCole Faust dst.allocator()->allocate();
1691*c217d954SCole Faust
1692*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1693*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1694*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1695*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1696*c217d954SCole Faust
1697*c217d954SCole Faust // Fill tensors
1698*c217d954SCole Faust fill(AccessorType(lhs), 0);
1699*c217d954SCole Faust fill(AccessorType(rhs), 1);
1700*c217d954SCole Faust
1701*c217d954SCole Faust // Compute GEMM
1702*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1703*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1704*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1705*c217d954SCole Faust gemm.run(gemm_pack);
1706*c217d954SCole Faust }
1707*c217d954SCole Faust
1708*c217d954SCole Faust return dst;
1709*c217d954SCole Faust }
1710*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,DataType data_type)1711*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
1712*c217d954SCole Faust {
1713*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1714*c217d954SCole Faust dst_shape[0] = rhs_shape[0];
1715*c217d954SCole Faust dst_shape[1] = lhs_shape[1];
1716*c217d954SCole Faust
1717*c217d954SCole Faust if(data_type == DataType::QASYMM8)
1718*c217d954SCole Faust {
1719*c217d954SCole Faust // Create reference
1720*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1721*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1722*c217d954SCole Faust SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1723*c217d954SCole Faust
1724*c217d954SCole Faust // Fill reference
1725*c217d954SCole Faust fill(lhs, 0);
1726*c217d954SCole Faust fill(rhs, 1);
1727*c217d954SCole Faust
1728*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1729*c217d954SCole Faust }
1730*c217d954SCole Faust else
1731*c217d954SCole Faust {
1732*c217d954SCole Faust // Create reference
1733*c217d954SCole Faust SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1734*c217d954SCole Faust SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1735*c217d954SCole Faust SimpleTensor<int32_t> dst{ dst_shape, DataType::S32, 1 };
1736*c217d954SCole Faust
1737*c217d954SCole Faust // Fill reference
1738*c217d954SCole Faust fill(lhs, 0);
1739*c217d954SCole Faust fill(rhs, 1);
1740*c217d954SCole Faust
1741*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1742*c217d954SCole Faust }
1743*c217d954SCole Faust }
1744*c217d954SCole Faust
1745*c217d954SCole Faust bool gemm_validated = true;
1746*c217d954SCole Faust TensorType _target{};
1747*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1748*c217d954SCole Faust };
1749*c217d954SCole Faust
1750*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename ReshapeRHSOperatorType, typename GEMMFunctionType>
1751*c217d954SCole Faust class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
1752*c217d954SCole Faust {
1753*c217d954SCole Faust public:
1754*c217d954SCole Faust template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0,unsigned int h0,bool interleave_rhs,bool transpose_rhs,DataType data_type)1755*c217d954SCole Faust void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
1756*c217d954SCole Faust unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
1757*c217d954SCole Faust {
1758*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1759*c217d954SCole Faust lhs_info.m0 = m0;
1760*c217d954SCole Faust lhs_info.k0 = k0;
1761*c217d954SCole Faust
1762*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1763*c217d954SCole Faust rhs_info.n0 = n0;
1764*c217d954SCole Faust rhs_info.k0 = k0;
1765*c217d954SCole Faust rhs_info.h0 = h0;
1766*c217d954SCole Faust rhs_info.interleave = interleave_rhs;
1767*c217d954SCole Faust rhs_info.transpose = transpose_rhs;
1768*c217d954SCole Faust
1769*c217d954SCole Faust // In case of GEMM3D, m is the product between m_w and m_h
1770*c217d954SCole Faust const unsigned int m = m_w * m_h;
1771*c217d954SCole Faust
1772*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1773*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1774*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1775*c217d954SCole Faust
1776*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h, data_type);
1777*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, m_h, data_type);
1778*c217d954SCole Faust }
1779*c217d954SCole Faust
1780*c217d954SCole Faust protected:
1781*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1782*c217d954SCole Faust void fill(U &&tensor, int i)
1783*c217d954SCole Faust {
1784*c217d954SCole Faust switch(tensor.data_type())
1785*c217d954SCole Faust {
1786*c217d954SCole Faust case DataType::QASYMM8:
1787*c217d954SCole Faust {
1788*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1789*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1790*c217d954SCole Faust library->fill(tensor, distribution, i);
1791*c217d954SCole Faust }
1792*c217d954SCole Faust break;
1793*c217d954SCole Faust case DataType::QASYMM8_SIGNED:
1794*c217d954SCole Faust {
1795*c217d954SCole Faust std::uniform_int_distribution<> distribution(-127, 126);
1796*c217d954SCole Faust library->fill(tensor, distribution, i);
1797*c217d954SCole Faust }
1798*c217d954SCole Faust break;
1799*c217d954SCole Faust default:
1800*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported data type");
1801*c217d954SCole Faust }
1802*c217d954SCole Faust }
1803*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h,DataType data_type)1804*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
1805*c217d954SCole Faust const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h, DataType data_type)
1806*c217d954SCole Faust {
1807*c217d954SCole Faust // Create tensors
1808*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
1809*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
1810*c217d954SCole Faust TensorType rhs_reshaped;
1811*c217d954SCole Faust TensorType dst;
1812*c217d954SCole Faust
1813*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1814*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1815*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1816*c217d954SCole Faust
1817*c217d954SCole Faust GEMMKernelInfo gemm_info;
1818*c217d954SCole Faust gemm_info.m = M;
1819*c217d954SCole Faust gemm_info.n = N;
1820*c217d954SCole Faust gemm_info.k = K;
1821*c217d954SCole Faust gemm_info.depth_output_gemm3d = m_h;
1822*c217d954SCole Faust gemm_info.lhs_info = lhs_info;
1823*c217d954SCole Faust gemm_info.rhs_info = rhs_info;
1824*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1825*c217d954SCole Faust
1826*c217d954SCole Faust // Create and configure function
1827*c217d954SCole Faust ReshapeRHSOperatorType reshape_rhs;
1828*c217d954SCole Faust GEMMFunctionType gemm;
1829*c217d954SCole Faust reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
1830*c217d954SCole Faust gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info);
1831*c217d954SCole Faust
1832*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1833*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1834*c217d954SCole Faust
1835*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &rhs_reshaped, &dst });
1836*c217d954SCole Faust
1837*c217d954SCole Faust // Allocate tensors
1838*c217d954SCole Faust lhs.allocator()->allocate();
1839*c217d954SCole Faust rhs.allocator()->allocate();
1840*c217d954SCole Faust rhs_reshaped.allocator()->allocate();
1841*c217d954SCole Faust dst.allocator()->allocate();
1842*c217d954SCole Faust
1843*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1844*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1845*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
1846*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1847*c217d954SCole Faust
1848*c217d954SCole Faust // Fill tensors
1849*c217d954SCole Faust fill(AccessorType(lhs), 0);
1850*c217d954SCole Faust fill(AccessorType(rhs), 1);
1851*c217d954SCole Faust
1852*c217d954SCole Faust // Compute GEMM
1853*c217d954SCole Faust ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
1854*c217d954SCole Faust reshape_rhs.run(reshape_rhs_pack);
1855*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } });
1856*c217d954SCole Faust gemm.run(gemm_pack);
1857*c217d954SCole Faust
1858*c217d954SCole Faust return dst;
1859*c217d954SCole Faust }
1860*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h,DataType data_type)1861*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h, DataType data_type)
1862*c217d954SCole Faust {
1863*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1864*c217d954SCole Faust dst_shape.set(0, rhs_shape[0]);
1865*c217d954SCole Faust dst_shape.set(1, lhs_shape[1] / m_h);
1866*c217d954SCole Faust dst_shape.set(2, m_h);
1867*c217d954SCole Faust dst_shape.set(3, lhs_shape[2]);
1868*c217d954SCole Faust
1869*c217d954SCole Faust if(data_type == DataType::QASYMM8)
1870*c217d954SCole Faust {
1871*c217d954SCole Faust // Create reference
1872*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
1873*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
1874*c217d954SCole Faust
1875*c217d954SCole Faust // Fill reference
1876*c217d954SCole Faust fill(lhs, 0);
1877*c217d954SCole Faust fill(rhs, 1);
1878*c217d954SCole Faust
1879*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1880*c217d954SCole Faust }
1881*c217d954SCole Faust else
1882*c217d954SCole Faust {
1883*c217d954SCole Faust // Create reference
1884*c217d954SCole Faust SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
1885*c217d954SCole Faust SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
1886*c217d954SCole Faust
1887*c217d954SCole Faust // Fill reference
1888*c217d954SCole Faust fill(lhs, 0);
1889*c217d954SCole Faust fill(rhs, 1);
1890*c217d954SCole Faust
1891*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
1892*c217d954SCole Faust }
1893*c217d954SCole Faust }
1894*c217d954SCole Faust
1895*c217d954SCole Faust TensorType _target{};
1896*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1897*c217d954SCole Faust };
1898*c217d954SCole Faust
1899*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename GEMMFunctionType>
1900*c217d954SCole Faust class GEMMLowpMatrixMultiplyNativeValidationFixture : public framework::Fixture
1901*c217d954SCole Faust {
1902*c217d954SCole Faust public:
1903*c217d954SCole Faust template <typename...>
setup(unsigned int m,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0)1904*c217d954SCole Faust void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0)
1905*c217d954SCole Faust {
1906*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
1907*c217d954SCole Faust lhs_info.m0 = m0;
1908*c217d954SCole Faust lhs_info.k0 = k0;
1909*c217d954SCole Faust
1910*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
1911*c217d954SCole Faust rhs_info.n0 = n0;
1912*c217d954SCole Faust rhs_info.k0 = k0;
1913*c217d954SCole Faust
1914*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
1915*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
1916*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
1917*c217d954SCole Faust
1918*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info);
1919*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape);
1920*c217d954SCole Faust }
1921*c217d954SCole Faust
1922*c217d954SCole Faust protected:
1923*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)1924*c217d954SCole Faust void fill(U &&tensor, int i)
1925*c217d954SCole Faust {
1926*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
1927*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
1928*c217d954SCole Faust library->fill(tensor, distribution, i);
1929*c217d954SCole Faust }
1930*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info)1931*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
1932*c217d954SCole Faust {
1933*c217d954SCole Faust // Create tensors
1934*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
1935*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
1936*c217d954SCole Faust TensorType dst;
1937*c217d954SCole Faust
1938*c217d954SCole Faust const unsigned int M = lhs_shape[1];
1939*c217d954SCole Faust const unsigned int N = rhs_shape[0];
1940*c217d954SCole Faust const unsigned int K = lhs_shape[0];
1941*c217d954SCole Faust
1942*c217d954SCole Faust // The output tensor will be auto-initialized within the function
1943*c217d954SCole Faust
1944*c217d954SCole Faust // Create and configure function
1945*c217d954SCole Faust GEMMFunctionType gemm;
1946*c217d954SCole Faust gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
1947*c217d954SCole Faust
1948*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
1949*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
1950*c217d954SCole Faust
1951*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &dst });
1952*c217d954SCole Faust
1953*c217d954SCole Faust // Allocate tensors
1954*c217d954SCole Faust lhs.allocator()->allocate();
1955*c217d954SCole Faust rhs.allocator()->allocate();
1956*c217d954SCole Faust dst.allocator()->allocate();
1957*c217d954SCole Faust
1958*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
1959*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
1960*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
1961*c217d954SCole Faust
1962*c217d954SCole Faust // Fill tensors
1963*c217d954SCole Faust fill(AccessorType(lhs), 0);
1964*c217d954SCole Faust fill(AccessorType(rhs), 1);
1965*c217d954SCole Faust
1966*c217d954SCole Faust // Compute GEMM
1967*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } });
1968*c217d954SCole Faust gemm.run(gemm_pack);
1969*c217d954SCole Faust
1970*c217d954SCole Faust return dst;
1971*c217d954SCole Faust }
1972*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape)1973*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape)
1974*c217d954SCole Faust {
1975*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
1976*c217d954SCole Faust dst_shape[0] = rhs_shape[0];
1977*c217d954SCole Faust dst_shape[1] = lhs_shape[1];
1978*c217d954SCole Faust
1979*c217d954SCole Faust // Create reference
1980*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
1981*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
1982*c217d954SCole Faust
1983*c217d954SCole Faust // Fill reference
1984*c217d954SCole Faust fill(lhs, 0);
1985*c217d954SCole Faust fill(rhs, 1);
1986*c217d954SCole Faust
1987*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
1988*c217d954SCole Faust }
1989*c217d954SCole Faust
1990*c217d954SCole Faust TensorType _target{};
1991*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
1992*c217d954SCole Faust };
1993*c217d954SCole Faust
1994*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename GEMMFunctionType>
1995*c217d954SCole Faust class GEMMLowpMatrixMultiplyNative3DValidationFixture : public framework::Fixture
1996*c217d954SCole Faust {
1997*c217d954SCole Faust public:
1998*c217d954SCole Faust template <typename...>
setup(unsigned int m_w,unsigned int m_h,unsigned int n,unsigned int k,unsigned int batch_size,unsigned int m0,unsigned int n0,unsigned int k0)1999*c217d954SCole Faust void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0)
2000*c217d954SCole Faust {
2001*c217d954SCole Faust GEMMLHSMatrixInfo lhs_info;
2002*c217d954SCole Faust lhs_info.m0 = m0;
2003*c217d954SCole Faust lhs_info.k0 = k0;
2004*c217d954SCole Faust
2005*c217d954SCole Faust GEMMRHSMatrixInfo rhs_info;
2006*c217d954SCole Faust rhs_info.n0 = n0;
2007*c217d954SCole Faust rhs_info.k0 = k0;
2008*c217d954SCole Faust
2009*c217d954SCole Faust // In case of GEMM3D, m is the product between m_w and m_h
2010*c217d954SCole Faust const unsigned int m = m_w * m_h;
2011*c217d954SCole Faust
2012*c217d954SCole Faust // Set the tensor shapes for LHS and RHS matrices
2013*c217d954SCole Faust const TensorShape lhs_shape(k, m, batch_size);
2014*c217d954SCole Faust const TensorShape rhs_shape(n, k, batch_size);
2015*c217d954SCole Faust
2016*c217d954SCole Faust _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h);
2017*c217d954SCole Faust _reference = compute_reference(lhs_shape, rhs_shape, m_h);
2018*c217d954SCole Faust }
2019*c217d954SCole Faust
2020*c217d954SCole Faust protected:
2021*c217d954SCole Faust template <typename U>
fill(U && tensor,int i)2022*c217d954SCole Faust void fill(U &&tensor, int i)
2023*c217d954SCole Faust {
2024*c217d954SCole Faust // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
2025*c217d954SCole Faust std::uniform_int_distribution<> distribution(1, 254);
2026*c217d954SCole Faust library->fill(tensor, distribution, i);
2027*c217d954SCole Faust }
2028*c217d954SCole Faust
compute_target(const TensorShape & lhs_shape,const TensorShape & rhs_shape,const GEMMLHSMatrixInfo & lhs_info,const GEMMRHSMatrixInfo & rhs_info,unsigned int m_h)2029*c217d954SCole Faust TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h)
2030*c217d954SCole Faust {
2031*c217d954SCole Faust // Create tensors
2032*c217d954SCole Faust TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
2033*c217d954SCole Faust TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
2034*c217d954SCole Faust TensorType dst;
2035*c217d954SCole Faust
2036*c217d954SCole Faust const unsigned int M = lhs_shape[1];
2037*c217d954SCole Faust const unsigned int N = rhs_shape[0];
2038*c217d954SCole Faust const unsigned int K = lhs_shape[0];
2039*c217d954SCole Faust
2040*c217d954SCole Faust // The output tensor will be auto-initialized within the function
2041*c217d954SCole Faust
2042*c217d954SCole Faust // Create and configure function
2043*c217d954SCole Faust GEMMFunctionType gemm;
2044*c217d954SCole Faust gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
2045*c217d954SCole Faust
2046*c217d954SCole Faust ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
2047*c217d954SCole Faust ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
2048*c217d954SCole Faust
2049*c217d954SCole Faust add_padding_x({ &lhs, &rhs, &dst });
2050*c217d954SCole Faust
2051*c217d954SCole Faust // Allocate tensors
2052*c217d954SCole Faust lhs.allocator()->allocate();
2053*c217d954SCole Faust rhs.allocator()->allocate();
2054*c217d954SCole Faust dst.allocator()->allocate();
2055*c217d954SCole Faust
2056*c217d954SCole Faust ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
2057*c217d954SCole Faust ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
2058*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
2059*c217d954SCole Faust
2060*c217d954SCole Faust // Fill tensors
2061*c217d954SCole Faust fill(AccessorType(lhs), 0);
2062*c217d954SCole Faust fill(AccessorType(rhs), 1);
2063*c217d954SCole Faust
2064*c217d954SCole Faust // Compute GEMM
2065*c217d954SCole Faust ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } });
2066*c217d954SCole Faust gemm.run(gemm_pack);
2067*c217d954SCole Faust
2068*c217d954SCole Faust return dst;
2069*c217d954SCole Faust }
2070*c217d954SCole Faust
compute_reference(const TensorShape & lhs_shape,const TensorShape & rhs_shape,unsigned int m_h)2071*c217d954SCole Faust SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h)
2072*c217d954SCole Faust {
2073*c217d954SCole Faust TensorShape dst_shape = lhs_shape;
2074*c217d954SCole Faust dst_shape.set(0, rhs_shape[0]);
2075*c217d954SCole Faust dst_shape.set(1, lhs_shape[1] / m_h);
2076*c217d954SCole Faust dst_shape.set(2, m_h);
2077*c217d954SCole Faust dst_shape.set(3, lhs_shape[2]);
2078*c217d954SCole Faust
2079*c217d954SCole Faust // Create reference
2080*c217d954SCole Faust SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
2081*c217d954SCole Faust SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
2082*c217d954SCole Faust
2083*c217d954SCole Faust // Fill reference
2084*c217d954SCole Faust fill(lhs, 0);
2085*c217d954SCole Faust fill(rhs, 1);
2086*c217d954SCole Faust
2087*c217d954SCole Faust return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
2088*c217d954SCole Faust }
2089*c217d954SCole Faust
2090*c217d954SCole Faust TensorType _target{};
2091*c217d954SCole Faust SimpleTensor<int32_t> _reference{};
2092*c217d954SCole Faust };
2093*c217d954SCole Faust } // namespace validation
2094*c217d954SCole Faust } // namespace test
2095*c217d954SCole Faust } // namespace arm_compute
2096*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE */
2097