xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/BlasKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/native/CPUBlas.h>
5 #include <ATen/native/cpu/zmath.h>
6 #include <c10/util/irange.h>
7 #include <c10/util/Unroll.h>
8 
9 #if defined(__aarch64__) && !defined(C10_MOBILE)
10 #include <arm_neon.h>
11 
12 namespace at::native::blas_impl {
13 void fp16_gemv_notrans(
14     const int m,
15     const int n,
16     const float alpha,
17     const float16_t* a,
18     const int lda,
19     const float16_t* x,
20     const int incx,
21     const float beta,
22     float16_t* y,
23     const int incy);
24 
25 void fp16_gemv_trans(
26     const int m,
27     const int n,
28     const float alpha,
29     const float16_t* a,
30     const int lda,
31     const float16_t* x,
32     const int incx,
33     const float beta,
34     float16_t* y,
35     const int incy);
36 
37 float fp16_dot_with_fp32_arith(
38   const float16_t* x,
39   const float16_t* a,
40   int64_t len);
41 
42 float bf16_dot_with_fp32_arith(
43   const at::BFloat16* x,
44   const at::BFloat16* a,
45   int64_t len);
46 }
47 #endif
48 
49 namespace at::native {
50 namespace cpublas {
51 namespace {
52 
53 template <typename scalar_t, typename opmath_t>
scale_(int64_t m,int64_t n,opmath_t alpha,scalar_t * a,int64_t lda)54 void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) {
55   if (alpha == opmath_t(1)) {
56     return;  // identity
57   }
58 
59   if (alpha == opmath_t(0)) {
60     for (const auto j : c10::irange(n)) {
61       for (const auto i : c10::irange(m)) {
62         a[j * lda + i] = scalar_t(0);
63       }
64     }
65     return;
66   }
67 
68   for (const auto j : c10::irange(n)) {
69     for (const auto i : c10::irange(m)) {
70       a[j * lda + i] *= alpha;
71     }
72   }
73 }
74 
75 template <typename Func>
sum(int64_t N,Func f)76 auto sum(int64_t N, Func f) {
77   constexpr int ilp_factor = 4;
78   using acc_t = decltype(f(0));
79 
80   // Calculate independent partial sums then add together at the end
81   std::array<acc_t, ilp_factor> partial_sums{};
82 
83   int64_t i = 0;
84   for (; i + ilp_factor <= N; i += ilp_factor) {
85     c10::ForcedUnroll<ilp_factor>{}([&](int k) {
86       partial_sums[k] += f(i + k);
87     });
88   }
89   for (; i < N; ++i) {
90     partial_sums[0] += f(i);
91   }
92   for (int k = 1; k < ilp_factor; ++k) {
93     partial_sums[0] += partial_sums[k];
94   }
95   return partial_sums[0];
96 }
97 
98 template <typename scalar_t, typename opmath_t>
99 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)100 gemm_notrans_(
101     int64_t m,
102     int64_t n,
103     int64_t k,
104     opmath_t alpha,
105     const scalar_t* a,
106     int64_t lda,
107     const scalar_t* b,
108     int64_t ldb,
109     opmath_t beta,
110     scalar_t* c,
111     int64_t ldc) {
112   // c *= beta
113   scale_(m, n, beta, c, ldc);
114 
115   // c += alpha * (a @ b)
116   for (const auto l : c10::irange(k)) {
117     for (const auto j : c10::irange(n)) {
118       opmath_t val = b[l + j * ldb] * alpha;
119       int64_t i_m = m / 4;
120       for (const auto i_i : c10::irange(i_m)) {
121         c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
122         c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
123         c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
124         c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
125       }
126       int64_t i = i_m * 4;
127       for (; i < m; i++)
128         c[j * ldc + i] += a[i + l * lda] * val;
129     }
130   }
131 }
132 
133 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
134 template <typename scalar_t, typename opmath_t>
135 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_notrans_(int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)136 gemm_notrans_(
137     int64_t m,
138     int64_t n,
139     int64_t k,
140     opmath_t alpha,
141     const scalar_t* a,
142     int64_t lda,
143     const scalar_t* b,
144     int64_t ldb,
145     opmath_t beta,
146     scalar_t* c,
147     int64_t ldc) {
148   // c += alpha * (a @ b)
149   for (const auto i : c10::irange(m)) {
150     for (const auto j : c10::irange(n)) {
151       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
152         return static_cast<opmath_t>(a[l * lda + i]) *
153             static_cast<opmath_t>(b[j * ldb + l]);
154       });
155       if (beta == opmath_t(0)) {
156         c[j * ldc + i] = alpha * dot;
157       } else {
158         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
159       }
160     }
161   }
162 }
163 
164 template <typename scalar_t, typename opmath_t>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)165 void gemm_transa_(
166     TransposeType transa,
167     int64_t m, int64_t n, int64_t k,
168     opmath_t alpha,
169     const scalar_t *a, int64_t lda,
170     const scalar_t *b, int64_t ldb,
171     opmath_t beta,
172     scalar_t *c, int64_t ldc) {
173   // c = alpha * (a.T @ b) + beta * c
174   const scalar_t *a_ = a;
175   for (const auto i : c10::irange(m)) {
176     const scalar_t *b_ = b;
177     for (const auto j : c10::irange(n)) {
178       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
179         return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a_[l]) : a_[l]) * static_cast<opmath_t>(b_[l]);
180       });
181       b_ += ldb;
182       if (beta == opmath_t(0)) {
183         c[j*ldc+i] = alpha*dot;
184       } else {
185         c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
186       }
187     }
188     a_ += lda;
189   }
190 }
191 
192 template <typename scalar_t, typename opmath_t>
gemm_transb_impl(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t * c,int64_t ldc)193 void gemm_transb_impl(
194     TransposeType transb,
195     int64_t m,
196     int64_t n,
197     int64_t k,
198     opmath_t alpha,
199     const scalar_t* a,
200     int64_t lda,
201     const scalar_t* b,
202     int64_t ldb,
203     /* we expect pre-applied beta */
204     opmath_t* c,
205     int64_t ldc) {
206   // c += alpha * (a @ b.T)
207   for (const auto l : c10::irange(k)) {
208     for (const auto j : c10::irange(n)) {
209       opmath_t val = (transb == TransposeType::ConjTranspose ? conj_impl(b[j + l * ldb]) : b[j + l * ldb]) * alpha;
210       int64_t i_m = m / 4;
211       for (const auto i_i : c10::irange(i_m)) {
212         c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
213         c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
214         c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
215         c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
216       }
217       int64_t i = i_m * 4;
218       for (; i < m; i++)
219         c[j * ldc + i] += a[i + l * lda] * val;
220     }
221   }
222 }
223 
224 template <typename scalar_t, typename opmath_t>
225 typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)226 gemm_transb_(
227     TransposeType transb,
228     int64_t m,
229     int64_t n,
230     int64_t k,
231     opmath_t alpha,
232     const scalar_t* a,
233     int64_t lda,
234     const scalar_t* b,
235     int64_t ldb,
236     opmath_t beta,
237     scalar_t* c,
238     int64_t ldc) {
239   // c *= beta
240   scale_(m, n, beta, c, ldc);
241 
242   gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c, ldc);
243 }
244 
245 // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
246 template <typename scalar_t, typename opmath_t>
247 typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
gemm_transb_(TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)248 gemm_transb_(
249     TransposeType transb,
250     int64_t m,
251     int64_t n,
252     int64_t k,
253     opmath_t alpha,
254     const scalar_t* a,
255     int64_t lda,
256     const scalar_t* b,
257     int64_t ldb,
258     opmath_t beta,
259     scalar_t* c,
260     int64_t ldc) {
261   // We need to calculate full-precision dot products for correctness;
262   // users notice error accumulation with reduced-width types (e.g.,
263   // https://github.com/pytorch/pytorch/issues/95125 and
264   // https://github.com/pytorch/pytorch/issues/83863, which were filed
265   // when we used gemm_transb_impl naively, accumulating into
266   // float16/bfloat16). The straightforward way to do this is to use
267   // the vector dot column algorithm anyway, but this gives terrible
268   // performance because of the non-contiguous matrix
269   // access. Therefore, we instead elect to allocate temporary space
270   // to hold the output at higher-precision so that we can accumulate
271   // into it using the above cache-friendly "load one vector element,
272   // FMA it with an entire matrix row into the entire result vector"
273   // algorithm instead.
274   const auto c_size = m * n;
275   auto c_accum = std::make_unique<opmath_t[]>(c_size);
276   if (beta == 1) {
277     for (const auto j : c10::irange(n)) {
278       for (const auto i : c10::irange(m)) {
279         c_accum[j * m + i] = c[j * ldc + i];
280       }
281     }
282   } else if (beta == 0) {
283     for (const auto j : c10::irange(n)) {
284       for (const auto i : c10::irange(m)) {
285         c_accum[j * m + i] = 0;
286       }
287     }
288   } else {
289     for (const auto j : c10::irange(n)) {
290       for (const auto i : c10::irange(m)) {
291         c_accum[j * m + i] = beta * c[j * ldc + i];
292       }
293     }
294   }
295   gemm_transb_impl(transb, m, n, k, alpha, a, lda, b, ldb, c_accum.get(), m);
296   for (const auto j : c10::irange(n)) {
297     for (const auto i : c10::irange(m)) {
298       c[j * ldc + i] = c_accum[j * m + i];
299     }
300   }
301 }
302 
303 template <typename scalar_t, typename opmath_t>
gemm_transab_(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)304 void gemm_transab_(
305     TransposeType transa, TransposeType transb,
306     int64_t m, int64_t n, int64_t k,
307     opmath_t alpha,
308     const scalar_t *a, int64_t lda,
309     const scalar_t *b, int64_t ldb,
310     opmath_t beta,
311     scalar_t *c, int64_t ldc) {
312   // c = beta * c + alpha * (a.T @ b.T)
313   for (const auto i : c10::irange(m)) {
314     for (const auto j : c10::irange(n)) {
315       const auto dot = sum(k, [&](int64_t l) -> opmath_t {
316         return static_cast<opmath_t>(transa == TransposeType::ConjTranspose ? conj_impl(a[i * lda + l]) : a[i * lda + l]) *
317             static_cast<opmath_t>(transb == TransposeType::ConjTranspose ? conj_impl(b[l * ldb + j]) : b[l * ldb + j]);
318       });
319 
320       if (beta == opmath_t(0)) {
321         c[j * ldc + i] = alpha * dot;
322       } else {
323         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
324       }
325     }
326   }
327 }
328 
329 #if defined(__aarch64__) && !defined(C10_MOBILE)
330 template <>
gemm_notrans_(int64_t m,int64_t n,int64_t k,float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,float beta,at::Half * c,int64_t ldc)331 void gemm_notrans_(
332     int64_t m,
333     int64_t n,
334     int64_t k,
335     float alpha,
336     const at::Half* a,
337     int64_t lda,
338     const at::Half* b,
339     int64_t ldb,
340     float beta,
341     at::Half* c,
342     int64_t ldc) {
343   // c += alpha * (a @ b)
344   if (n == 1 && beta == 0.0 && alpha == 1.0) {
345     at::native::blas_impl::fp16_gemv_notrans(m, k, 1.0, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, 0.0, reinterpret_cast<float16_t*>(c), 1);
346     return;
347   }
348   for (const auto i : c10::irange(m)) {
349     for (const auto j : c10::irange(n)) {
350       const auto dot = sum(k, [&](int64_t l) -> float {
351         return float(c10::detail::fp16_from_bits(a[l * lda + i].x)) *
352             float(c10::detail::fp16_from_bits(b[j * ldb + l].x));
353       });
354       if (beta == 0) {
355         c[j * ldc + i] = alpha * dot;
356       } else {
357         c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
358       }
359     }
360   }
361 }
362 
363 
load_as_float32x4(const BFloat16 * ptr)364 inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
365   int32x4_t shift = vdupq_n_s32(16);
366   uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
367   return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
368 }
369 
compute_dot(const at::Half * a,const at::Half * b,int64_t len)370 static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
371   return at::native::blas_impl::fp16_dot_with_fp32_arith(
372     reinterpret_cast<const float16_t*>(a),
373     reinterpret_cast<const float16_t*>(b),
374     len);
375 }
376 
compute_dot(const at::BFloat16 * a,const at::BFloat16 * b,int64_t len)377 static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) {
378   return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len);
379 }
380 
381 template <>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,float beta,at::Half * c,int64_t ldc)382 void gemm_transa_(
383     TransposeType transa,
384     int64_t m, int64_t n, int64_t k,
385     float alpha,
386     const at::Half *a, int64_t lda,
387     const at::Half *b, int64_t ldb,
388     float beta,
389     at::Half *c, int64_t ldc) {
390   // c = alpha * (a.T @ b) + beta * c
391   if (n == 1 && beta == 0.0 && alpha == 1.0) {
392     at::native::blas_impl::fp16_gemv_trans(k, m, 1.0, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(b), 1, 0.0, reinterpret_cast<float16_t*>(c), 1);
393     return;
394   }
395   parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
396     const auto *a_ = a + begin * lda;
397     for (const auto i : c10::irange(begin, end)) {
398       const auto *b_ = b;
399       for (const auto j : c10::irange(n)) {
400         const auto dot = compute_dot(a_, b_, k);
401         b_ += ldb;
402         if (beta == 0) {
403           c[j*ldc+i] = alpha*dot;
404         } else {
405           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
406         }
407       }
408       a_ += lda;
409     }
410   });
411 }
412 
413 template <>
gemm_transa_(TransposeType transa,int64_t m,int64_t n,int64_t k,float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,float beta,at::BFloat16 * c,int64_t ldc)414 void gemm_transa_(
415     TransposeType transa,
416     int64_t m, int64_t n, int64_t k,
417     float alpha,
418     const at::BFloat16 *a, int64_t lda,
419     const at::BFloat16 *b, int64_t ldb,
420     float beta,
421     at::BFloat16 *c, int64_t ldc) {
422   // c = alpha * (a.T @ b) + beta * c
423   parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
424     const auto *a_ = a + begin * lda;
425     for (const auto i : c10::irange(begin, end)) {
426       const auto *b_ = b;
427       for (const auto j : c10::irange(n)) {
428         const auto dot = compute_dot(a_, b_, k);
429         b_ += ldb;
430         if (beta == 0) {
431           c[j*ldc+i] = alpha*dot;
432         } else {
433           c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
434         }
435       }
436       a_ += lda;
437     }
438   });
439 }
440 
441 #endif
442 
443 template <typename scalar_t, typename opmath_t>
gemm_core_(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)444 void gemm_core_(
445     TransposeType transa, TransposeType transb,
446     int64_t m, int64_t n, int64_t k,
447     opmath_t alpha,
448     const scalar_t *a, int64_t lda,
449     const scalar_t *b, int64_t ldb,
450     opmath_t beta,
451     scalar_t *c, int64_t ldc) {
452   if (transa == TransposeType::NoTranspose &&
453       transb == TransposeType::NoTranspose) {
454     return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
455   } else if (
456       transa != TransposeType::NoTranspose &&
457       transb == TransposeType::NoTranspose) {
458     gemm_transa_(transa, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
459   } else if (
460       transa == TransposeType::NoTranspose &&
461       transb != TransposeType::NoTranspose) {
462     gemm_transb_(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
463   } else {
464     gemm_transab_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
465   }
466 }
467 
468 #if !defined(C10_MOBILE)
469 #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...)                                                \
470         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                                 \
471             kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
472             TYPE, NAME, __VA_ARGS__)
473 #else
474 #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...)         \
475         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(          \
476             kHalf, kBFloat16,                            \
477             TYPE, NAME, __VA_ARGS__)
478 #endif
cpublas_gemm_impl(at::ScalarType type,TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const Scalar & alpha,const void * a,int64_t lda,const void * b,int64_t ldb,const Scalar & beta,void * c,int64_t ldc)479 void cpublas_gemm_impl(
480     at::ScalarType type,
481     TransposeType transa, TransposeType transb,
482     int64_t m, int64_t n, int64_t k,
483     const Scalar& alpha,
484     const void *a, int64_t lda,
485     const void *b, int64_t ldb,
486     const Scalar& beta,
487     void *c, int64_t ldc) {
488   _AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
489         using opmath_t = at::opmath_type<scalar_t>;
490         gemm_core_(
491             transa, transb, m, n, k,
492             alpha.to<opmath_t>(),
493             static_cast<const scalar_t *>(a), lda,
494             static_cast<const scalar_t *>(b), ldb,
495             beta.to<opmath_t>(),
496             static_cast<scalar_t *>(c), ldc);
497       });
498 }
499 
cpublas_axpy_impl(at::ScalarType type,int64_t n,const Scalar & _a,const void * _x,int64_t incx,void * _y,int64_t incy)500 void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
501   if (type == at::kBool) {
502       auto a = _a.to<bool>();
503       auto x = static_cast<const bool *>(_x);
504       auto y = static_cast<bool *>(_y);
505       int64_t i;
506       for(i = 0; i < n; i++)
507         y[i*incy] |= a & x[i*incx];
508   } else {
509     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl",
510       [&] {
511         using opmath_t = at::opmath_type<scalar_t>;
512         auto a = _a.to<opmath_t>();
513         auto x = static_cast<const scalar_t *>(_x);
514         auto y = static_cast<scalar_t *>(_y);
515         int64_t i;
516         for(i = 0; i < n; i++)
517           y[i*incy] += a*x[i*incx];
518       });
519   }
520 }
521 
cpublas_copy_impl(at::ScalarType type,int64_t n,const void * _x,int64_t incx,void * _y,int64_t incy)522 void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t incx, void *_y, int64_t incy){
523   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::kComplexHalf, at::kHalf, at::kBFloat16, at::kBool, type, "cpublas_copy_impl",
524     [&] {
525       auto x = static_cast<const scalar_t *>(_x);
526       auto y = static_cast<scalar_t *>(_y);
527       int64_t i;
528       for(i = 0; i < n; i++)
529         y[i*incy] = x[i*incx];
530     });
531 }
532 
533 }}  // namespace cpublas::(anonymous)
534 
535 
536 REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl);
537 REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl);
538 REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl);
539 
540 }  // namespace at::native
541