1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/native/CPUBlas.h>
5 #include <ATen/native/cpu/zmath.h>
6 #include <c10/util/irange.h>
7 #include <c10/util/Unroll.h>
8
9 #if defined(__aarch64__) && !defined(C10_MOBILE)
10 #include <arm_neon.h>
11
12 namespace at::native::blas_impl {
13 void fp16_gemv_notrans(
14 const int m,
15 const int n,
16 const float alpha,
17 const float16_t* a,
18 const int lda,
19 const float16_t* x,
20 const int incx,
21 const float beta,
22 float16_t* y,
23 const int incy);
24
25 void fp16_gemv_trans(
26 const int m,
27 const int n,
28 const float alpha,
29 const float16_t* a,
30 const int lda,
31 const float16_t* x,
32 const int incx,
33 const float beta,
34 float16_t* y,
35 const int incy);
36
37 float fp16_dot_with_fp32_arith(
38 const float16_t* x,
39 const float16_t* a,
40 int64_t len);
41
42 float bf16_dot_with_fp32_arith(
43 const at::BFloat16* x,
44 const at::BFloat16* a,
45 int64_t len);
46 }
47 #endif
48
49 namespace at::native {
50 namespace cpublas {
51 namespace {
52
53 template <typename scalar_t, typename opmath_t>
scale_(int64_t m,int64_t n,opmath_t alpha,scalar_t * a,int64_t lda)54 void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) {
55 if (alpha == opmath_t(1)) {
56 return; // identity
57 }
58
59 if (alpha == opmath_t(0)) {
60 for (const auto j : c10::irange(n)) {
61 for (const auto i : c10::irange(m)) {
62 a[j * lda + i] = scalar_t(0);
63 }
64 }
65 return;
66 }
67
68 for (const auto j : c10::irange(n)) {
69 for (const auto i : c10::irange(m)) {
70 a[j * lda + i] *= alpha;
71 }
72 }
73 }
74
75 template <typename Func>
sum(int64_t N,Func f)76 auto sum(int64_t N, Func f) {
77 constexpr int ilp_factor = 4;
78 using acc_t = decltype(f(0));
79
80 // Calculate independent partial sums then add together at the end
81 std::array<acc_t, ilp_factor> partial_sums{};
82
83 int64_t i = 0;
84 for (; i + ilp_factor <= N; i += ilp_factor) {
85 c10::ForcedUnroll<ilp_factor>{}([&](int k) {
86 partial_sums[k] += f(i + k);
87 });
88 }
89 for (; i < N; ++i) {
90 partial_sums[0] += f(i);
91 }
92 for (int k = 1; k < ilp_factor; ++k) {
93 partial_sums[0] += partial_sums[k];
94 }
95 return partial_sums[0];
96 }
97
98 template <typename scalar_t, typename opmath_t>
99 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)100 gemm_notrans_(
101 int64_t m,
102 int64_t n,
103 int64_t k,
104 opmath_t alpha,
105 const scalar_t* a,
106 int64_t lda,
107 const scalar_t* b,
108 int64_t ldb,
109 opmath_t beta,
110 scalar_t* c,
111 int64_t ldc) {
112 // c *= beta
113 scale_(m, n, beta, c, ldc);
114
115 // c += alpha * (a @ b)
116 for (const auto l : c10::irange(k)) {
117 for (const auto j : c10::irange(n)) {
118 opmath_t val = b[l + j * ldb] * alpha;
119 int64_t i_m = m / 4;
120 for (const auto i_i : c10::irange(i_m)) {
121 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
122 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
123 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
124 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
125 }
126 int64_t i = i_m * 4;
127 for (; i < m; i++)
128 c[j * ldc + i] += a[i + l * lda] * val;
129 }
130 }
131 }
132
133 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
134 template <typename scalar_t, typename opmath_t>
135 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)136 gemm_notrans_(
137 int64_t m,
138 int64_t n,
139 int64_t k,
140 opmath_t alpha,
141 const scalar_t* a,
142 int64_t lda,
143 const scalar_t* b,
144 int64_t ldb,
145 opmath_t beta,
146 scalar_t* c,
147 int64_t ldc) {
148 // c += alpha * (a @ b)
149 for (const auto i : c10::irange(m)) {
150 for (const auto j : c10::irange(n)) {
151 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
152 return static_cast<opmath_t>(a[l * lda + i]) *
153 static_cast<opmath_t>(b[j * ldb + l]);
154 });
155 if (beta == opmath_t(0)) {
156 c[j * ldc + i] = alpha * dot;
157 } else {
158 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
159 }
160 }
161 }
162 }
163
164 template <typename scalar_t, typename opmath_t>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)165 void gemm_transa_(
166 TransposeType transa,
167 int64_t m, int64_t n, int64_t k,
168 opmath_t alpha,
169 const scalar_t *a, int64_t lda,
170 const scalar_t *b, int64_t ldb,
171 opmath_t beta,
172 scalar_t *c, int64_t ldc) {
173 // c = alpha * (a.T @ b) + beta * c
174 const scalar_t *a_ = a;
175 for (const auto i : c10::irange(m)) {
176 const scalar_t *b_ = b;
177 for (const auto j : c10::irange(n)) {
178 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
179 return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a_[l]) : a_[l]) * static_cast<opmath_t>(b_[l]);
180 });
181 b_ += ldb;
182 if (beta == opmath_t(0)) {
183 c[j*ldc+i] = alpha*dot;
184 } else {
185 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
186 }
187 }
188 a_ += lda;
189 }
190 }
191
192 template <typename scalar_t, typename opmath_t>
gemm_transb_impl(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t * c,int64_t ldc)193 void gemm_transb_impl(
194 TransposeType transb,
195 int64_t m,
196 int64_t n,
197 int64_t k,
198 opmath_t alpha,
199 const scalar_t* a,
200 int64_t lda,
201 const scalar_t* b,
202 int64_t ldb,
203 /* we expect pre-applied beta */
204 opmath_t* c,
205 int64_t ldc) {
206 // c += alpha * (a @ b.T)
207 for (const auto l : c10::irange(k)) {
208 for (const auto j : c10::irange(n)) {
209 opmath_t val = (transb == TransposeType::ConjTranspose ? conj_impl(b[j + l * ldb]) : b[j + l * ldb]) * alpha;
210 int64_t i_m = m / 4;
211 for (const auto i_i : c10::irange(i_m)) {
212 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
213 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
214 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
215 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
216 }
217 int64_t i = i_m * 4;
218 for (; i < m; i++)
219 c[j * ldc + i] += a[i + l * lda] * val;
220 }
221 }
222 }
223
224 template <typename scalar_t, typename opmath_t>
225 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)226 gemm_transb_(
227 TransposeType transb,
228 int64_t m,
229 int64_t n,
230 int64_t k,
231 opmath_t alpha,
232 const scalar_t* a,
233 int64_t lda,
234 const scalar_t* b,
235 int64_t ldb,
236 opmath_t beta,
237 scalar_t* c,
238 int64_t ldc) {
239 // c *= beta
240 scale_(m, n, beta, c, ldc);
241
242 gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c, ldc);
243 }
244
245 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
246 template <typename scalar_t, typename opmath_t>
247 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)248 gemm_transb_(
249 TransposeType transb,
250 int64_t m,
251 int64_t n,
252 int64_t k,
253 opmath_t alpha,
254 const scalar_t* a,
255 int64_t lda,
256 const scalar_t* b,
257 int64_t ldb,
258 opmath_t beta,
259 scalar_t* c,
260 int64_t ldc) {
261 // We need to calculate full-precision dot products for correctness;
262 // users notice error accumulation with reduced-width types (e.g.,
263 // https://github.com/pytorch/pytorch/issues/95125 and
264 // https://github.com/pytorch/pytorch/issues/83863, which were filed
265 // when we used gemm_transb_impl naively, accumulating into
266 // float16/bfloat16). The straightforward way to do this is to use
267 // the vector dot column algorithm anyway, but this gives terrible
268 // performance because of the non-contiguous matrix
269 // access. Therefore, we instead elect to allocate temporary space
270 // to hold the output at higher-precision so that we can accumulate
271 // into it using the above cache-friendly "load one vector element,
272 // FMA it with an entire matrix row into the entire result vector"
273 // algorithm instead.
274 const auto c_size = m * n;
275 auto c_accum = std::make_unique<opmath_t[]>(c_size);
276 if (beta == 1) {
277 for (const auto j : c10::irange(n)) {
278 for (const auto i : c10::irange(m)) {
279 c_accum[j * m + i] = c[j * ldc + i];
280 }
281 }
282 } else if (beta == 0) {
283 for (const auto j : c10::irange(n)) {
284 for (const auto i : c10::irange(m)) {
285 c_accum[j * m + i] = 0;
286 }
287 }
288 } else {
289 for (const auto j : c10::irange(n)) {
290 for (const auto i : c10::irange(m)) {
291 c_accum[j * m + i] = beta * c[j * ldc + i];
292 }
293 }
294 }
295 gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c_accum.get(), m);
296 for (const auto j : c10::irange(n)) {
297 for (const auto i : c10::irange(m)) {
298 c[j * ldc + i] = c_accum[j * m + i];
299 }
300 }
301 }
302
303 template <typename scalar_t, typename opmath_t>
gemm_transab_(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)304 void gemm_transab_(
305 TransposeType transa, TransposeType transb,
306 int64_t m, int64_t n, int64_t k,
307 opmath_t alpha,
308 const scalar_t *a, int64_t lda,
309 const scalar_t *b, int64_t ldb,
310 opmath_t beta,
311 scalar_t *c, int64_t ldc) {
312 // c = beta * c + alpha * (a.T @ b.T)
313 for (const auto i : c10::irange(m)) {
314 for (const auto j : c10::irange(n)) {
315 const auto dot = sum(k, [&](int64_t l) -> opmath_t {
316 return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a[i * lda + l]) : a[i * lda + l]) *
317 static_cast<opmath_t>(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]);
318 });
319
320 if (beta == opmath_t(0)) {
321 c[j * ldc + i] = alpha * dot;
322 } else {
323 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
324 }
325 }
326 }
327 }
328
329 #if defined(__aarch64__) && !defined(C10_MOBILE)
330 template <>
gemm_notrans_(int64_t m,int64_t n,int64_t k,float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,float beta,at::Half * c,int64_t ldc)331 void gemm_notrans_(
332 int64_t m,
333 int64_t n,
334 int64_t k,
335 float alpha,
336 const at::Half* a,
337 int64_t lda,
338 const at::Half* b,
339 int64_t ldb,
340 float beta,
341 at::Half* c,
342 int64_t ldc) {
343 // c += alpha * (a @ b)
344 if (n == 1 && beta == 0.0 && alpha == 1.0) {
345 at::native::blas_impl::fp16_gemv_notrans(m, k, 1.0, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, 0.0, reinterpret_cast<float16_t*>(c), 1);
346 return;
347 }
348 for (const auto i : c10::irange(m)) {
349 for (const auto j : c10::irange(n)) {
350 const auto dot = sum(k, [&](int64_t l) -> float {
351 return float(c10::detail::fp16_from_bits(a[l * lda + i].x)) *
352 float(c10::detail::fp16_from_bits(b[j * ldb + l].x));
353 });
354 if (beta == 0) {
355 c[j * ldc + i] = alpha * dot;
356 } else {
357 c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
358 }
359 }
360 }
361 }
362
363
load_as_float32x4(const BFloat16 * ptr)364 inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
365 int32x4_t shift = vdupq_n_s32(16);
366 uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
367 return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
368 }
369
compute_dot(const at::Half * a,const at::Half * b,int64_t len)370 static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
371 return at::native::blas_impl::fp16_dot_with_fp32_arith(
372 reinterpret_cast<const float16_t*>(a),
373 reinterpret_cast<const float16_t*>(b),
374 len);
375 }
376
compute_dot(const at::BFloat16 * a,const at::BFloat16 * b,int64_t len)377 static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) {
378 return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len);
379 }
380
381 template <>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,float beta,at::Half * c,int64_t ldc)382 void gemm_transa_(
383 TransposeType transa,
384 int64_t m, int64_t n, int64_t k,
385 float alpha,
386 const at::Half *a, int64_t lda,
387 const at::Half *b, int64_t ldb,
388 float beta,
389 at::Half *c, int64_t ldc) {
390 // c = alpha * (a.T @ b) + beta * c
391 if (n == 1 && beta == 0.0 && alpha == 1.0) {
392 at::native::blas_impl::fp16_gemv_trans(k, m, 1.0, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, 0.0, reinterpret_cast<float16_t*>(c), 1);
393 return;
394 }
395 parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
396 const auto *a_ = a + begin * lda;
397 for (const auto i : c10::irange(begin, end)) {
398 const auto *b_ = b;
399 for (const auto j : c10::irange(n)) {
400 const auto dot = compute_dot(a_, b_, k);
401 b_ += ldb;
402 if (beta == 0) {
403 c[j*ldc+i] = alpha*dot;
404 } else {
405 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
406 }
407 }
408 a_ += lda;
409 }
410 });
411 }
412
413 template <>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,float beta,at::BFloat16 * c,int64_t ldc)414 void gemm_transa_(
415 TransposeType transa,
416 int64_t m, int64_t n, int64_t k,
417 float alpha,
418 const at::BFloat16 *a, int64_t lda,
419 const at::BFloat16 *b, int64_t ldb,
420 float beta,
421 at::BFloat16 *c, int64_t ldc) {
422 // c = alpha * (a.T @ b) + beta * c
423 parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
424 const auto *a_ = a + begin * lda;
425 for (const auto i : c10::irange(begin, end)) {
426 const auto *b_ = b;
427 for (const auto j : c10::irange(n)) {
428 const auto dot = compute_dot(a_, b_, k);
429 b_ += ldb;
430 if (beta == 0) {
431 c[j*ldc+i] = alpha*dot;
432 } else {
433 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
434 }
435 }
436 a_ += lda;
437 }
438 });
439 }
440
441 #endif
442
443 template <typename scalar_t, typename opmath_t>
gemm_core_(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)444 void gemm_core_(
445 TransposeType transa, TransposeType transb,
446 int64_t m, int64_t n, int64_t k,
447 opmath_t alpha,
448 const scalar_t *a, int64_t lda,
449 const scalar_t *b, int64_t ldb,
450 opmath_t beta,
451 scalar_t *c, int64_t ldc) {
452 if (transa == TransposeType::NoTranspose &&
453 transb == TransposeType::NoTranspose) {
454 return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
455 } else if (
456 transa != TransposeType::NoTranspose &&
457 transb == TransposeType::NoTranspose) {
458 gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
459 } else if (
460 transa == TransposeType::NoTranspose &&
461 transb != TransposeType::NoTranspose) {
462 gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
463 } else {
464 gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
465 }
466 }
467
468 #if !defined(C10_MOBILE)
469 #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
470 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
471 kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
472 TYPE, NAME, __VA_ARGS__)
473 #else
474 #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
475 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
476 kHalf, kBFloat16, \
477 TYPE, NAME, __VA_ARGS__)
478 #endif
cpublas_gemm_impl(at::ScalarType type,TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Scalar & alpha,const void * a,int64_t lda,const void * b,int64_t ldb,const Scalar & beta,void * c,int64_t ldc)479 void cpublas_gemm_impl(
480 at::ScalarType type,
481 TransposeType transa, TransposeType transb,
482 int64_t m, int64_t n, int64_t k,
483 const Scalar& alpha,
484 const void *a, int64_t lda,
485 const void *b, int64_t ldb,
486 const Scalar& beta,
487 void *c, int64_t ldc) {
488 _AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
489 using opmath_t = at::opmath_type<scalar_t>;
490 gemm_core_(
491 transa, transb, m, n, k,
492 alpha.to<opmath_t>(),
493 static_cast<const scalar_t *>(a), lda,
494 static_cast<const scalar_t *>(b), ldb,
495 beta.to<opmath_t>(),
496 static_cast<scalar_t *>(c), ldc);
497 });
498 }
499
cpublas_axpy_impl(at::ScalarType type,int64_t n,const Scalar & _a,const void * _x,int64_t incx,void * _y,int64_t incy)500 void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
501 if (type == at::kBool) {
502 auto a = _a.to<bool>();
503 auto x = static_cast<const bool *>(_x);
504 auto y = static_cast<bool *>(_y);
505 int64_t i;
506 for(i = 0; i < n; i++)
507 y[i*incy] |= a & x[i*incx];
508 } else {
509 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl",
510 [&] {
511 using opmath_t = at::opmath_type<scalar_t>;
512 auto a = _a.to<opmath_t>();
513 auto x = static_cast<const scalar_t *>(_x);
514 auto y = static_cast<scalar_t *>(_y);
515 int64_t i;
516 for(i = 0; i < n; i++)
517 y[i*incy] += a*x[i*incx];
518 });
519 }
520 }
521
cpublas_copy_impl(at::ScalarType type,int64_t n,const void * _x,int64_t incx,void * _y,int64_t incy)522 void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t incx, void *_y, int64_t incy){
523 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::kComplexHalf, at::kHalf, at::kBFloat16, at::kBool, type, "cpublas_copy_impl",
524 [&] {
525 auto x = static_cast<const scalar_t *>(_x);
526 auto y = static_cast<scalar_t *>(_y);
527 int64_t i;
528 for(i = 0; i < n; i++)
529 y[i*incy] = x[i*incx];
530 });
531 }
532
533 }} // namespace cpublas::(anonymous)
534
535
536 REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl);
537 REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl);
538 REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl);
539
540 } // namespace at::native
541