1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Config.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/Parallel.h>
6 #include <c10/core/ScalarType.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/Unroll.h>
10 #include <c10/util/complex.h>
11 #include <c10/util/irange.h>
12 #include <algorithm>
13 #include <climits>
14 #include <limits>
15
16 #if defined(__aarch64__) && !defined(C10_MOBILE)
17 #include <arm_neon.h>
18 #endif
19
20 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
21 namespace {
22
23 /// Wrapper for const_cast<T*> with type-inference.
24 ///
25 /// Use this to call into APIs that are not const-correct.
26 template <typename T>
remove_const(const T * x)27 T* remove_const(const T* x) {
28 return const_cast<T*>(x);
29 }
30
31 } // namespace
32
33 #if AT_BUILD_WITH_BLAS()
34 extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy);
35 extern "C" void dscal_(int *n, double *a, double *x, int *incx);
36 extern "C" void sscal_(int *n, float *a, float *x, int *incx);
37 extern "C" void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
38 extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
39
40 #if AT_BLAS_F2C()
41 # define ffloat double
42 #else
43 # define ffloat float
44 #endif
45
46 #if AT_BLAS_USE_CBLAS_DOT()
47 extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
48 extern "C" void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
49 extern "C" void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
50 extern "C" void cblas_cdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
51 extern "C" void cblas_zdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
52
sdot_(const int * n,const float * x,const int * incx,const float * y,const int * incy)53 static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy)
54 {
55 return cblas_sdot(*n, x, *incx, y, *incy);
56 }
cdotu_(std::complex<float> * res,const int * n,const std::complex<float> * x,const int * incx,const std::complex<float> * y,const int * incy)57 static inline void cdotu_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
58 const std::complex<float> *y, const int *incy) {
59 cblas_cdotu_sub(*n, x, *incx, y, *incy, res);
60 }
zdotu_(std::complex<double> * res,const int * n,const std::complex<double> * x,const int * incx,const std::complex<double> * y,const int * incy)61 static inline void zdotu_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
62 const std::complex<double> *y, const int *incy) {
63 cblas_zdotu_sub(*n, x, *incx, y, *incy, res);
64 }
cdotc_(std::complex<float> * res,const int * n,const std::complex<float> * x,const int * incx,const std::complex<float> * y,const int * incy)65 static inline void cdotc_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
66 const std::complex<float> *y, const int *incy) {
67 cblas_cdotc_sub(*n, x, *incx, y, *incy, res);
68 }
zdotc_(std::complex<double> * res,const int * n,const std::complex<double> * x,const int * incx,const std::complex<double> * y,const int * incy)69 static inline void zdotc_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
70 const std::complex<double> *y, const int *incy) {
71 cblas_zdotc_sub(*n, x, *incx, y, *incy, res);
72 }
73
74 #else
75 extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
76 extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
77 extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
78 extern "C" void cdotc_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
79 extern "C" void zdotc_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
80 #endif // AT_BLAS_USE_CBLAS_DOT
81 #endif // AT_BUILD_WITH_BLAS
82
83 namespace at::native {
84
85 namespace blas_impl {
86 #if defined(__aarch64__) && !defined(C10_MOBILE)
87 void fp16_gemv_notrans(
88 const int m,
89 const int n,
90 const float alpha,
91 const float16_t* a,
92 const int lda,
93 const float16_t* x,
94 const int incx,
95 const float beta,
96 float16_t* y,
97 const int incy);
98
99 void fp16_gemv_trans(
100 const int m,
101 const int n,
102 const float alpha,
103 const float16_t* a,
104 const int lda,
105 const float16_t* x,
106 const int incx,
107 const float beta,
108 float16_t* y,
109 const int incy);
110
111 float fp16_dot_with_fp32_arith(
112 const float16_t* vec1,
113 const float16_t* vec2,
114 int64_t len);
115
116 void bf16_gemv_trans(
117 const int m,
118 const int n,
119 const at::BFloat16 alpha,
120 const at::BFloat16* a,
121 const int lda,
122 const at::BFloat16* x,
123 const int incx,
124 const at::BFloat16 beta,
125 at::BFloat16* y,
126 const int incy);
127
128 float bf16_dot_with_fp32_arith(
129 const at::BFloat16* vec1,
130 const at::BFloat16* vec2,
131 int64_t len);
132 #endif
133
134 template <typename scalar_t>
scal_use_fast_path(C10_UNUSED int64_t n,C10_UNUSED int64_t incx)135 bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) {
136 return false;
137 }
138
139 template <typename scalar_t>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,C10_UNUSED scalar_t alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,C10_UNUSED scalar_t beta,C10_UNUSED int64_t incy)140 bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m,
141 C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha,
142 C10_UNUSED int64_t lda,
143 C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta,
144 C10_UNUSED int64_t incy) {
145 return false;
146 }
147
148 template <typename scalar_t>
scal_fast_path(C10_UNUSED int * n,C10_UNUSED scalar_t * a,C10_UNUSED scalar_t * x,C10_UNUSED int * incx)149 void scal_fast_path(C10_UNUSED int *n, C10_UNUSED scalar_t *a, C10_UNUSED scalar_t *x, C10_UNUSED int *incx) {
150 TORCH_INTERNAL_ASSERT(false, "scal_fast_path shouldn't be called for this configuration");
151 }
152
153 template <typename scalar_t>
gemv_fast_path(C10_UNUSED const char * trans,C10_UNUSED const int * m,C10_UNUSED const int * n,C10_UNUSED const scalar_t * alpha,C10_UNUSED const scalar_t * a,C10_UNUSED const int * lda,C10_UNUSED const scalar_t * x,C10_UNUSED const int * incx,C10_UNUSED const scalar_t * beta,C10_UNUSED scalar_t * y,C10_UNUSED const int * incy)154 void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_UNUSED const int *n,
155 C10_UNUSED const scalar_t *alpha, C10_UNUSED const scalar_t *a, C10_UNUSED const int *lda,
156 C10_UNUSED const scalar_t *x, C10_UNUSED const int *incx, C10_UNUSED const scalar_t *beta,
157 C10_UNUSED scalar_t *y, C10_UNUSED const int *incy) {
158 TORCH_INTERNAL_ASSERT(false, "gemv_fast_path shouldn't be called for this configuration");
159 }
160
161 #define INSTANTIATE(scalar_t) \
162 template bool scal_use_fast_path<scalar_t>(int64_t n, int64_t incx); \
163 template bool gemv_use_fast_path<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \
164 template void gemv_fast_path<scalar_t>(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \
165 template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *incx);
166
167 #if AT_BUILD_WITH_BLAS()
168 template <>
scal_use_fast_path(int64_t n,int64_t incx)169 bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
170 auto intmax = std::numeric_limits<int>::max();
171 return n <= intmax && incx <= intmax;
172 }
173
174 template <>
scal_use_fast_path(int64_t n,int64_t incx)175 bool scal_use_fast_path<float>(int64_t n, int64_t incx) {
176 return scal_use_fast_path<double>(n, incx);
177 }
178
179 template <>
scal_fast_path(int * n,double * a,double * x,int * incx)180 void scal_fast_path<double>(int *n, double *a, double *x, int *incx) {
181 dscal_(n, a, x, incx);
182 }
183
184 template <>
scal_fast_path(int * n,float * a,float * x,int * incx)185 void scal_fast_path<float>(int *n, float *a, float *x, int *incx) {
186 sscal_(n, a, x, incx);
187 }
188
189 template <>
gemv_use_fast_path(C10_UNUSED char trans,int64_t m,int64_t n,C10_UNUSED float alpha,int64_t lda,int64_t incx,C10_UNUSED float beta,int64_t incy)190 bool gemv_use_fast_path<float>(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) {
191 auto intmax = std::numeric_limits<int>::max();
192 return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
193 (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
194 }
195
196 template <>
gemv_use_fast_path(C10_UNUSED char trans,int64_t m,int64_t n,C10_UNUSED double alpha,int64_t lda,int64_t incx,C10_UNUSED double beta,int64_t incy)197 bool gemv_use_fast_path<double>(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) {
198 return gemv_use_fast_path<float>(trans, m, n, (float)alpha, lda, incx, (float)beta, incy);
199 }
200
201 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const double * alpha,const double * a,const int * lda,const double * x,const int * incx,const double * beta,double * y,const int * incy)202 void gemv_fast_path<double>(const char *trans, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy) {
203 dgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
204 }
205
206 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const float * alpha,const float * a,const int * lda,const float * x,const int * incx,const float * beta,float * y,const int * incy)207 void gemv_fast_path<float>(const char *trans, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy) {
208 sgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
209 }
210 #else
211 INSTANTIATE(float);
212 INSTANTIATE(double);
213 #endif // AT_BUILD_WITH_BLAS
214
215 INSTANTIATE(uint8_t);
216 INSTANTIATE(int8_t);
217 INSTANTIATE(int16_t);
218 INSTANTIATE(int);
219 INSTANTIATE(int64_t);
220 #if defined(__aarch64__) && !defined(C10_MOBILE)
221 template <>
scal_use_fast_path(C10_UNUSED int64_t n,C10_UNUSED int64_t incx)222 bool scal_use_fast_path<at::Half>(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) {
223 return false;
224 }
225
226 template <>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,at::Half alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,at::Half beta,C10_UNUSED int64_t incy)227 bool gemv_use_fast_path<at::Half>(
228 C10_UNUSED char trans,
229 C10_UNUSED int64_t m,
230 C10_UNUSED int64_t n,
231 at::Half alpha,
232 C10_UNUSED int64_t lda,
233 C10_UNUSED int64_t incx,
234 at::Half beta,
235 C10_UNUSED int64_t incy) {
236 return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f &&
237 c10::detail::fp16_from_bits(beta.x) == 0.0f;
238 }
239
240 template <>
gemv_use_fast_path(C10_UNUSED char trans,C10_UNUSED int64_t m,C10_UNUSED int64_t n,at::BFloat16 alpha,C10_UNUSED int64_t lda,C10_UNUSED int64_t incx,at::BFloat16 beta,C10_UNUSED int64_t incy)241 bool gemv_use_fast_path<at::BFloat16>(
242 C10_UNUSED char trans,
243 C10_UNUSED int64_t m,
244 C10_UNUSED int64_t n,
245 at::BFloat16 alpha,
246 C10_UNUSED int64_t lda,
247 C10_UNUSED int64_t incx,
248 at::BFloat16 beta,
249 C10_UNUSED int64_t incy) {
250 return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0;
251 }
252
253
254 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
reduce(float16x4_t x)255 static inline float16_t reduce(float16x4_t x) {
256 auto sum = vpadd_f16(x, x);
257 return vget_lane_f16(vpadd_f16(sum, sum), 0);
258 }
reduce(float16x8_t x)259 static inline float16_t reduce(float16x8_t x) {
260 return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x)));
261 }
262
263 /*
264 * NOTE [ GGML Copyright Notice ]
265 * The below reduce overload and fp16_dot_with_fp16_arith function is
266 * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility
267 * functions, so here is the required copyright notice:
268 *
269 * MIT License
270 *
271 * Copyright (c) 2023-2024 The ggml authors
272 *
273 * Permission is hereby granted, free of charge, to any person obtaining a copy
274 * of this software and associated documentation files (the "Software"), to deal
275 * in the Software without restriction, including without limitation the rights
276 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
277 * copies of the Software, and to permit persons to whom the Software is
278 * furnished to do so, subject to the following conditions:
279 *
280 * The above copyright notice and this permission notice shall be included in all
281 * copies or substantial portions of the Software.
282 *
283 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
284 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
285 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
286 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
287 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
288 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
289 * SOFTWARE.
290 */
291 // We need the shift for reduce(), hence the extra constants.
292 static constexpr auto kF16ElementsPerIterationShift = 7;
293 static constexpr auto kF16ElementsPerIteration = 1 << kF16ElementsPerIterationShift;
294 static_assert(kF16ElementsPerIteration == 128);
295
296 static constexpr auto kF16ElementsPerRegisterShift = 3;
297 static constexpr auto kF16ElementsPerRegister = 1 << kF16ElementsPerRegisterShift;
298 static_assert(kF16ElementsPerRegister == 8);
299
300 static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationShift - kF16ElementsPerRegisterShift;
301 static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift;
302 static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister);
303
reduce(float16x8_t x[kF16RegistersPerIteration])304 static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
305 int offset = kF16RegistersPerIteration;
306 c10::ForcedUnroll<kF16RegistersPerIterationShift>{}([&offset, &x](auto idx) {
307 offset /= 2;
308 for (int i = 0; i < offset; ++i) {
309 x[i] = vaddq_f16(x[i], x[offset + i]);
310 }
311 });
312 const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
313 const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
314 return (double)vaddvq_f32(vaddq_f32(t0, t1));
315 }
316
f16_fma(float16x8_t a,float16x8_t b,float16x8_t c)317 static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
318 #ifdef __ARM_FEATURE_FMA
319 return vfmaq_f16(a, b, c);
320 #else
321 return vaddq_f16(a, vmulq_f16(b, c));
322 #endif
323 }
324
fp16_dot_with_fp16_arith(const float16_t * x,const float16_t * a,int len)325 static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, int len) {
326 float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)};
327
328 const auto len_aligned = len & ~(kF16ElementsPerIteration - 1);
329 for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) {
330 for (int k = 0; k < kF16RegistersPerIteration; ++k) {
331 const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister);
332 const auto temp_a = vld1q_f16(a + j + k * kF16ElementsPerRegister);
333 sum[k] = f16_fma(sum[k], temp_x, temp_a);
334 }
335 }
336 auto reducedSum = reduce(sum);
337
338 for (int j = len_aligned; j < len; ++j) {
339 reducedSum += x[j] * a[j];
340 }
341 return reducedSum;
342 }
343
344 // Rather than unrolling to process multiple rows (transposed columns)
345 // of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
346 // along an individual dot product.
fp16_gemv_trans_fp16_arith_by_dot_products(const int m,const int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y,int incy)347 static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
348 parallel_for(0, n, 1, [&](int begin, int end) {
349 for (int i = begin; i < end; ++i) {
350 y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m);
351 }
352 });
353 }
354
355 #endif
356
reduce(float32x4_t x)357 static inline float reduce(float32x4_t x) {
358 auto sum = vpaddq_f32(x, x);
359 return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
360 }
361
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)362 static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
363 #ifdef __ARM_FEATURE_FMA
364 return vfmaq_f32(a, b, c);
365 #else
366 return vaddq_f32(a, vmulq_f32(b, c));
367 #endif
368 }
369
f32_fma_low_f16(float32x4_t a,float16x8_t b,float16x8_t c)370 static inline float32x4_t f32_fma_low_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
371 #ifdef __ARM_FEATURE_FP16_FML
372 // NOTE: this instruction is an optional instruction in ARM v8.2 and
373 // v8.3, but mandatory in v8.4 per
374 // https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
375 // I'm not certain that I have the right feature test macro.
376 return vfmlalq_low_f16(a, b, c);
377 #else
378 return f32_fma(a, vcvt_f32_f16(vget_low_f16(b)), vcvt_f32_f16(vget_low_f16(c)));
379 #endif
380 }
381
f32_fma_high_f16(float32x4_t a,float16x8_t b,float16x8_t c)382 static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
383 #ifdef __ARM_FEATURE_FP16_FML
384 // See above note about this instruction.
385 return vfmlalq_high_f16(a, b, c);
386 #else
387 return f32_fma(a, vcvt_f32_f16(vget_high_f16(b)), vcvt_f32_f16(vget_high_f16(c)));
388 #endif
389 }
390
f32_fma_f16(float32x4_t a,float16x4_t b,float16x4_t c)391 static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) {
392 return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0)));
393 }
394
395 // The below reduce overload and fp16_dot_with_fp32_arith are adapted
396 // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
397 // functions. See NOTE [ GGML Copyright Notice ] above for the
398 // required notice.
399
400 // We need the shift for reduce(), hence the extra constants.
401 static constexpr auto kF32ElementsPerIterationShift = 5;
402 static constexpr auto kF32ElementsPerIteration = 1 << kF32ElementsPerIterationShift;
403 static_assert(kF32ElementsPerIteration == 32);
404
405 static constexpr auto kF32ElementsPerRegisterShift = 2;
406 static constexpr auto kF32ElementsPerRegister = 1 << kF32ElementsPerRegisterShift;
407 static_assert(kF32ElementsPerRegister == 4);
408
409 static constexpr auto kF32RegisterPairsPerIteration = 4;
410 static constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
411 static constexpr auto kF32RegistersPerIterationShift = 3;
412 static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister);
413 static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
414
reduce(float32x4_t x[kF32RegistersPerIteration])415 static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
416 int offset = kF32RegistersPerIteration;
417 c10::ForcedUnroll<kF32RegistersPerIterationShift>{}([&offset, &x](auto idx) {
418 offset /= 2;
419 for (int i = 0; i < offset; ++i) {
420 x[i] = vaddq_f32(x[i], x[offset + i]);
421 }
422 });
423 return vaddvq_f32(x[0]);
424 }
425
dot_with_fp32_arith_main_inner_loop(const float16_t * vec1,const float16_t * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)426 static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
427 const float16_t* vec1,
428 const float16_t* vec2,
429 float32x4_t sum[kF32RegistersPerIteration],
430 int registerPairIndex) {
431 // Load a pair of f32 registers at a time.
432 const auto temp_vec1 = vld1q_f16(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]);
433 const auto temp_vec2 = vld1q_f16(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]);
434
435 sum[2 * registerPairIndex] = f32_fma_low_f16(sum[2 * registerPairIndex], temp_vec1, temp_vec2);
436 sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2);
437 }
438
dot_with_fp32_arith_vectorized_tail_inner_loop(const float16_t * vec1,const float16_t * vec2,float32x4_t * tailSum,int idx)439 static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
440 const float16_t* vec1,
441 const float16_t* vec2,
442 float32x4_t* tailSum,
443 int idx) {
444 const auto temp_vec1 = vld1_f16(&vec1[idx]);
445 const auto temp_vec2 = vld1_f16(&vec2[idx]);
446 *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2);
447 }
448
to_bfloat16(uint16x4_t u16)449 static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
450 int32x4_t shift = vdupq_n_s32(16);
451 return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
452 }
453
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)454 static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
455 return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
456 }
457
dot_with_fp32_arith_main_inner_loop(const at::BFloat16 * vec1,const at::BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)458 static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop(
459 const at::BFloat16* vec1,
460 const at::BFloat16* vec2,
461 float32x4_t sum[kF32RegistersPerIteration],
462 int registerPairIndex) {
463 // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16
464 // Load a pair of f32 registers at a time.
465 const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
466 const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
467
468 sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2));
469 sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2));
470 }
471
dot_with_fp32_arith_vectorized_tail_inner_loop(const at::BFloat16 * vec1,const at::BFloat16 * vec2,float32x4_t * tailSum,int idx)472 static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
473 const at::BFloat16* vec1,
474 const at::BFloat16* vec2,
475 float32x4_t* tailSum,
476 int idx) {
477 const auto temp_vec1 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
478 const auto temp_vec2 = vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
479 *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
480 }
481
482 template <typename T>
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)483 float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
484 float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
485 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
486 for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
487 const auto* vec1_ = vec1 + j;
488 const auto* vec2_ = vec2 + j;
489 c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) {
490 dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
491 });
492 }
493 auto reducedSum = reduce(sum);
494
495 // First-tier tail fixup: make sure we handle workloads that can
496 // benefit from vectorization, but don't fit into our fully unrolled
497 // loop above.
498 float32x4_t tailSum = vdupq_n_f32(0);
499 const auto len_aligned_4 = len & ~3;
500 for (int j = len_aligned; j < len_aligned_4; j += 4) {
501 dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
502 }
503 auto reducedTail = vpaddq_f32(tailSum, tailSum);
504 reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
505
506 // Second-tier tail fixup: handle all workloads.
507 for (int j = len_aligned_4; j < len; ++j) {
508 reducedSum += vec1[j] * vec2[j];
509 }
510 return reducedSum;
511 }
512
fp16_dot_with_fp32_arith(const float16_t * vec1,const float16_t * vec2,int64_t len)513 float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) {
514 return dot_with_fp32_arith(vec1, vec2, len);
515 }
516
bf16_dot_with_fp32_arith(const at::BFloat16 * vec1,const at::BFloat16 * vec2,int64_t len)517 float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
518 return dot_with_fp32_arith(vec1, vec2, len);
519 }
520
521 // On my Apple M1 Macbook (which is ARM v8.5 and thus has the
522 // instructions f32_fma_{low,high}_f16 is targeting), this kernel has
523 // equivalent performance to the fp16-native kernel.
fp16_gemv_trans_fp32_arith_by_dot_products(const int m,const int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y,int incy)524 static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
525 parallel_for(0, n, 1, [&](int begin, int end) {
526 for (int i = begin; i < end; ++i) {
527 y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m);
528 }
529 });
530 }
531
bf16_gemv_trans_fp32_arith_by_dot_products(const int m,const int n,const at::BFloat16 * a,const int lda,const at::BFloat16 * x,at::BFloat16 * y,int incy)532 static void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) {
533 parallel_for(0, n, 1, [&](int begin, int end) {
534 for (int i = begin; i < end; ++i) {
535 y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m);
536 }
537 });
538 }
539
fp16_gemv_trans(const int m,const int n,const float alpha,const float16_t * a,const int lda,const float16_t * x,const int incx,const float beta,float16_t * y,const int incy)540 void fp16_gemv_trans(
541 const int m,
542 const int n,
543 const float alpha,
544 const float16_t* a,
545 const int lda,
546 const float16_t* x,
547 const int incx,
548 const float beta,
549 float16_t* y,
550 const int incy) {
551 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
552 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
553 if (at::globalContext().allowFP16ReductionCPU()) {
554 return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy);
555 }
556 #endif
557 return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
558 }
559
bf16_gemv_trans(const int m,const int n,const at::BFloat16 alpha,const at::BFloat16 * a,const int lda,const at::BFloat16 * x,const int incx,const at::BFloat16 beta,at::BFloat16 * y,const int incy)560 void bf16_gemv_trans(
561 const int m,
562 const int n,
563 const at::BFloat16 alpha,
564 const at::BFloat16* a,
565 const int lda,
566 const at::BFloat16* x,
567 const int incx,
568 const at::BFloat16 beta,
569 at::BFloat16* y,
570 const int incy) {
571 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
572 return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
573 }
574
575
576 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
fp16_gemv_notrans_fp16_arith(int m,int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y)577 static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
578 for (auto j = 0; j < n; j++) {
579 auto vecCol = vdup_n_f16(x[j]);
580 const auto* column = a + lda * j;
581 for (auto i = 0; i < m; i += 4) {
582 auto yf16 = y + i;
583 auto matRow = vld1_f16(column + i);
584 auto resVec = j != 0 ? vld1_f16(yf16) : vdup_n_f16(0);
585 resVec = vfma_lane_f16(resVec, matRow, vecCol, 0);
586 vst1_f16(yf16, resVec);
587 }
588 }
589 }
590 #endif
591
fp16_gemv_notrans_fp32_arith(int m,int n,const float16_t * a,const int lda,const float16_t * x,float16_t * y)592 static void fp16_gemv_notrans_fp32_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
593 std::vector<float> sum(m);
594 for (auto j = 0; j < n; j++) {
595 auto vecCol = vdup_n_f32(x[j]);
596 const auto* column = a + lda * j;
597 for (auto i = 0; i < m; i += 4) {
598 auto sf32 = sum.data() + i;
599 auto matRow = vcvt_f32_f16(vld1_f16(column + i));
600 auto resVec = j != 0 ? vld1q_f32(sf32) : vdupq_n_f32(0);
601 resVec = vfmaq_lane_f32(resVec, matRow, vecCol, 0);
602 vst1q_f32(sf32, resVec);
603 }
604 }
605
606 for (auto i = 0; i < m; i+= 4) {
607 vst1_f16(y + i, vcvt_f16_f32(vld1q_f32(sum.data() + i)));
608 }
609 }
610
fp16_gemv_notrans(const int m,const int n,const float alpha,const float16_t * a,const int lda,const float16_t * x,const int incx,const float beta,float16_t * y,const int incy)611 void fp16_gemv_notrans(
612 const int m,
613 const int n,
614 const float alpha,
615 const float16_t* a,
616 const int lda,
617 const float16_t* x,
618 const int incx,
619 const float beta,
620 float16_t* y,
621 const int incy) {
622 if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && incy == 1) {
623 #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
624 return at::globalContext().allowFP16ReductionCPU() ? fp16_gemv_notrans_fp16_arith(m, n, a, lda, x, y)
625 : fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y);
626 #else
627 return fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y);
628 #endif
629 }
630 std::vector<float> sum(m);
631 for (const auto j : c10::irange(n)) {
632 const auto* column_ = a + lda * j;
633 auto z = alpha * x[j * incx];
634 for (const auto i : c10::irange(m)) {
635 sum[i] += z * column_[i];
636 }
637 }
638 if (beta == 0.0) {
639 for (const auto i : c10::irange(m)) {
640 y[i * incy] = sum[i];
641 }
642 } else {
643 for (const auto i : c10::irange(m)) {
644 y[i * incy] += sum[i];
645 }
646 }
647 }
648
649 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const at::Half * alpha,const at::Half * a,const int * lda,const at::Half * x,const int * incx,const at::Half * beta,at::Half * y,const int * incy)650 void gemv_fast_path<at::Half>(
651 const char* trans,
652 const int* m,
653 const int* n,
654 const at::Half* alpha,
655 const at::Half* a,
656 const int* lda,
657 const at::Half* x,
658 const int* incx,
659 const at::Half* beta,
660 at::Half* y,
661 const int* incy) {
662 using namespace c10::detail;
663 if ((trans[0] == 'T') || (trans[0] == 't')) {
664 fp16_gemv_trans(
665 *m,
666 *n,
667 fp16_from_bits(alpha->x),
668 reinterpret_cast<const float16_t*>(a),
669 *lda,
670 reinterpret_cast<const float16_t*>(x),
671 *incx,
672 fp16_from_bits(beta->x),
673 reinterpret_cast<float16_t*>(y),
674 *incy);
675 } else {
676 fp16_gemv_notrans(
677 *m,
678 *n,
679 fp16_from_bits(alpha->x),
680 reinterpret_cast<const float16_t*>(a),
681 *lda,
682 reinterpret_cast<const float16_t*>(x),
683 *incx,
684 fp16_from_bits(beta->x),
685 reinterpret_cast<float16_t*>(y),
686 *incy);
687 }
688 }
689
690 template <>
gemv_fast_path(const char * trans,const int * m,const int * n,const at::BFloat16 * alpha,const at::BFloat16 * a,const int * lda,const at::BFloat16 * x,const int * incx,const at::BFloat16 * beta,at::BFloat16 * y,const int * incy)691 void gemv_fast_path<at::BFloat16>(
692 const char* trans,
693 const int* m,
694 const int* n,
695 const at::BFloat16* alpha,
696 const at::BFloat16* a,
697 const int* lda,
698 const at::BFloat16* x,
699 const int* incx,
700 const at::BFloat16* beta,
701 at::BFloat16* y,
702 const int* incy) {
703 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't');
704 bf16_gemv_trans(
705 *m,
706 *n,
707 *alpha,
708 a,
709 *lda,
710 x,
711 *incx,
712 *beta,
713 y,
714 *incy);
715 }
716 #else // defined(__aarch64__) && !defined(C10_MOBILE)
717 INSTANTIATE(c10::Half);
718 INSTANTIATE(c10::BFloat16);
719 #endif // defined(__aarch64__) && !defined(C10_MOBILE)
720 #undef INSTANTIATE
721
722 } // namespace blas_impl
723
724 template <typename scalar_t>
scal(int64_t n,scalar_t a,scalar_t * x,int64_t incx)725 inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
726 {
727 if (n == 1) incx = 1;
728 #if AT_BUILD_WITH_BLAS()
729 if (blas_impl::scal_use_fast_path<scalar_t>(n, incx)) {
730 int i_n = (int)n;
731 int i_incx = (int)incx;
732 blas_impl::scal_fast_path<scalar_t>(&i_n, &a, x, &i_incx);
733 return;
734 }
735 #endif
736 for (const auto i : c10::irange(n)) {
737 if (a == scalar_t(0)) {
738 x[i * incx] = 0;
739 } else {
740 x[i * incx] *= a;
741 }
742 }
743 }
744
745 template<typename scalar_t>
gemv(char trans,int64_t m,int64_t n,scalar_t alpha,const scalar_t * a,int64_t lda,const scalar_t * x,int64_t incx,scalar_t beta,scalar_t * y,int64_t incy)746 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) {
747 if(n == 1) lda = m;
748
749 #if AT_BUILD_WITH_BLAS()
750 if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
751 TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
752 int i_m = (int)m;
753 int i_n = (int)n;
754 int i_lda = (int)lda;
755 int i_incx = (int)incx;
756 int i_incy = (int)incy;
757 blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
758 return;
759 }
760 #endif
761
762 using opmath_t = at::opmath_type<scalar_t>;
763 if ((trans == 'T') || (trans == 't')) {
764 for (const auto i : c10::irange(n)) {
765 opmath_t sum = 0;
766 const scalar_t *row_ = a + lda * i;
767 for (const auto j : c10::irange(m)) {
768 sum += x[j * incx] * row_[j];
769 }
770 if (beta == scalar_t(0)) {
771 y[i * incy] = alpha * sum;
772 } else {
773 y[i * incy] = beta * y[i * incy] + alpha * sum;
774 }
775 }
776 } else {
777 if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
778
779 constexpr bool is_low_precision = !std::is_same_v<opmath_t, scalar_t>;
780 std::vector<opmath_t> sum;
781 if constexpr (is_low_precision) {
782 sum.resize(m);
783 }
784 for (const auto j : c10::irange(n)) {
785 const scalar_t *column_ = a + lda * j;
786 opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
787 for (const auto i : c10::irange(m)) {
788 //output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
789 if (j==0 && beta==scalar_t(0)) {
790 if constexpr (!is_low_precision) {
791 y[i * incy] = 0;
792 }
793 }
794 if constexpr (is_low_precision) {
795 sum[i] += z * column_[i];
796 } else {
797 y[i * incy] += z * column_[i];
798 }
799 }
800 }
801 if constexpr (is_low_precision) {
802 if (beta == scalar_t(0)) {
803 for (const auto i : c10::irange(m)) {
804 y[i * incy] = sum[i];
805 }
806 } else {
807 for (const auto i : c10::irange(m)) {
808 y[i * incy] += sum[i];
809 }
810 }
811 }
812 }
813 return;
814 }
815
816 #define INSTANTIATE(scalar_t, _) \
817 template void gemv<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
818 AT_FORALL_SCALAR_TYPES_AND2(BFloat16, Half, INSTANTIATE);
819 AT_FORALL_COMPLEX_TYPES(INSTANTIATE);
820 #undef INSTANTIATE
821
822 namespace blas_impl {
823 #if AT_BUILD_WITH_BLAS()
dot_fast_path(int n,float * x,int incx,float * y,int incy)824 static float dot_fast_path(int n, float* x, int incx, float* y, int incy) {
825 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
826 return sdot_(&n, x, &incx, y, &incy);
827 }
828
dot_fast_path(int n,double * x,int incx,double * y,int incy)829 static double dot_fast_path(int n, double* x, int incx, double* y, int incy) {
830 return ddot_(&n, x, &incx, y, &incy);
831 }
832
vdot_fast_path(int n,c10::complex<float> * x,int incx,c10::complex<float> * y,int incy)833 static c10::complex<float> vdot_fast_path(int n, c10::complex<float>* x, int incx, c10::complex<float>* y, int incy) {
834 c10::complex<float> result;
835 cdotc_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<std::complex<float>*>(x), &incx, reinterpret_cast<std::complex<float>*>(y), &incy);
836 return result;
837 }
838
vdot_fast_path(int n,c10::complex<double> * x,int incx,c10::complex<double> * y,int incy)839 static c10::complex<double> vdot_fast_path(int n, c10::complex<double>* x, int incx, c10::complex<double>* y, int incy) {
840 c10::complex<double> result;
841 zdotc_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<std::complex<double>*>(x), &incx, reinterpret_cast<std::complex<double>*>(y), &incy);
842 return result;
843 }
844
dot_fast_path(int n,c10::complex<double> * x,int incx,c10::complex<double> * y,int incy)845 static c10::complex<double> dot_fast_path(int n, c10::complex<double>* x, int incx, c10::complex<double>* y, int incy) {
846 c10::complex<double> result;
847 zdotu_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<std::complex<double>*>(x), &incx, reinterpret_cast<std::complex<double>*>(y), &incy);
848 return result;
849 }
850
dot_fast_path(int n,c10::complex<float> * x,int incx,c10::complex<float> * y,int incy)851 static c10::complex<float> dot_fast_path(int n, c10::complex<float>* x, int incx, c10::complex<float>* y, int incy) {
852 c10::complex<float> result;
853 cdotu_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<std::complex<float>*>(x), &incx, reinterpret_cast<std::complex<float>*>(y), &incy);
854 return result;
855 }
856 #endif
857
858 template <typename scalar_t, typename Functor>
dot_naive(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy,Functor op)859 scalar_t dot_naive(
860 int64_t n,
861 scalar_t* x,
862 int64_t incx,
863 scalar_t* y,
864 int64_t incy,
865 Functor op) {
866 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
867 int64_t i;
868 using opmath_t = at::opmath_type<scalar_t>;
869 opmath_t sum = 0;
870 for (i = 0; i < n; i++) {
871 sum += op(static_cast<opmath_t>(x[i * incx]), static_cast<opmath_t>(y[i * incy]));
872 }
873 return static_cast<scalar_t>(sum);
874 }
875
876 } // namespace blas_impl
877
878 template <typename scalar_t>
dot_impl_floating(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)879 scalar_t dot_impl_floating(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy)
880 {
881 if (n == 1) {
882 incx = 1;
883 incy = 1;
884 }
885 #if AT_BUILD_WITH_BLAS()
886 if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
887 return blas_impl::dot_fast_path(n, x, incx, y, incy);
888 } else {
889 return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
890 }
891 #else
892 { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{}); }
893 #endif
894 }
895
896 template <typename scalar_t>
dot_impl(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)897 scalar_t dot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) {
898 if (n == 1) {
899 incx = 1;
900 incy = 1;
901 }
902 return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
903 }
904
905 template <>
dot_impl(int64_t n,float * x,int64_t incx,float * y,int64_t incy)906 float dot_impl(int64_t n, float* x, int64_t incx, float* y, int64_t incy) {
907 return dot_impl_floating(n, x, incx, y, incy);
908 }
909
910 template <>
dot_impl(int64_t n,double * x,int64_t incx,double * y,int64_t incy)911 double dot_impl(int64_t n, double* x, int64_t incx, double* y, int64_t incy) {
912 return dot_impl_floating(n, x, incx, y, incy);
913 }
914
915 template <>
dot_impl(int64_t n,c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)916 c10::complex<double> dot_impl(int64_t n, c10::complex<double>* x, int64_t incx, c10::complex<double>* y, int64_t incy) {
917 return dot_impl_floating(n, x, incx, y, incy);
918 }
919
920 template <>
dot_impl(int64_t n,c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)921 c10::complex<float> dot_impl(int64_t n, c10::complex<float>* x, int64_t incx, c10::complex<float>* y, int64_t incy) {
922 return dot_impl_floating(n, x, incx, y, incy);
923 }
924
925 namespace {
926 template <typename scalar_t>
927 struct vdot_op {
operator ()at::native::__anon780b546e0811::vdot_op928 scalar_t operator()(scalar_t x, scalar_t y) {
929 return std::conj(x) * y;
930 }
931 };
932 } // anonymous namespace
933
934 template <typename scalar_t>
vdot_impl(int64_t n,scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)935 scalar_t vdot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) {
936 if (n == 1) {
937 incx = 1;
938 incy = 1;
939 }
940 #if AT_BUILD_WITH_BLAS()
941 if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
942 return blas_impl::vdot_fast_path(n, x, incx, y, incy);
943 } else {
944 return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{});
945 }
946 #else
947 { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{}); }
948 #endif
949 }
950
951 // Skip reinstantiating the explicitly specialized types `float` and `double`.
952 #define INSTANTIATE_DOT_IMPL(scalar_t) \
953 template scalar_t dot_impl<scalar_t>( \
954 int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy);
955 INSTANTIATE_DOT_IMPL(uint8_t);
956 INSTANTIATE_DOT_IMPL(int8_t);
957 INSTANTIATE_DOT_IMPL(int16_t);
958 INSTANTIATE_DOT_IMPL(int);
959 INSTANTIATE_DOT_IMPL(int64_t);
960 INSTANTIATE_DOT_IMPL(c10::Half);
961 INSTANTIATE_DOT_IMPL(c10::BFloat16);
962
963 #define INSTANTIATE_VDOT_IMPL(scalar_t) \
964 template scalar_t vdot_impl<scalar_t>( \
965 int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy);
966 INSTANTIATE_VDOT_IMPL(c10::complex<float>);
967 INSTANTIATE_VDOT_IMPL(c10::complex<double>);
968
969 #undef INSTANTIATE_DOT_IMPL
970
971 } // namespace at::native
972 C10_DIAGNOSTIC_POP()
973