1 /*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
30
31 #include <utility>
32
33 namespace arm_compute
34 {
35 namespace opencl
36 {
37 namespace kernels
38 {
39 namespace gemm
40 {
ClGemmDefaultConfigNativeBifrost(GPUTarget gpu)41 ClGemmDefaultConfigNativeBifrost::ClGemmDefaultConfigNativeBifrost(GPUTarget gpu)
42 : IClGemmKernelConfig(gpu)
43 {
44 }
45
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)46 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
47 {
48 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigNativeBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
49 unsigned int b);
50
51 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G71(&ClGemmDefaultConfigNativeBifrost::configure_G71_f32,
52 &ClGemmDefaultConfigNativeBifrost::configure_G71_f32, // We use the F32 heuristic
53 &ClGemmDefaultConfigNativeBifrost::configure_G71_u8);
54
55 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigNativeBifrost::configure_G76_f32,
56 &ClGemmDefaultConfigNativeBifrost::configure_G76_f32, // We use the F32 heuristic
57 &ClGemmDefaultConfigNativeBifrost::configure_G76_u8);
58
59 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigNativeBifrost::configure_default_f32,
60 &ClGemmDefaultConfigNativeBifrost::configure_default_f32, // We use the F32 heuristic
61 &ClGemmDefaultConfigNativeBifrost::configure_default_u8);
62
63 ConfigurationFunctionExecutorPtr func = nullptr;
64
65 switch(_target)
66 {
67 case GPUTarget::G76:
68 func = configs_G76.get_function(data_type);
69 break;
70 case GPUTarget::G71:
71 func = configs_G71.get_function(data_type);
72 break;
73 default:
74 func = configs_G7x.get_function(data_type);
75 break;
76 }
77
78 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
79 return (this->*func)(m, n, k, b);
80 }
81
configure_G71_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)82 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G71_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
83 {
84 ARM_COMPUTE_UNUSED(k);
85 ARM_COMPUTE_UNUSED(b);
86
87 if(m == 1)
88 {
89 if(n < 2048)
90 {
91 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
92 }
93 else if(n >= 2048 && n < 8192)
94 {
95 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false);
96 }
97 else
98 {
99 return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 1, false, false, false, false);
100 }
101 }
102 else
103 {
104 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 1, false, false, false, false);
105 }
106 }
107
configure_G71_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)108 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G71_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
109 {
110 ARM_COMPUTE_UNUSED(k);
111 ARM_COMPUTE_UNUSED(b);
112
113 if(dot8_supported(CLKernelLibrary::get().get_device()))
114 {
115 if(m == 1)
116 {
117 if(n < 2048)
118 {
119 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
120 }
121 else if(n >= 2048 && n < 16384)
122 {
123 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
124 }
125 else
126 {
127 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
128 }
129 }
130 else
131 {
132 if(m < 64)
133 {
134 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
135 }
136 else
137 {
138 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
139 }
140 }
141 }
142 else
143 {
144 if(m == 1)
145 {
146 if(n < 8192)
147 {
148 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
149 }
150 else
151 {
152 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
153 }
154 }
155 else
156 {
157 return configure_lhs_rhs_info(m, n, 2, 8, 16, 1, 1, false, false, false, false);
158 }
159 }
160 }
161
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)162 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
163 {
164 ARM_COMPUTE_UNUSED(k);
165 ARM_COMPUTE_UNUSED(b);
166
167 if(m == 1)
168 {
169 if(n > 4196)
170 {
171 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 1, false, false, false, false);
172 }
173 else
174 {
175 if(k < 2048)
176 {
177 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 1, false, false, false, false);
178 }
179 else if(k >= 2048 && k < 16384)
180 {
181 return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false);
182 }
183 else
184 {
185 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 1, false, false, false, false);
186 }
187 }
188 }
189 else
190 {
191 return configure_lhs_rhs_info(m, n, 2, 8, 2, 1, 1, false, false, false, false);
192 }
193 }
194
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)195 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
196 {
197 ARM_COMPUTE_UNUSED(k);
198 ARM_COMPUTE_UNUSED(b);
199
200 if(m == 1)
201 {
202 if(n < 2048)
203 {
204 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false);
205 }
206 else if(n >= 2048 && n < 16384)
207 {
208 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false);
209 }
210 else
211 {
212 return configure_lhs_rhs_info(m, n, 1, 8, 16, 1, 1, false, false, false, false);
213 }
214 }
215 else
216 {
217 if(m < 64)
218 {
219 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false);
220 }
221 else
222 {
223 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
224 }
225 }
226 }
227
configure_default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)228 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
229 {
230 ARM_COMPUTE_UNUSED(k);
231 ARM_COMPUTE_UNUSED(b);
232
233 return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 1, false, false, false, false);
234 }
235
configure_default_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)236 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeBifrost::configure_default_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
237 {
238 ARM_COMPUTE_UNUSED(k);
239 ARM_COMPUTE_UNUSED(b);
240
241 return configure_lhs_rhs_info(m, n, 5, 2, 16, 1, 1, false, false, false, false);
242 }
243 } // namespace gemm
244 } // namespace kernels
245 } // namespace opencl
246 } // namespace arm_compute