1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2022 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust
25*c217d954SCole Faust #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
26*c217d954SCole Faust #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
27*c217d954SCole Faust #include "tests/CL/CLAccessor.h"
28*c217d954SCole Faust #include "tests/CL/Helper.h"
29*c217d954SCole Faust #include "tests/framework/Macros.h"
30*c217d954SCole Faust #include "tests/framework/datasets/Datasets.h"
31*c217d954SCole Faust #include "tests/validation/Validation.h"
32*c217d954SCole Faust #include "tests/validation/fixtures/GEMMFixture.h"
33*c217d954SCole Faust
34*c217d954SCole Faust namespace arm_compute
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace test
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace validation
39*c217d954SCole Faust {
40*c217d954SCole Faust using namespace arm_compute::opencl::kernels;
41*c217d954SCole Faust
42*c217d954SCole Faust // Create function for ClGemmReshapeRhsMatrixKernel
43*c217d954SCole Faust using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<ClGemmReshapeRhsMatrixKernel>;
44*c217d954SCole Faust
45*c217d954SCole Faust // Create function for ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel
46*c217d954SCole Faust using CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL = CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>;
47*c217d954SCole Faust
48*c217d954SCole Faust // Fixture for CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL
49*c217d954SCole Faust template <typename T>
50*c217d954SCole Faust using CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture = GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL>;
51*c217d954SCole Faust
52*c217d954SCole Faust namespace
53*c217d954SCole Faust {
54*c217d954SCole Faust // *INDENT-OFF*
55*c217d954SCole Faust // clang-format off
56*c217d954SCole Faust RelativeTolerance<float> rel_tolerance_f32(0.001f);
57*c217d954SCole Faust constexpr float abs_tolerance_f32(0.0001f);
58*c217d954SCole Faust RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.001f));
59*c217d954SCole Faust constexpr float abs_tolerance_f16(0.3f);
60*c217d954SCole Faust
61*c217d954SCole Faust /** Alpha values to test - Precommit */
62*c217d954SCole Faust const auto a_values = framework::dataset::make("alpha", {1.0f, 0.75f} );
63*c217d954SCole Faust
64*c217d954SCole Faust /** Beta values to test - Precommit */
65*c217d954SCole Faust const auto beta_values = framework::dataset::make("beta", {0.0f, -0.75f} );
66*c217d954SCole Faust
67*c217d954SCole Faust /** M values to test */
68*c217d954SCole Faust const auto m_values = framework::dataset::make("M", {49});
69*c217d954SCole Faust
70*c217d954SCole Faust /** N values to test */
71*c217d954SCole Faust const auto n_values = framework::dataset::make("N", {257});
72*c217d954SCole Faust
73*c217d954SCole Faust /** K values to test */
74*c217d954SCole Faust /** The test case requires this to be multiple of 4*/
75*c217d954SCole Faust const auto k_values = framework::dataset::make("K", {192});
76*c217d954SCole Faust
77*c217d954SCole Faust /** Batch size values to test */
78*c217d954SCole Faust const auto b_values = framework::dataset::make("batch_size", {1, 2});
79*c217d954SCole Faust
80*c217d954SCole Faust /** Activation values to test */
81*c217d954SCole Faust const auto act_values = framework::dataset::make("Activation",
82*c217d954SCole Faust {
83*c217d954SCole Faust ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
84*c217d954SCole Faust ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::ELU),
85*c217d954SCole Faust });
86*c217d954SCole Faust
87*c217d954SCole Faust /** M0 values to test - Precommit */
88*c217d954SCole Faust const auto m0_values_precommit = framework::dataset::make("M0", { 1, 2, 4 });
89*c217d954SCole Faust
90*c217d954SCole Faust /** N0 values to test - Precommit */
91*c217d954SCole Faust const auto n0_values_precommit = framework::dataset::make("N0", { 4, 8 });
92*c217d954SCole Faust
93*c217d954SCole Faust /** K0 values to test - Precommit */
94*c217d954SCole Faust const auto k0_values_precommit = framework::dataset::make("K0", { 1 });
95*c217d954SCole Faust
96*c217d954SCole Faust /** Broadcast bias from vector to matrix */
97*c217d954SCole Faust const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
98*c217d954SCole Faust
99*c217d954SCole Faust } // namespace
100*c217d954SCole Faust
101*c217d954SCole Faust TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL)102*c217d954SCole Faust TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL)
103*c217d954SCole Faust TEST_SUITE(Float)
104*c217d954SCole Faust TEST_SUITE(FP32)
105*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
106*c217d954SCole Faust combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
107*c217d954SCole Faust m_values,
108*c217d954SCole Faust n_values),
109*c217d954SCole Faust k_values),
110*c217d954SCole Faust b_values),
111*c217d954SCole Faust m0_values_precommit),
112*c217d954SCole Faust n0_values_precommit),
113*c217d954SCole Faust k0_values_precommit),
114*c217d954SCole Faust framework::dataset::make("ExportToCLImage", false)),
115*c217d954SCole Faust framework::dataset::make("DataType", DataType::F32)),
116*c217d954SCole Faust a_values),
117*c217d954SCole Faust beta_values),
118*c217d954SCole Faust broadcast_bias_values),
119*c217d954SCole Faust act_values))
120*c217d954SCole Faust {
121*c217d954SCole Faust // Validate output
122*c217d954SCole Faust if(validate_result)
123*c217d954SCole Faust {
124*c217d954SCole Faust validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
125*c217d954SCole Faust }
126*c217d954SCole Faust else
127*c217d954SCole Faust {
128*c217d954SCole Faust ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
129*c217d954SCole Faust framework::ARM_COMPUTE_PRINT_INFO();
130*c217d954SCole Faust }
131*c217d954SCole Faust }
132*c217d954SCole Faust
133*c217d954SCole Faust TEST_SUITE_END() // FP32
134*c217d954SCole Faust
TEST_SUITE(FP16)135*c217d954SCole Faust TEST_SUITE(FP16)
136*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
137*c217d954SCole Faust combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
138*c217d954SCole Faust m_values,
139*c217d954SCole Faust n_values),
140*c217d954SCole Faust k_values),
141*c217d954SCole Faust b_values),
142*c217d954SCole Faust m0_values_precommit),
143*c217d954SCole Faust n0_values_precommit),
144*c217d954SCole Faust k0_values_precommit),
145*c217d954SCole Faust framework::dataset::make("ExportToCLImage", false)),
146*c217d954SCole Faust framework::dataset::make("DataType", DataType::F16)),
147*c217d954SCole Faust a_values),
148*c217d954SCole Faust beta_values),
149*c217d954SCole Faust broadcast_bias_values),
150*c217d954SCole Faust act_values))
151*c217d954SCole Faust {
152*c217d954SCole Faust // Validate output
153*c217d954SCole Faust if(validate_result)
154*c217d954SCole Faust {
155*c217d954SCole Faust validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
156*c217d954SCole Faust }
157*c217d954SCole Faust else
158*c217d954SCole Faust {
159*c217d954SCole Faust ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
160*c217d954SCole Faust framework::ARM_COMPUTE_PRINT_INFO();
161*c217d954SCole Faust }
162*c217d954SCole Faust }
163*c217d954SCole Faust TEST_SUITE_END() // FP16
164*c217d954SCole Faust
TEST_SUITE(ExportToCLImage)165*c217d954SCole Faust TEST_SUITE(ExportToCLImage)
166*c217d954SCole Faust TEST_SUITE(FP32)
167*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
168*c217d954SCole Faust combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
169*c217d954SCole Faust m_values,
170*c217d954SCole Faust n_values),
171*c217d954SCole Faust k_values),
172*c217d954SCole Faust b_values),
173*c217d954SCole Faust m0_values_precommit),
174*c217d954SCole Faust n0_values_precommit),
175*c217d954SCole Faust k0_values_precommit),
176*c217d954SCole Faust framework::dataset::make("ExportToCLImage", true)),
177*c217d954SCole Faust framework::dataset::make("DataType", DataType::F32)),
178*c217d954SCole Faust a_values),
179*c217d954SCole Faust beta_values),
180*c217d954SCole Faust broadcast_bias_values),
181*c217d954SCole Faust act_values))
182*c217d954SCole Faust {
183*c217d954SCole Faust // Validate output
184*c217d954SCole Faust if(validate_result)
185*c217d954SCole Faust {
186*c217d954SCole Faust validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
187*c217d954SCole Faust }
188*c217d954SCole Faust else
189*c217d954SCole Faust {
190*c217d954SCole Faust ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
191*c217d954SCole Faust framework::ARM_COMPUTE_PRINT_INFO();
192*c217d954SCole Faust }
193*c217d954SCole Faust }
194*c217d954SCole Faust
195*c217d954SCole Faust TEST_SUITE_END() // FP32
196*c217d954SCole Faust
TEST_SUITE(FP16)197*c217d954SCole Faust TEST_SUITE(FP16)
198*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
199*c217d954SCole Faust combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
200*c217d954SCole Faust m_values,
201*c217d954SCole Faust n_values),
202*c217d954SCole Faust k_values),
203*c217d954SCole Faust b_values),
204*c217d954SCole Faust m0_values_precommit),
205*c217d954SCole Faust n0_values_precommit),
206*c217d954SCole Faust k0_values_precommit),
207*c217d954SCole Faust framework::dataset::make("ExportToCLImage", true)),
208*c217d954SCole Faust framework::dataset::make("DataType", DataType::F16)),
209*c217d954SCole Faust a_values),
210*c217d954SCole Faust beta_values),
211*c217d954SCole Faust broadcast_bias_values),
212*c217d954SCole Faust act_values))
213*c217d954SCole Faust {
214*c217d954SCole Faust // Validate output
215*c217d954SCole Faust if(validate_result)
216*c217d954SCole Faust {
217*c217d954SCole Faust validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
218*c217d954SCole Faust }
219*c217d954SCole Faust else
220*c217d954SCole Faust {
221*c217d954SCole Faust ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
222*c217d954SCole Faust framework::ARM_COMPUTE_PRINT_INFO();
223*c217d954SCole Faust }
224*c217d954SCole Faust }
225*c217d954SCole Faust TEST_SUITE_END() // FP16
226*c217d954SCole Faust TEST_SUITE_END() // ExportToCLImage
227*c217d954SCole Faust TEST_SUITE_END() // Float
228*c217d954SCole Faust TEST_SUITE_END() // GEMMMatrixMultiplyReshapedOnlyRhsMMUL
229*c217d954SCole Faust TEST_SUITE_END() // CL
230*c217d954SCole Faust } // namespace validation
231*c217d954SCole Faust } // namespace test
232*c217d954SCole Faust } // namespace arm_compute
233