xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/BlasKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Config.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/Parallel.h>
6 #include <c10/core/ScalarType.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/Unroll.h>
10 #include <c10/util/complex.h>
11 #include <c10/util/irange.h>
12 #include <algorithm>
13 #include <climits>
14 #include <limits>
15 
16 #if defined(__aarch64__) && !defined(C10_MOBILE)
17 #include <arm_neon.h>
18 #endif
19 
20 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
21 namespace {
22 
23 /// Wrapper for const_cast<T*> with type-inference.
24 ///
25 /// Use this to call into APIs that are not const-correct.
26 template <typename T>
remove_const(const T * x)27 T* remove_const(const T* x) {
28   return const_cast<T*>(x);
29 }
30 
31 } // namespace
32 
33 #if AT_BUILD_WITH_BLAS()
34 extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy);
35 extern "C" void dscal_(int *n, double *a, double *x, int *incx);
36 extern "C" void sscal_(int *n, float *a, float *x, int *incx);
37 extern "C" void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
38 extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
39 
40 #if AT_BLAS_F2C()
41 # define ffloat double
42 #else
43 # define ffloat float
44 #endif
45 
46 #if AT_BLAS_USE_CBLAS_DOT()
47   extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
48   extern "C" void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
49   extern "C" void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
50   extern "C" void cblas_cdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
51   extern "C" void cblas_zdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
52 
sdot_(const int * n,const float * x,const int * incx,const float * y,const int * incy)53   static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy)
54   {
55     return cblas_sdot(*n, x, *incx, y, *incy);
56   }
cdotu_(std::complex<float> * res,const int * n,const std::complex<float> * x,const int * incx,const std::complex<float> * y,const int * incy)57   static inline void cdotu_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
58   const std::complex<float> *y, const int *incy) {
59     cblas_cdotu_sub(*n, x, *incx, y, *incy, res);
60   }
zdotu_(std::complex<double> * res,const int * n,const std::complex<double> * x,const int * incx,const std::complex<double> * y,const int * incy)61   static inline void zdotu_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
62   const std::complex<double> *y, const int *incy) {
63     cblas_zdotu_sub(*n, x, *incx, y, *incy, res);
64   }
cdotc_(std::complex<float> * res,const int * n,const std::complex<float> * x,const int * incx,const std::complex<float> * y,const int * incy)65   static inline void cdotc_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
66   const std::complex<float> *y, const int *incy) {
67     cblas_cdotc_sub(*n, x, *incx, y, *incy, res);
68   }
zdotc_(std::complex<double> * res,const int * n,const std::complex<double> * x,const int * incx,const std::complex<double> * y,const int * incy)69   static inline void zdotc_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
70   const std::complex<double> *y, const int *incy) {
71     cblas_zdotc_sub(*n, x, *incx, y, *incy, res);
72   }
73 
74 #else
75   extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
76   extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
77   extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
78   extern "C" void cdotc_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
79   extern "C" void zdotc_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
80 #endif // AT_BLAS_USE_CBLAS_DOT
81 #endif // AT_BUILD_WITH_BLAS
82 
83 namespace at::native {
84 
85 namespace blas_impl {
86 #if defined(__aarch64__) && !defined(C10_MOBILE)
87 void fp16_gemv_notrans(
88     const int m,
89     const int n,
90     const float alpha,
91     const float16_t* a,
92     const int lda,
93     const float16_t* x,
94     const int incx,
95     const float beta,
96     float16_t* y,
97     const int incy);
98 
99 void fp16_gemv_trans(
100     const int m,
101     const int n,
102     const float alpha,
103     const float16_t* a,
104     const int lda,
105     const float16_t* x,
106     const int incx,
107     const float beta,
108     float16_t* y,
109     const int incy);
110 
111 float fp16_dot_with_fp32_arith(
112     const float16_t* vec1,
113     const float16_t* vec2,
114     int64_t len);
115 
116 void bf16_gemv_trans(
117     const int m,
118     const int n,
119     const at::BFloat16 alpha,
120     const at::BFloat16* a,
121     const int lda,
122     const at::BFloat16* x,
123     const int incx,
124     const at::BFloat16 beta,
125     at::BFloat16* y,
126     const int incy);
127 
128 float bf16_dot_with_fp32_arith(
129     const at::BFloat16* vec1,
130     const at::BFloat16* vec2,
131     int64_t len);
132 #endif
133 
134 template <typename scalar_t>
scal_use_fast_path(C10_UNUSED int64_t n,C10_UNUSED int64_t incx)135 bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) {
136   return false;
137 }
138 
139 template <typename scalar_t>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,C10_UNUSED scalar_t alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,C10_UNUSED scalar_t beta,C10_UNUSED int64_t incy)140 bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m,
141                         C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha,
142                         C10_UNUSED int64_t lda,
143                         C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta,
144                         C10_UNUSED int64_t incy) {
145   return false;
146 }
147 
148 template <typename scalar_t>
scal_fast_path(C10_UNUSED int * n,C10_UNUSED scalar_t * a,C10_UNUSED scalar_t * x,C10_UNUSED int * incx)149 void scal_fast_path(C10_UNUSED int *n, C10_UNUSED scalar_t *a, C10_UNUSED scalar_t *x, C10_UNUSED int *incx) {
150   TORCH_INTERNAL_ASSERT(false, "scal_fast_path shouldn't be called for this configuration");
151 }
152 
153 template <typename scalar_t>
gemv_fast_path(C10_UNUSED const char * trans,C10_UNUSED const int * m,C10_UNUSED const int * n,C10_UNUSED const scalar_t * alpha,C10_UNUSED const scalar_t * a,C10_UNUSED const int * lda,C10_UNUSED const scalar_t * x,C10_UNUSED const int * incx,C10_UNUSED const scalar_t * beta,C10_UNUSED scalar_t * y,C10_UNUSED const int * incy)154 void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_UNUSED const int *n,
155                     C10_UNUSED  const scalar_t *alpha, C10_UNUSED const scalar_t *a, C10_UNUSED const int *lda,
156                     C10_UNUSED  const scalar_t *x, C10_UNUSED const int *incx, C10_UNUSED const scalar_t *beta,
157                     C10_UNUSED  scalar_t *y, C10_UNUSED const int *incy) {
158   TORCH_INTERNAL_ASSERT(false, "gemv_fast_path shouldn't be called for this configuration");
159 }
160 
161 #define INSTANTIATE(scalar_t)                                                                                                                                                     \
162 template bool scal_use_fast_path<scalar_t>(int64_t n, int64_t incx);                                                                                                              \
163 template bool gemv_use_fast_path<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \
164 template void gemv_fast_path<scalar_t>(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy);      \
165 template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *incx);
166 
167 #if AT_BUILD_WITH_BLAS()
168 template <>
scal_use_fast_path(int64_t n,int64_t incx)169 bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
170   auto intmax = std::numeric_limits<int>::max();
171   return n <= intmax && incx <= intmax;
172 }
173 
174 template <>
scal_use_fast_path(int64_t n,int64_t incx)175 bool scal_use_fast_path<float>(int64_t n, int64_t incx) {
176   return scal_use_fast_path<double>(n, incx);
177 }
178 
179 template <>
scal_fast_path(int * n,double * a,double * x,int * incx)180 void scal_fast_path<double>(int *n, double *a, double *x, int *incx) {
181   dscal_(n, a, x, incx);
182 }
183 
184 template <>
scal_fast_path(int * n,float * a,float * x,int * incx)185 void scal_fast_path<float>(int *n, float *a, float *x, int *incx) {
186   sscal_(n, a, x, incx);
187 }
188 
189 template <>
gemv_use_fast_path(C10_UNUSED char trans,int64_t m,int64_t n,C10_UNUSED float alpha,int64_t lda,int64_t incx,C10_UNUSED float beta,int64_t incy)190 bool gemv_use_fast_path<float>(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) {
191   auto intmax = std::numeric_limits<int>::max();
192   return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
193          (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
194 }
195 
196 template <>
gemv_use_fast_path(C10_UNUSED char trans,int64_t m,int64_t n,C10_UNUSED double alpha,int64_t lda,int64_t incx,C10_UNUSED double beta,int64_t incy)197 bool gemv_use_fast_path<double>(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) {
198   return gemv_use_fast_path<float>(trans, m, n, (float)alpha, lda, incx, (float)beta, incy);
199 }
200 
201 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const double * alpha,const double * a,const int * lda,const double * x,const int * incx,const double * beta,double * y,const int * incy)202 void gemv_fast_path<double>(const char *trans, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy) {
203   dgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
204 }
205 
206 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const float * alpha,const float * a,const int * lda,const float * x,const int * incx,const float * beta,float * y,const int * incy)207 void gemv_fast_path<float>(const char *trans, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy) {
208   sgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
209 }
210 #else
211 INSTANTIATE(float);
212 INSTANTIATE(double);
213 #endif // AT_BUILD_WITH_BLAS
214 
215 INSTANTIATE(uint8_t);
216 INSTANTIATE(int8_t);
217 INSTANTIATE(int16_t);
218 INSTANTIATE(int);
219 INSTANTIATE(int64_t);
220 #if defined(__aarch64__) && !defined(C10_MOBILE)
221 template <>
scal_use_fast_path(C10_UNUSED int64_t n,C10_UNUSED int64_t incx)222 bool scal_use_fast_path<at::Half>(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) {
223   return false;
224 }
225 
226 template <>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,at::Half alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,at::Half beta,C10_UNUSED int64_t incy)227 bool gemv_use_fast_path<at::Half>(
228     C10_UNUSED char trans,
229     C10_UNUSED int64_t m,
230     C10_UNUSED int64_t n,
231     at::Half alpha,
232     C10_UNUSED int64_t lda,
233     C10_UNUSED int64_t incx,
234     at::Half beta,
235     C10_UNUSED int64_t incy) {
236   return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f &&
237     c10::detail::fp16_from_bits(beta.x) == 0.0f;
238 }
239 
240 template <>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,at::BFloat16 alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,at::BFloat16 beta,C10_UNUSED int64_t incy)241 bool gemv_use_fast_path<at::BFloat16>(
242   C10_UNUSED char trans,
243   C10_UNUSED int64_t m,
244     C10_UNUSED int64_t n,
245     at::BFloat16 alpha,
246     C10_UNUSED int64_t lda,
247     C10_UNUSED int64_t incx,
248     at::BFloat16 beta,
249     C10_UNUSED int64_t incy) {
250   return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0;
251 }
252 
253 
254 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
reduce(float16x4_t x)255 static inline float16_t reduce(float16x4_t x) {
256         auto sum = vpadd_f16(x, x);
257         return vget_lane_f16(vpadd_f16(sum, sum), 0);
258 }
reduce(float16x8_t x)259 static inline float16_t reduce(float16x8_t x) {
260         return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x)));
261 }
262 
263 /*
264  * NOTE [ GGML Copyright Notice ]
265  * The below reduce overload and fp16_dot_with_fp16_arith function is
266  * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility
267  * functions, so here is the required copyright notice:
268  *
269  * MIT License
270  *
271  * Copyright (c) 2023-2024 The ggml authors
272  *
273  * Permission is hereby granted, free of charge, to any person obtaining a copy
274  * of this software and associated documentation files (the "Software"), to deal
275  * in the Software without restriction, including without limitation the rights
276  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
277  * copies of the Software, and to permit persons to whom the Software is
278  * furnished to do so, subject to the following conditions:
279  *
280  * The above copyright notice and this permission notice shall be included in all
281  * copies or substantial portions of the Software.
282  *
283  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
284  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
285  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
286  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
287  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
288  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
289  * SOFTWARE.
290  */
291 // We need the shift for reduce(), hence the extra constants.
292 static constexpr auto kF16ElementsPerIterationShift = 7;
293 static constexpr auto kF16ElementsPerIteration = 1 << kF16ElementsPerIterationShift;
294 static_assert(kF16ElementsPerIteration == 128);
295 
296 static constexpr auto kF16ElementsPerRegisterShift = 3;
297 static constexpr auto kF16ElementsPerRegister = 1 << kF16ElementsPerRegisterShift;
298 static_assert(kF16ElementsPerRegister == 8);
299 
300 static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationShift - kF16ElementsPerRegisterShift;
301 static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift;
302 static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister);
303 
reduce(float16x8_t x[kF16RegistersPerIteration])304 static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
305   int offset = kF16RegistersPerIteration;
306   c10::ForcedUnroll<kF16RegistersPerIterationShift>{}([&offset, &x](auto idx) {
307     offset /= 2;
308     for (int i = 0; i < offset; ++i) {
309       x[i] = vaddq_f16(x[i], x[offset + i]);
310     }
311   });
312   const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
313   const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
314   return (double)vaddvq_f32(vaddq_f32(t0, t1));
315 }
316 
f16_fma(float16x8_t a,float16x8_t b,float16x8_t c)317 static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
318 #ifdef __ARM_FEATURE_FMA
319   return vfmaq_f16(a, b, c);
320 #else
321   return vaddq_f16(a, vmulq_f16(b, c));
322 #endif
323 }
324 
fp16_dot_with_fp16_arith(const float16_t * x,const float16_t * a,int len)325 static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, int len) {
326   float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)};
327 
328   const auto len_aligned = len & ~(kF16ElementsPerIteration - 1);
329   for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) {
330     for (int k = 0; k < kF16RegistersPerIteration; ++k) {
331       const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister);
332       const auto temp_a = vld1q_f16(a + j + k * kF16ElementsPerRegister);
333       sum[k] = f16_fma(sum[k], temp_x, temp_a);
334     }
335   }
336   auto reducedSum = reduce(sum);
337 
338   for (int j = len_aligned; j < len; ++j) {
339     reducedSum += x[j] * a[j];
340   }
341   return reducedSum;
342 }
343 
344 // Rather than unrolling to process multiple rows (transposed columns)
345 // of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
346 // along an individual dot product.
fp16_gemv_trans_fp16_arith_by_dot_products(const int m,const int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y,int incy)347 static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
348   parallel_for(0, n, 1, [&](int begin, int end) {
349     for (int i = begin; i < end; ++i) {
350       y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m);
351     }
352   });
353 }
354 
355 #endif
356 
reduce(float32x4_t x)357 static inline float reduce(float32x4_t x) {
358         auto sum = vpaddq_f32(x, x);
359         return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
360 }
361 
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)362 static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
363 #ifdef __ARM_FEATURE_FMA
364   return vfmaq_f32(a, b, c);
365 #else
366   return vaddq_f32(a, vmulq_f32(b, c));
367 #endif
368 }
369 
f32_fma_low_f16(float32x4_t a,float16x8_t b,float16x8_t c)370 static inline float32x4_t f32_fma_low_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
371 #ifdef __ARM_FEATURE_FP16_FML
372   // NOTE: this instruction is an optional instruction in ARM v8.2 and
373   // v8.3, but mandatory in v8.4 per
374   // https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
375   // I'm not certain that I have the right feature test macro.
376   return vfmlalq_low_f16(a, b, c);
377 #else
378   return f32_fma(a, vcvt_f32_f16(vget_low_f16(b)), vcvt_f32_f16(vget_low_f16(c)));
379 #endif
380 }
381 
f32_fma_high_f16(float32x4_t a,float16x8_t b,float16x8_t c)382 static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
383 #ifdef __ARM_FEATURE_FP16_FML
384   // See above note about this instruction.
385   return vfmlalq_high_f16(a, b, c);
386 #else
387   return f32_fma(a, vcvt_f32_f16(vget_high_f16(b)), vcvt_f32_f16(vget_high_f16(c)));
388 #endif
389 }
390 
f32_fma_f16(float32x4_t a,float16x4_t b,float16x4_t c)391 static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) {
392   return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0)));
393 }
394 
395 // The below reduce overload and fp16_dot_with_fp32_arith are adapted
396 // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
397 // functions. See NOTE [ GGML Copyright Notice ] above for the
398 // required notice.
399 
400 // We need the shift for reduce(), hence the extra constants.
401 static constexpr auto kF32ElementsPerIterationShift = 5;
402 static constexpr auto kF32ElementsPerIteration = 1 << kF32ElementsPerIterationShift;
403 static_assert(kF32ElementsPerIteration == 32);
404 
405 static constexpr auto kF32ElementsPerRegisterShift = 2;
406 static constexpr auto kF32ElementsPerRegister = 1 << kF32ElementsPerRegisterShift;
407 static_assert(kF32ElementsPerRegister == 4);
408 
409 static constexpr auto kF32RegisterPairsPerIteration = 4;
410 static constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
411 static constexpr auto kF32RegistersPerIterationShift = 3;
412 static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister);
413 static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
414 
reduce(float32x4_t x[kF32RegistersPerIteration])415 static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
416   int offset = kF32RegistersPerIteration;
417   c10::ForcedUnroll<kF32RegistersPerIterationShift>{}([&offset, &x](auto idx) {
418     offset /= 2;
419     for (int i = 0; i < offset; ++i) {
420       x[i] = vaddq_f32(x[i], x[offset + i]);
421     }
422   });
423   return vaddvq_f32(x[0]);
424 }
425 
dot_with_fp32_arith_main_inner_loop(const float16_t * vec1,const float16_t * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)426 static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
427   const float16_t* vec1,
428   const float16_t* vec2,
429   float32x4_t sum[kF32RegistersPerIteration],
430   int registerPairIndex) {
431   // Load a pair of f32 registers at a time.
432   const auto temp_vec1 = vld1q_f16(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]);
433   const auto temp_vec2 = vld1q_f16(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]);
434 
435   sum[2 * registerPairIndex] = f32_fma_low_f16(sum[2 * registerPairIndex], temp_vec1, temp_vec2);
436   sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2);
437 }
438 
dot_with_fp32_arith_vectorized_tail_inner_loop(const float16_t * vec1,const float16_t * vec2,float32x4_t * tailSum,int idx)439 static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
440   const float16_t* vec1,
441   const float16_t* vec2,
442   float32x4_t* tailSum,
443   int idx) {
444   const auto temp_vec1 = vld1_f16(&vec1[idx]);
445   const auto temp_vec2 = vld1_f16(&vec2[idx]);
446   *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2);
447 }
448 
to_bfloat16(uint16x4_t u16)449 static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
450   int32x4_t shift = vdupq_n_s32(16);
451   return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
452 }
453 
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)454 static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
455   return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
456 }
457 
dot_with_fp32_arith_main_inner_loop(const at::BFloat16 * vec1,const at::BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)458 static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
459   const at::BFloat16* vec1,
460   const at::BFloat16* vec2,
461   float32x4_t sum[kF32RegistersPerIteration],
462   int registerPairIndex) {
463   // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16
464   // Load a pair of f32 registers at a time.
465   const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
466   const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
467 
468   sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2));
469   sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2));
470 }
471 
dot_with_fp32_arith_vectorized_tail_inner_loop(const at::BFloat16 * vec1,const at::BFloat16 * vec2,float32x4_t * tailSum,int idx)472 static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
473   const at::BFloat16* vec1,
474   const at::BFloat16* vec2,
475   float32x4_t* tailSum,
476   int idx) {
477   const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
478   const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
479   *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
480 }
481 
482 template <typename T>
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)483 float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
484   float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
485   const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
486   for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
487     const auto* vec1_ = vec1 + j;
488     const auto* vec2_ = vec2 + j;
489     c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) {
490       dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
491     });
492   }
493   auto reducedSum = reduce(sum);
494 
495   // First-tier tail fixup: make sure we handle workloads that can
496   // benefit from vectorization, but don't fit into our fully unrolled
497   // loop above.
498   float32x4_t tailSum = vdupq_n_f32(0);
499   const auto len_aligned_4 = len & ~3;
500   for (int j = len_aligned; j < len_aligned_4; j += 4) {
501     dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
502   }
503   auto reducedTail = vpaddq_f32(tailSum, tailSum);
504   reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
505 
506   // Second-tier tail fixup: handle all workloads.
507   for (int j = len_aligned_4; j < len; ++j) {
508     reducedSum += vec1[j] * vec2[j];
509   }
510   return reducedSum;
511 }
512 
fp16_dot_with_fp32_arith(const float16_t * vec1,const float16_t * vec2,int64_t len)513 float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) {
514   return dot_with_fp32_arith(vec1, vec2, len);
515 }
516 
bf16_dot_with_fp32_arith(const at::BFloat16 * vec1,const at::BFloat16 * vec2,int64_t len)517 float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
518   return dot_with_fp32_arith(vec1, vec2, len);
519 }
520 
521 // On my Apple M1 Macbook (which is ARM v8.5 and thus has the
522 // instructions f32_fma_{low,high}_f16 is targeting), this kernel has
523 // equivalent performance to the fp16-native kernel.
fp16_gemv_trans_fp32_arith_by_dot_products(const int m,const int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y,int incy)524 static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
525   parallel_for(0, n, 1, [&](int begin, int end) {
526     for (int i = begin; i < end; ++i) {
527       y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m);
528     }
529   });
530 }
531 
bf16_gemv_trans_fp32_arith_by_dot_products(const int m,const int n,const at::BFloat16 * a,const int lda,const at::BFloat16 * x,at::BFloat16 * y,int incy)532 static void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) {
533   parallel_for(0, n, 1, [&](int begin, int end) {
534     for (int i = begin; i < end; ++i) {
535       y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m);
536     }
537   });
538 }
539 
fp16_gemv_trans(const int m,const int n,const float alpha,const float16_t * a,const int lda,const float16_t * x,const int incx,const float beta,float16_t * y,const int incy)540 void fp16_gemv_trans(
541     const int m,
542     const int n,
543     const float alpha,
544     const float16_t* a,
545     const int lda,
546     const float16_t* x,
547     const int incx,
548     const float beta,
549     float16_t* y,
550     const int incy) {
551   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
552 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
553   if (at::globalContext().allowFP16ReductionCPU()) {
554     return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy);
555   }
556 #endif
557   return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
558 }
559 
bf16_gemv_trans(const int m,const int n,const at::BFloat16 alpha,const at::BFloat16 * a,const int lda,const at::BFloat16 * x,const int incx,const at::BFloat16 beta,at::BFloat16 * y,const int incy)560 void bf16_gemv_trans(
561   const int m,
562   const int n,
563   const at::BFloat16 alpha,
564   const at::BFloat16* a,
565   const int lda,
566   const at::BFloat16* x,
567   const int incx,
568   const at::BFloat16 beta,
569   at::BFloat16* y,
570   const int incy) {
571   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
572   return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
573 }
574 
575 
576 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
fp16_gemv_notrans_fp16_arith(int m,int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y)577 static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
578   for (auto j = 0; j < n; j++) {
579     auto vecCol = vdup_n_f16(x[j]);
580     const auto* column = a + lda * j;
581     for (auto i = 0; i < m; i += 4) {
582       auto yf16 = y + i;
583       auto matRow = vld1_f16(column + i);
584       auto resVec = j != 0 ? vld1_f16(yf16) : vdup_n_f16(0);
585       resVec = vfma_lane_f16(resVec, matRow, vecCol, 0);
586       vst1_f16(yf16, resVec);
587     }
588   }
589 }
590 #endif
591 
fp16_gemv_notrans_fp32_arith(int m,int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y)592 static void fp16_gemv_notrans_fp32_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
593   std::vector<float> sum(m);
594   for (auto j = 0; j < n; j++) {
595     auto vecCol = vdup_n_f32(x[j]);
596     const auto* column = a + lda * j;
597     for (auto i = 0; i < m; i += 4) {
598       auto sf32 = sum.data() + i;
599       auto matRow = vcvt_f32_f16(vld1_f16(column + i));
600       auto resVec = j != 0 ? vld1q_f32(sf32) : vdupq_n_f32(0);
601       resVec = vfmaq_lane_f32(resVec, matRow, vecCol, 0);
602       vst1q_f32(sf32, resVec);
603     }
604   }
605 
606   for (auto i = 0; i < m; i+= 4) {
607     vst1_f16(y + i, vcvt_f16_f32(vld1q_f32(sum.data() + i)));
608   }
609 }
610 
fp16_gemv_notrans(const int m,const int n,const float alpha,const float16_t * a,const int lda,const float16_t * x,const int incx,const float beta,float16_t * y,const int incy)611 void fp16_gemv_notrans(
612     const int m,
613     const int n,
614     const float alpha,
615     const float16_t* a,
616     const int lda,
617     const float16_t* x,
618     const int incx,
619     const float beta,
620     float16_t* y,
621     const int incy) {
622   if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && incy == 1) {
623 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
624     return at::globalContext().allowFP16ReductionCPU() ? fp16_gemv_notrans_fp16_arith(m, n, a, lda, x, y)
625                                                        : fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y);
626 #else
627     return fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y);
628 #endif
629   }
630   std::vector<float> sum(m);
631   for (const auto j : c10::irange(n)) {
632     const auto* column_ = a + lda * j;
633     auto z = alpha * x[j * incx];
634     for (const auto i : c10::irange(m)) {
635       sum[i] += z * column_[i];
636     }
637   }
638   if (beta == 0.0) {
639     for (const auto i : c10::irange(m)) {
640       y[i * incy] = sum[i];
641     }
642   } else {
643     for (const auto i : c10::irange(m)) {
644       y[i * incy] += sum[i];
645     }
646   }
647 }
648 
649 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const at::Half * alpha,const at::Half * a,const int * lda,const at::Half * x,const int * incx,const at::Half * beta,at::Half * y,const int * incy)650 void gemv_fast_path<at::Half>(
651     const char* trans,
652     const int* m,
653     const int* n,
654     const at::Half* alpha,
655     const at::Half* a,
656     const int* lda,
657     const at::Half* x,
658     const int* incx,
659     const at::Half* beta,
660     at::Half* y,
661     const int* incy) {
662   using namespace c10::detail;
663   if ((trans[0] == 'T') || (trans[0] == 't')) {
664     fp16_gemv_trans(
665         *m,
666         *n,
667         fp16_from_bits(alpha->x),
668         reinterpret_cast<const float16_t*>(a),
669         *lda,
670         reinterpret_cast<const float16_t*>(x),
671         *incx,
672         fp16_from_bits(beta->x),
673         reinterpret_cast<float16_t*>(y),
674         *incy);
675   } else {
676     fp16_gemv_notrans(
677         *m,
678         *n,
679         fp16_from_bits(alpha->x),
680         reinterpret_cast<const float16_t*>(a),
681         *lda,
682         reinterpret_cast<const float16_t*>(x),
683         *incx,
684         fp16_from_bits(beta->x),
685         reinterpret_cast<float16_t*>(y),
686         *incy);
687   }
688 }
689 
690 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const at::BFloat16 * alpha,const at::BFloat16 * a,const int * lda,const at::BFloat16 * x,const int * incx,const at::BFloat16 * beta,at::BFloat16 * y,const int * incy)691 void gemv_fast_path<at::BFloat16>(
692     const char* trans,
693     const int* m,
694     const int* n,
695     const at::BFloat16* alpha,
696     const at::BFloat16* a,
697     const int* lda,
698     const at::BFloat16* x,
699     const int* incx,
700     const at::BFloat16* beta,
701     at::BFloat16* y,
702     const int* incy) {
703   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't');
704   bf16_gemv_trans(
705     *m,
706     *n,
707     *alpha,
708     a,
709     *lda,
710     x,
711     *incx,
712     *beta,
713     y,
714     *incy);
715 }
716 #else // defined(__aarch64__) && !defined(C10_MOBILE)
717 INSTANTIATE(c10::Half);
718 INSTANTIATE(c10::BFloat16);
719 #endif // defined(__aarch64__) && !defined(C10_MOBILE)
720 #undef INSTANTIATE
721 
722 } // namespace blas_impl
723 
724 template <typename scalar_t>
scal(int64_t n,scalar_t a,scalar_t * x,int64_t incx)725 inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
726 {
727   if (n == 1) incx = 1;
728 #if AT_BUILD_WITH_BLAS()
729   if (blas_impl::scal_use_fast_path<scalar_t>(n, incx)) {
730     int i_n = (int)n;
731     int i_incx = (int)incx;
732     blas_impl::scal_fast_path<scalar_t>(&i_n, &a, x, &i_incx);
733     return;
734   }
735 #endif
736   for (const auto i : c10::irange(n)) {
737     if (a == scalar_t(0)) {
738       x[i * incx] = 0;
739     } else {
740       x[i * incx] *= a;
741     }
742   }
743 }
744 
745 template<typename scalar_t>
gemv(char trans,int64_t m,int64_t n,scalar_t alpha,const scalar_t * a,int64_t lda,const scalar_t * x,int64_t incx,scalar_t beta,scalar_t * y,int64_t incy)746 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) {
747   if(n == 1) lda = m;
748 
749 #if AT_BUILD_WITH_BLAS()
750   if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
751     TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
752     int i_m = (int)m;
753     int i_n = (int)n;
754     int i_lda = (int)lda;
755     int i_incx = (int)incx;
756     int i_incy = (int)incy;
757     blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
758     return;
759   }
760 #endif
761 
762   using opmath_t = at::opmath_type<scalar_t>;
763   if ((trans == 'T') || (trans == 't')) {
764     for (const auto i : c10::irange(n)) {
765       opmath_t sum = 0;
766       const scalar_t *row_ = a + lda * i;
767       for (const auto j : c10::irange(m)) {
768         sum += x[j * incx] * row_[j];
769       }
770       if (beta == scalar_t(0)) {
771         y[i * incy] = alpha * sum;
772       } else {
773         y[i * incy] = beta * y[i * incy] + alpha * sum;
774       }
775     }
776   } else {
777     if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
778 
779     constexpr bool is_low_precision = !std::is_same_v<opmath_t, scalar_t>;
780     std::vector<opmath_t> sum;
781     if constexpr (is_low_precision) {
782       sum.resize(m);
783     }
784     for (const auto j : c10::irange(n)) {
785       const scalar_t *column_ = a + lda * j;
786       opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
787       for (const auto i : c10::irange(m)) {
788         //output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
789         if (j==0 && beta==scalar_t(0)) {
790           if constexpr (!is_low_precision) {
791             y[i * incy] = 0;
792           }
793         }
794         if constexpr (is_low_precision) {
795           sum[i] += z * column_[i];
796         } else {
797           y[i * incy] += z * column_[i];
798         }
799       }
800     }
801     if constexpr (is_low_precision) {
802       if (beta == scalar_t(0)) {
803         for (const auto i : c10::irange(m)) {
804           y[i * incy] = sum[i];
805         }
806       } else {
807         for (const auto i : c10::irange(m)) {
808           y[i * incy] += sum[i];
809         }
810       }
811     }
812   }
813   return;
814 }
815 
816 #define INSTANTIATE(scalar_t, _) \
817 template void gemv<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
818 AT_FORALL_SCALAR_TYPES_AND2(BFloat16, Half, INSTANTIATE);
819 AT_FORALL_COMPLEX_TYPES(INSTANTIATE);
820 #undef INSTANTIATE
821 
822 namespace blas_impl {
823 #if AT_BUILD_WITH_BLAS()
dot_fast_path(int n,float * x,int incx,float * y,int incy)824 static float dot_fast_path(int n, float* x, int incx, float* y, int incy) {
825   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
826   return sdot_(&n, x, &incx, y, &incy);
827 }
828 
dot_fast_path(int n,double * x,int incx,double * y,int incy)829 static double dot_fast_path(int n, double* x, int incx, double* y, int incy) {
830   return ddot_(&n, x, &incx, y, &incy);
831 }
832 
vdot_fast_path(int n,c10::complex<float> * x,int incx,c10::complex<float> * y,int incy)833 static c10::complex<float> vdot_fast_path(int n, c10::complex<float>* x, int incx, c10::complex<float>* y, int incy) {
834   c10::complex<float> result;
835   cdotc_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<std::complex<float>*>(x), &incx, reinterpret_cast<std::complex<float>*>(y), &incy);
836   return result;
837 }
838 
vdot_fast_path(int n,c10::complex<double> * x,int incx,c10::complex<double> * y,int incy)839 static c10::complex<double> vdot_fast_path(int n, c10::complex<double>* x, int incx, c10::complex<double>* y, int incy) {
840   c10::complex<double> result;
841   zdotc_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<std::complex<double>*>(x), &incx, reinterpret_cast<std::complex<double>*>(y), &incy);
842   return result;
843 }
844 
dot_fast_path(int n,c10::complex<double> * x,int incx,c10::complex<double> * y,int incy)845 static c10::complex<double> dot_fast_path(int n, c10::complex<double>* x, int incx, c10::complex<double>* y, int incy) {
846   c10::complex<double> result;
847   zdotu_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<std::complex<double>*>(x), &incx, reinterpret_cast<std::complex<double>*>(y), &incy);
848   return result;
849 }
850 
dot_fast_path(int n,c10::complex<float> * x,int incx,c10::complex<float> * y,int incy)851 static c10::complex<float> dot_fast_path(int n, c10::complex<float>* x, int incx, c10::complex<float>* y, int incy) {
852   c10::complex<float> result;
853   cdotu_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<std::complex<float>*>(x), &incx, reinterpret_cast<std::complex<float>*>(y), &incy);
854   return result;
855 }
856 #endif
857 
858 template <typename scalar_t, typename Functor>
dot_naive(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy,Functor op)859 scalar_t dot_naive(
860     int64_t n,
861     scalar_t* x,
862     int64_t incx,
863     scalar_t* y,
864     int64_t incy,
865     Functor op) {
866   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
867   int64_t i;
868   using opmath_t = at::opmath_type<scalar_t>;
869   opmath_t sum = 0;
870   for (i = 0; i < n; i++) {
871     sum += op(static_cast<opmath_t>(x[i * incx]), static_cast<opmath_t>(y[i * incy]));
872   }
873   return static_cast<scalar_t>(sum);
874 }
875 
876 } // namespace blas_impl
877 
878 template <typename scalar_t>
dot_impl_floating(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)879 scalar_t dot_impl_floating(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy)
880 {
881   if (n == 1) {
882     incx = 1;
883     incy = 1;
884   }
885 #if AT_BUILD_WITH_BLAS()
886         if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
887           return blas_impl::dot_fast_path(n, x, incx, y, incy);
888         } else {
889           return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
890         }
891 #else
892         { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{}); }
893 #endif
894 }
895 
896 template <typename scalar_t>
dot_impl(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)897 scalar_t dot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) {
898   if (n == 1) {
899     incx = 1;
900     incy = 1;
901   }
902   return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
903 }
904 
905 template <>
dot_impl(int64_t n,float * x,int64_t incx,float * y,int64_t incy)906 float dot_impl(int64_t n, float* x, int64_t incx, float* y, int64_t incy) {
907   return dot_impl_floating(n, x, incx, y, incy);
908 }
909 
910 template <>
dot_impl(int64_t n,double * x,int64_t incx,double * y,int64_t incy)911 double dot_impl(int64_t n, double* x, int64_t incx, double* y, int64_t incy) {
912   return dot_impl_floating(n, x, incx, y, incy);
913 }
914 
915 template <>
dot_impl(int64_t n,c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)916 c10::complex<double> dot_impl(int64_t n, c10::complex<double>* x, int64_t incx, c10::complex<double>* y, int64_t incy) {
917   return dot_impl_floating(n, x, incx, y, incy);
918 }
919 
920 template <>
dot_impl(int64_t n,c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)921 c10::complex<float> dot_impl(int64_t n, c10::complex<float>* x, int64_t incx, c10::complex<float>* y, int64_t incy) {
922   return dot_impl_floating(n, x, incx, y, incy);
923 }
924 
925 namespace {
926 template <typename scalar_t>
927 struct vdot_op {
operator ()at::native::__anon780b546e0811::vdot_op928   scalar_t operator()(scalar_t x, scalar_t y) {
929     return std::conj(x) * y;
930   }
931 };
932 } // anonymous namespace
933 
934 template <typename scalar_t>
vdot_impl(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)935 scalar_t vdot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) {
936   if (n == 1) {
937     incx = 1;
938     incy = 1;
939   }
940 #if AT_BUILD_WITH_BLAS()
941         if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
942           return blas_impl::vdot_fast_path(n, x, incx, y, incy);
943         } else {
944           return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{});
945         }
946 #else
947         { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{}); }
948 #endif
949 }
950 
951 // Skip reinstantiating the explicitly specialized types `float` and `double`.
952 #define INSTANTIATE_DOT_IMPL(scalar_t)  \
953   template scalar_t dot_impl<scalar_t>( \
954       int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy);
955 INSTANTIATE_DOT_IMPL(uint8_t);
956 INSTANTIATE_DOT_IMPL(int8_t);
957 INSTANTIATE_DOT_IMPL(int16_t);
958 INSTANTIATE_DOT_IMPL(int);
959 INSTANTIATE_DOT_IMPL(int64_t);
960 INSTANTIATE_DOT_IMPL(c10::Half);
961 INSTANTIATE_DOT_IMPL(c10::BFloat16);
962 
963 #define INSTANTIATE_VDOT_IMPL(scalar_t)  \
964   template scalar_t vdot_impl<scalar_t>( \
965       int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy);
966 INSTANTIATE_VDOT_IMPL(c10::complex<float>);
967 INSTANTIATE_VDOT_IMPL(c10::complex<double>);
968 
969 #undef INSTANTIATE_DOT_IMPL
970 
971 } // namespace at::native
972 C10_DIAGNOSTIC_POP()
973