xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/CL/GEMMMatrixMultiplyNative.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/KernelDescriptors.h"
25 #include "arm_compute/core/Types.h"
26 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
27 #include "arm_compute/runtime/CL/CLTensor.h"
28 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
29 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
30 #include "tests/CL/CLAccessor.h"
31 #include "tests/CL/Helper.h"
32 #include "tests/PaddingCalculator.h"
33 #include "tests/datasets/ShapeDatasets.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Macros.h"
36 #include "tests/framework/datasets/Datasets.h"
37 #include "tests/validation/Validation.h"
38 #include "tests/validation/fixtures/GEMMFixture.h"
39 
40 namespace arm_compute
41 {
42 namespace test
43 {
44 namespace validation
45 {
46 using namespace arm_compute::misc::shape_calculator;
47 using namespace arm_compute::opencl::kernels;
48 
49 // Create function for ClGemmMatrixMultiplyNativeKernel
50 using CLGEMMMatrixMultiplyNative = CLSynthetizeOperator<ClGemmMatrixMultiplyNativeKernel>;
51 
52 // Fixture for CLGEMMMatrixMultiplyNative
53 template <typename T>
54 using CLGEMMMatrixMultiplyNativeFixture = GEMMMatrixMultiplyNativeValidationFixture<CLTensor, CLAccessor, T, CLGEMMMatrixMultiplyNative>;
55 
56 // Fixture for CLGEMMMatrixMultiplyNative with post ops
57 template <typename T>
58 using CLGEMMMatrixMultiplyNativeWithPostOpsFixture =
59     GEMMMatrixMultiplyNativeWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMMatrixMultiplyNative>;
60 
61 // Fixture for CLGEMMMatrixMultiplyNative3D
62 template <typename T>
63 using CLGEMMMatrixMultiplyNative3DFixture = GEMMMatrixMultiplyNative3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMMatrixMultiplyNative>;
64 
65 namespace
66 {
67 // *INDENT-OFF*
68 // clang-format off
69 RelativeTolerance<float> rel_tolerance_f32(0.001f);
70 constexpr float          abs_tolerance_f32(0.0001f);
71 
72 /** Alpha values to test - Precommit */
73 const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
74 
75 /** Beta values to test - Precommit */
76 const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
77 
78 /** M values to test */
79 const auto m_values = framework::dataset::make("M", 37);
80 
81 /** M_W values to test */
82 const auto m_w_values = framework::dataset::make("M_W", 5);
83 
84 /** M_H values to test */
85 const auto m_h_values = framework::dataset::make("M_H", 7);
86 
87 /** N values to test */
88 const auto n_values = framework::dataset::make("N", 51);
89 
90 /** K values to test */
91 const auto k_values = framework::dataset::make("K", 23);
92 
93 /** Batch size values to test */
94 const auto b_values = framework::dataset::make("batch_size", 1, 3);
95 
96 /** Activation values to test */
97 const auto act_values = framework::dataset::make("Activation",
98 {
99     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
100     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::ELU),
101 });
102 
103 /** M0 values to test - Precommit */
104 const auto m0_values_precommit = framework::dataset::make("M0", { 4, 6 });
105 
106 /** N0 values to test - Precommit */
107 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
108 
109 /** K0 values to test - Precommit */
110 const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
111 
112 /** H0 values to test - Precommit */
113 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
114 
115 /** M0 values to test - Nightly */
116 const auto m0_values_nightly = framework::dataset::make("M0", 1, 8);
117 
118 /** N0 values to test - Nightly */
119 const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
120 
121 /** K0 values to test - Nightly */
122 const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8 });
123 
124 /** Broadcast bias from vector to matrix */
125 const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
126 
127 /** Boundary handling cases for testing partial/non-partial (full) block dimensions, resulting from different combinations
128  * of M, M0, N and N0 values.
129  * M0 and N0 are kept constant, while the different test cases need to vary M and N.
130  *
131  * Eg. M = 64 and N = 33 result in a block dimension that has no partial blocks (all full blocks) in Y dimension and
132  * parital blocks in X dimension.
133  */
134 const auto boundary_handling_cases = combine(combine(combine(combine(combine(combine(combine(combine(combine(
135                                     // Large k to force potential out-of-bound reads on input0
136                                     framework::dataset::make("K", 315),
137                                     // Batch size == 1 to force potential out-of-bound reads on input0
138                                     framework::dataset::make("batch_size", 1)),
139                                     framework::dataset::make("M0", 4)),
140                                     framework::dataset::make("N0", 4)),
141                                     framework::dataset::make("K0", 4)),
142                                     // Only need to test F32 as F16 shares identical boundary handling logics
143                                     framework::dataset::make("DataType", DataType::F32)),
144                                     framework::dataset::make("alpha", -0.75f )),
145                                     framework::dataset::make("beta", -0.35f )),
146                                     broadcast_bias_values),
147                                     framework::dataset::make("Activation", ActivationLayerInfo()));
148 
149 /** Post Ops */
150 using PostOpArgBroadcast =  CLGEMMMatrixMultiplyNativeWithPostOpsFixture<float>::PostOpArgBroadcast;
post_ops_1()151 experimental::PostOpList<PostOpArgBroadcast> post_ops_1()
152 {
153     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
154     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
155     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
156         std::make_tuple(true, true, false),   // If broadcast in dims 0, 1 and 2
157         0,
158         ConvertPolicy::SATURATE);
159     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
160     return post_ops;
161 }
post_ops_2()162 experimental::PostOpList<PostOpArgBroadcast> post_ops_2()
163 {
164     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
165     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
166         std::make_tuple(false, true, true),   // If broadcast in dims 0, 1 and 2
167         1,
168         ConvertPolicy::SATURATE);
169     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
170     return post_ops;
171 }
post_ops_3()172 experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
173 {
174     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
175     // post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
176     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
177         std::make_tuple(false, false, false),  // If broadcast in dims 0, 1 and 2
178         1,
179         ConvertPolicy::SATURATE);
180     return post_ops;
181 }
182 // To test that the output of the main op is the first parameter in prelu post op
post_ops_4()183 experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
184 {
185     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
186     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
187     post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
188         std::make_tuple(false, false, true),   // If true, broadcast in corresponding dim: 0, 1 or 2
189         0,
190         ConvertPolicy::SATURATE);
191     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
192     return post_ops;
193 }
194 // To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
post_ops_5()195 experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
196 {
197     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
198     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
199     post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
200         std::make_tuple(false, false, false),   // If true, broadcast in corresponding dim: 0, 1 or 2
201         1,
202         ConvertPolicy::SATURATE);
203     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
204     return post_ops;
205 }
206 /** Different Post Op Lists */
207 const auto post_op_lists = framework::dataset::make("post_op_lists", {
208     post_ops_1(),
209     post_ops_2(),
210     post_ops_3(),
211     post_ops_4(),
212     post_ops_5()
213  } );
214 
is_post_op_list_valid(unsigned int m,unsigned int n,unsigned int k,unsigned int batch,DataType data_type,const experimental::PostOpList<ITensorInfo * > & post_ops)215 bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
216 {
217     const auto lhs_info = GEMMLHSMatrixInfo(4,4,1,false,true);
218     const auto rhs_info = GEMMRHSMatrixInfo(4,4,1,true,true,false);
219 
220     // Create TensorInfo for post op arguments
221     TensorInfo input0_info(TensorShape(k, m, batch), 1, data_type);
222     TensorInfo input1_info(TensorShape(n, k, batch), 1, data_type);
223     TensorInfo input2_info(TensorShape(n), 1, data_type);
224     TensorInfo output_info(TensorShape(n, m, batch), 1, data_type);
225 
226     GEMMKernelInfo gemm_info(m, n, k, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
227              false /**< reinterpret the input as 3D */,
228              true  /**< Flag used to broadcast the bias addition */,
229              false /**< wider accumm */,
230              false /**< has pad y */,
231            ActivationLayerInfo::ActivationFunction::IDENTITY,
232              1   /**< Multiplication factor for the width of the 1xW transposed block */,
233              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
234              lhs_info,
235              rhs_info,
236              0  /**< Offset to be added to each element of the matrix A */,
237              0 /**< Offset to be added to each element of the matrix B */,
238              post_ops);
239     return bool(ClGemmMatrixMultiplyNativeKernel::validate(&input0_info.clone()->set_is_resizable(true),
240                                                           &input1_info.clone()->set_is_resizable(true),
241                                                           &input2_info.clone()->set_is_resizable(true),
242                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
243                                                           lhs_info,
244                                                           rhs_info,
245                                                           gemm_info));
246 }
247 
248 /** Configuration test */
validate_configuration(unsigned int m_value,unsigned int n_value,unsigned int k_value,unsigned int b_value,unsigned int m0_value,unsigned int n0_value,unsigned int k0_value,bool broadcast_bias,DataType data_type,const ActivationLayerInfo & act_info)249 void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, bool broadcast_bias, DataType data_type, const ActivationLayerInfo &act_info)
250 {
251     const unsigned int M = m_value;
252     const unsigned int N = n_value;
253     const unsigned int K = k_value;
254 
255     GEMMLHSMatrixInfo lhs_info;
256     lhs_info.m0         = m0_value;
257     lhs_info.k0         = k0_value;
258 
259     GEMMRHSMatrixInfo rhs_info;
260     rhs_info.n0         = n0_value;
261     rhs_info.k0         = k0_value;
262 
263     GEMMKernelInfo kernel_info;
264     kernel_info.m               = M;
265     kernel_info.n               = N;
266     kernel_info.k               = K;
267     kernel_info.broadcast_bias  = broadcast_bias;
268     kernel_info.activation_info = act_info;
269 
270     const TensorShape lhs_shape(K, M, b_value);
271     const TensorShape rhs_shape(N, K, b_value);
272     const TensorShape bias_shape(N,
273                                  broadcast_bias? 1 : M,
274                                  broadcast_bias? 1 : b_value);
275     const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape, 1, data_type),
276                                                    TensorInfo(rhs_shape, 1, data_type),
277                                                    kernel_info);
278 
279     // Create tensors
280     CLTensor lhs  = create_tensor<CLTensor>(lhs_shape, data_type);
281     CLTensor rhs  = create_tensor<CLTensor>(rhs_shape, data_type);
282     CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
283     CLTensor dst  = create_tensor<CLTensor>(dst_shape, data_type);
284 
285     ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
286     ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
287     ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
288     ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
289 
290     // Create and configure function
291     CLGEMMMatrixMultiplyNative gemm;
292     gemm.configure(lhs.info(), rhs.info(), bias.info(), dst.info(), 1.0f, 1.0f, lhs_info, rhs_info, kernel_info);
293 }
294 } // namespace
295 
296 TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyNative)297 TEST_SUITE(GEMMMatrixMultiplyNative)
298 TEST_SUITE(ValidateFusedPostOpsConfigs)
299 TEST_SUITE(Invalid)
300 TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
301 {
302     const auto data_type = DataType::F32;
303     const unsigned int m = 17;
304     const unsigned int n = 1;
305     const unsigned int k = 13;
306     const unsigned int batch = 2;
307     TensorShape post_op_arg0_shape(n, m, batch);
308     TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
309     auto post_op_arg1_info = post_op_arg_info.clone();
310 
311     // Unsupported sequence of post ops
312     experimental::PostOpList<ITensorInfo*> post_ops{};
313     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
314         &post_op_arg_info,
315         1,
316         ConvertPolicy::SATURATE);
317     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
318         post_op_arg1_info.get(),
319         0,
320         ConvertPolicy::SATURATE);
321 
322     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
323 }
TEST_CASE(OutputWidened,framework::DatasetMode::ALL)324 TEST_CASE(OutputWidened, framework::DatasetMode::ALL)
325 {
326     // Invalid broadcast: post op tensors "widen" the output tensor
327     const auto data_type = DataType::F32;
328     const unsigned int m = 1;
329     const unsigned int n = 18;
330     const unsigned int k = 13;
331     const unsigned int batch = 2;
332     TensorShape post_op_arg_shape(n, m + 1, batch); // output's Y dimension (m) is "widened", which is not allowed
333     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
334     experimental::PostOpList<ITensorInfo*> post_ops{};
335     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
336 
337     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
338 }
TEST_CASE(BroadcastInXDimOnly,framework::DatasetMode::ALL)339 TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
340 {
341     // Invalid broadcast: post op tensors broadcast in the first dimension (X) only
342     const auto data_type = DataType::F32;
343     const unsigned int m = 22;
344     const unsigned int n = 16;
345     const unsigned int k = 15;
346     const unsigned int batch = 3;
347     TensorShape post_op_arg_shape(1, m, batch);
348     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
349     experimental::PostOpList<ITensorInfo*> post_ops{};
350     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
351 
352     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
353 }
354 TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)355 TEST_SUITE(Valid)
356 TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
357 {
358     const auto data_type = DataType::F32;
359     const unsigned int m = 22;
360     const unsigned int n = 16;
361     const unsigned int k = 15;
362     const unsigned int batch = 3;
363     experimental::PostOpList<ITensorInfo*> post_ops{};
364 
365     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
366 }
TEST_CASE(BroadcastInYDimOnly,framework::DatasetMode::ALL)367 TEST_CASE(BroadcastInYDimOnly, framework::DatasetMode::ALL)
368 {
369     const auto data_type = DataType::F32;
370     const unsigned int m = 22;
371     const unsigned int n = 16;
372     const unsigned int k = 15;
373     const unsigned int batch = 3;
374     TensorShape post_op_arg_shape(n, 1, batch);
375     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
376     experimental::PostOpList<ITensorInfo*> post_ops{};
377     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
378 
379     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
380 }
TEST_CASE(BroadcastInBothXandYDims,framework::DatasetMode::ALL)381 TEST_CASE(BroadcastInBothXandYDims, framework::DatasetMode::ALL)
382 {
383     const auto data_type = DataType::F32;
384     const unsigned int m = 22;
385     const unsigned int n = 16;
386     const unsigned int k = 15;
387     const unsigned int batch = 3;
388     TensorShape post_op_arg_shape(1, 1, batch);
389     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
390     experimental::PostOpList<ITensorInfo*> post_ops{};
391     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
392 
393     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
394 }
TEST_CASE(BroadcastInAllDims,framework::DatasetMode::ALL)395 TEST_CASE(BroadcastInAllDims, framework::DatasetMode::ALL)
396 {
397     const auto data_type = DataType::F32;
398     const unsigned int m = 22;
399     const unsigned int n = 16;
400     const unsigned int k = 15;
401     const unsigned int batch = 3;
402     TensorShape post_op_arg_shape(1, 1, 1);
403     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
404     experimental::PostOpList<ITensorInfo*> post_ops{};
405     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
406 
407     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
408 }
409 TEST_SUITE_END() // Valid
TEST_SUITE_END()410 TEST_SUITE_END() // ValidateFusedPostOps
411 TEST_SUITE(Float)
412 TEST_SUITE(FP32)
413 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(
414                                                                    m_values,
415                                                                    n_values),
416                                                                    k_values),
417                                                                    framework::dataset::make("batch_size", 1)),
418                                                                    m0_values_precommit),
419                                                                    n0_values_precommit),
420                                                                    k0_values_precommit),
421                                                                    broadcast_bias_values),
422                                                                    act_values),
423 m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, broadcast_bias, act_value)
424 {
425     validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, broadcast_bias, DataType::F32, act_value);
426 }
427 
428 FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingPartialInXPartialInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
429                 combine(combine(
430                         framework::dataset::make("M", 3),
431                         framework::dataset::make("N", 1)),
432                         boundary_handling_cases))
433 {
434     // Validate output
435     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
436 }
437 
438 FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingPartialInXFullInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
439                 combine(combine(
440                         framework::dataset::make("M", 64),
441                         framework::dataset::make("N", 51)),
442                         boundary_handling_cases))
443 {
444     // Validate output
445     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
446 }
447 
448 FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingFullInXFullInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
449                 combine(combine(
450                         framework::dataset::make("M", 64),
451                         framework::dataset::make("N", 32)),
452                         boundary_handling_cases))
453 {
454     // Validate output
455     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
456 }
457 
458 FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingFullInXPartialInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
459                 combine(combine(
460                         framework::dataset::make("M", 37),
461                         framework::dataset::make("N", 32)),
462                         boundary_handling_cases))
463 {
464     // Validate output
465     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
466 }
467 
468 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
469                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
470                                                                    m_values,
471                                                                    n_values),
472                                                                    k_values),
473                                                                    b_values),
474                                                                    m0_values_precommit),
475                                                                    n0_values_precommit),
476                                                                    k0_values_precommit),
477                                                                    framework::dataset::make("DataType", DataType::F32)),
478                                                                    a_values),
479                                                                    beta_values),
480                                                                    broadcast_bias_values),
481                                                                    act_values))
482 {
483     // Validate output
484     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
485 }
486 
487 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::DISABLED,
488                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
489                                                                    m_values,
490                                                                    n_values),
491                                                                    k_values),
492                                                                    b_values),
493                                                                    m0_values_nightly),
494                                                                    n0_values_nightly),
495                                                                    k0_values_nightly),
496                                                                    framework::dataset::make("DataType", DataType::F32)),
497                                                                    a_values),
498                                                                    beta_values),
499                                                                    broadcast_bias_values),
500                                                                    act_values))
501 {
502     // Validate output
503     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
504 }
505 
506 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyNative3DFixture<float>, framework::DatasetMode::ALL,
507                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
508                                                                    m_w_values,
509                                                                    m_h_values),
510                                                                    n_values),
511                                                                    k_values),
512                                                                    b_values),
513                                                                    m0_values_precommit),
514                                                                    n0_values_precommit),
515                                                                    k0_values_precommit),
516                                                                    framework::dataset::make("DataType", DataType::F32)),
517                                                                    a_values),
518                                                                    beta_values),
519                                                                    act_values))
520 {
521     // Validate output
522     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
523 }
524 
525 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyNative3DFixture<float>, framework::DatasetMode::DISABLED,
526                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
527                                                                    m_w_values,
528                                                                    m_h_values),
529                                                                    n_values),
530                                                                    k_values),
531                                                                    b_values),
532                                                                    m0_values_nightly),
533                                                                    n0_values_nightly),
534                                                                    k0_values_nightly),
535                                                                    framework::dataset::make("DataType", DataType::F32)),
536                                                                    a_values),
537                                                                    beta_values),
538                                                                    act_values))
539 {
540     // Validate output
541     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
542 }
543 
544 TEST_SUITE(FusedPostOps)
545 
546 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeWithPostOpsFixture<float>, framework::DatasetMode::ALL,
547                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
548                                                                    m_values,
549                                                                    n_values),
550                                                                    k_values),
551                                                                    b_values),
552                                                                    framework::dataset::make("M0", { 4 })),
553                                                                    n0_values_precommit),
554                                                                    k0_values_precommit),
555                                                                    framework::dataset::make("DataType", DataType::F32)),
556                                                                    framework::dataset::make("alpha", {1.0f} )),
557                                                                    framework::dataset::make("beta", {1.0f} )),
558                                                                    framework::dataset::make("broadcast_bias", { false, true } )),
559                                                                    framework::dataset::make("Activation", { ActivationLayerInfo() })),
560                                                                    post_op_lists)
561                                                                    )
562 {
563     // Validate output
564     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
565 }
566 
567 TEST_SUITE_END() //  FusedPostOps
568 
569 TEST_SUITE_END() // FP32
570 TEST_SUITE_END() // Float
571 TEST_SUITE_END() // GEMMMatrixMulipltyNative
572 TEST_SUITE_END() // CL
573 } // namespace validation
574 } // namespace test
575 } // namespace arm_compute
576