#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(__aarch64__) && !defined(C10_MOBILE) #include #endif C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") namespace { /// Wrapper for const_cast with type-inference. /// /// Use this to call into APIs that are not const-correct. template T* remove_const(const T* x) { return const_cast(x); } } // namespace #if AT_BUILD_WITH_BLAS() extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy); extern "C" void dscal_(int *n, double *a, double *x, int *incx); extern "C" void sscal_(int *n, float *a, float *x, int *incx); extern "C" void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); #if AT_BLAS_F2C() # define ffloat double #else # define ffloat float #endif #if AT_BLAS_USE_CBLAS_DOT() extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy); extern "C" void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); extern "C" void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); extern "C" void cblas_cdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc); extern "C" void cblas_zdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc); static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy) { return cblas_sdot(*n, x, *incx, y, *incy); } static inline void cdotu_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_cdotu_sub(*n, x, *incx, y, *incy, res); } static inline void zdotu_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_zdotu_sub(*n, x, *incx, y, *incy, res); } static inline void cdotc_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_cdotc_sub(*n, x, *incx, y, *incy, res); } static inline void zdotc_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_zdotc_sub(*n, x, *incx, y, *incy, res); } #else extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy); extern "C" void cdotu_(std::complex *res, int *n, std::complex *x, int *incx, std::complex *y, int *incy); extern "C" void zdotu_(std::complex *res, int *n, std::complex *x, int *incx, std::complex *y, int *incy); extern "C" void cdotc_(std::complex *res, int *n, std::complex *x, int *incx, std::complex *y, int *incy); extern "C" void zdotc_(std::complex *res, int *n, std::complex *x, int *incx, std::complex *y, int *incy); #endif // AT_BLAS_USE_CBLAS_DOT #endif // AT_BUILD_WITH_BLAS namespace at::native { namespace blas_impl { #if defined(__aarch64__) && !defined(C10_MOBILE) void fp16_gemv_notrans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy); void fp16_gemv_trans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy); float fp16_dot_with_fp32_arith( const float16_t* vec1, const float16_t* vec2, int64_t len); void bf16_gemv_trans( const int m, const int n, const at::BFloat16 alpha, const at::BFloat16* a, const int lda, const at::BFloat16* x, const int incx, const at::BFloat16 beta, at::BFloat16* y, const int incy); float bf16_dot_with_fp32_arith( const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len); #endif template bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { return false; } template bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m, C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha, C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta, C10_UNUSED int64_t incy) { return false; } template void scal_fast_path(C10_UNUSED int *n, C10_UNUSED scalar_t *a, C10_UNUSED scalar_t *x, C10_UNUSED int *incx) { TORCH_INTERNAL_ASSERT(false, "scal_fast_path shouldn't be called for this configuration"); } template void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_UNUSED const int *n, C10_UNUSED const scalar_t *alpha, C10_UNUSED const scalar_t *a, C10_UNUSED const int *lda, C10_UNUSED const scalar_t *x, C10_UNUSED const int *incx, C10_UNUSED const scalar_t *beta, C10_UNUSED scalar_t *y, C10_UNUSED const int *incy) { TORCH_INTERNAL_ASSERT(false, "gemv_fast_path shouldn't be called for this configuration"); } #define INSTANTIATE(scalar_t) \ template bool scal_use_fast_path(int64_t n, int64_t incx); \ template bool gemv_use_fast_path(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ template void gemv_fast_path(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \ template void scal_fast_path(int *n, scalar_t *a, scalar_t *x, int *incx); #if AT_BUILD_WITH_BLAS() template <> bool scal_use_fast_path(int64_t n, int64_t incx) { auto intmax = std::numeric_limits::max(); return n <= intmax && incx <= intmax; } template <> bool scal_use_fast_path(int64_t n, int64_t incx) { return scal_use_fast_path(n, incx); } template <> void scal_fast_path(int *n, double *a, double *x, int *incx) { dscal_(n, a, x, incx); } template <> void scal_fast_path(int *n, float *a, float *x, int *incx) { sscal_(n, a, x, incx); } template <> bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { return gemv_use_fast_path(trans, m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> void gemv_fast_path(const char *trans, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy) { dgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy)); } template <> void gemv_fast_path(const char *trans, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy) { sgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy)); } #else INSTANTIATE(float); INSTANTIATE(double); #endif // AT_BUILD_WITH_BLAS INSTANTIATE(uint8_t); INSTANTIATE(int8_t); INSTANTIATE(int16_t); INSTANTIATE(int); INSTANTIATE(int64_t); #if defined(__aarch64__) && !defined(C10_MOBILE) template <> bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { return false; } template <> bool gemv_use_fast_path( C10_UNUSED char trans, C10_UNUSED int64_t m, C10_UNUSED int64_t n, at::Half alpha, C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, at::Half beta, C10_UNUSED int64_t incy) { return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f && c10::detail::fp16_from_bits(beta.x) == 0.0f; } template <> bool gemv_use_fast_path( C10_UNUSED char trans, C10_UNUSED int64_t m, C10_UNUSED int64_t n, at::BFloat16 alpha, C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, at::BFloat16 beta, C10_UNUSED int64_t incy) { return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0; } #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static inline float16_t reduce(float16x4_t x) { auto sum = vpadd_f16(x, x); return vget_lane_f16(vpadd_f16(sum, sum), 0); } static inline float16_t reduce(float16x8_t x) { return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x))); } /* * NOTE [ GGML Copyright Notice ] * The below reduce overload and fp16_dot_with_fp16_arith function is * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility * functions, so here is the required copyright notice: * * MIT License * * Copyright (c) 2023-2024 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ // We need the shift for reduce(), hence the extra constants. static constexpr auto kF16ElementsPerIterationShift = 7; static constexpr auto kF16ElementsPerIteration = 1 << kF16ElementsPerIterationShift; static_assert(kF16ElementsPerIteration == 128); static constexpr auto kF16ElementsPerRegisterShift = 3; static constexpr auto kF16ElementsPerRegister = 1 << kF16ElementsPerRegisterShift; static_assert(kF16ElementsPerRegister == 8); static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationShift - kF16ElementsPerRegisterShift; static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift; static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister); static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) { int offset = kF16RegistersPerIteration; c10::ForcedUnroll{}([&offset, &x](auto idx) { offset /= 2; for (int i = 0; i < offset; ++i) { x[i] = vaddq_f16(x[i], x[offset + i]); } }); const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0])); const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); return (double)vaddvq_f32(vaddq_f32(t0, t1)); } static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { #ifdef __ARM_FEATURE_FMA return vfmaq_f16(a, b, c); #else return vaddq_f16(a, vmulq_f16(b, c)); #endif } static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, int len) { float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)}; const auto len_aligned = len & ~(kF16ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) { for (int k = 0; k < kF16RegistersPerIteration; ++k) { const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister); const auto temp_a = vld1q_f16(a + j + k * kF16ElementsPerRegister); sum[k] = f16_fma(sum[k], temp_x, temp_a); } } auto reducedSum = reduce(sum); for (int j = len_aligned; j < len; ++j) { reducedSum += x[j] * a[j]; } return reducedSum; } // Rather than unrolling to process multiple rows (transposed columns) // of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll // along an individual dot product. static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { for (int i = begin; i < end; ++i) { y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m); } }); } #endif static inline float reduce(float32x4_t x) { auto sum = vpaddq_f32(x, x); return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); } static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { #ifdef __ARM_FEATURE_FMA return vfmaq_f32(a, b, c); #else return vaddq_f32(a, vmulq_f32(b, c)); #endif } static inline float32x4_t f32_fma_low_f16(float32x4_t a, float16x8_t b, float16x8_t c) { #ifdef __ARM_FEATURE_FP16_FML // NOTE: this instruction is an optional instruction in ARM v8.2 and // v8.3, but mandatory in v8.4 per // https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en // I'm not certain that I have the right feature test macro. return vfmlalq_low_f16(a, b, c); #else return f32_fma(a, vcvt_f32_f16(vget_low_f16(b)), vcvt_f32_f16(vget_low_f16(c))); #endif } static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16x8_t c) { #ifdef __ARM_FEATURE_FP16_FML // See above note about this instruction. return vfmlalq_high_f16(a, b, c); #else return f32_fma(a, vcvt_f32_f16(vget_high_f16(b)), vcvt_f32_f16(vget_high_f16(c))); #endif } static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) { return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0))); } // The below reduce overload and fp16_dot_with_fp32_arith are adapted // from llama.cpp's ggml_vec_dot_f32 and surrounding utility // functions. See NOTE [ GGML Copyright Notice ] above for the // required notice. // We need the shift for reduce(), hence the extra constants. static constexpr auto kF32ElementsPerIterationShift = 5; static constexpr auto kF32ElementsPerIteration = 1 << kF32ElementsPerIterationShift; static_assert(kF32ElementsPerIteration == 32); static constexpr auto kF32ElementsPerRegisterShift = 2; static constexpr auto kF32ElementsPerRegister = 1 << kF32ElementsPerRegisterShift; static_assert(kF32ElementsPerRegister == 4); static constexpr auto kF32RegisterPairsPerIteration = 4; static constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2; static constexpr auto kF32RegistersPerIterationShift = 3; static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister); static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { int offset = kF32RegistersPerIteration; c10::ForcedUnroll{}([&offset, &x](auto idx) { offset /= 2; for (int i = 0; i < offset; ++i) { x[i] = vaddq_f32(x[i], x[offset + i]); } }); return vaddvq_f32(x[0]); } static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( const float16_t* vec1, const float16_t* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { // Load a pair of f32 registers at a time. const auto temp_vec1 = vld1q_f16(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]); const auto temp_vec2 = vld1q_f16(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]); sum[2 * registerPairIndex] = f32_fma_low_f16(sum[2 * registerPairIndex], temp_vec1, temp_vec2); sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2); } static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( const float16_t* vec1, const float16_t* vec2, float32x4_t* tailSum, int idx) { const auto temp_vec1 = vld1_f16(&vec1[idx]); const auto temp_vec2 = vld1_f16(&vec2[idx]); *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2); } static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { int32x4_t shift = vdupq_n_s32(16); return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); } static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( const at::BFloat16* vec1, const at::BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16 // Load a pair of f32 registers at a time. const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2)); sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); } static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( const at::BFloat16* vec1, const at::BFloat16* vec2, float32x4_t* tailSum, int idx) { const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } template float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) { dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); // First-tier tail fixup: make sure we handle workloads that can // benefit from vectorization, but don't fit into our fully unrolled // loop above. float32x4_t tailSum = vdupq_n_f32(0); const auto len_aligned_4 = len & ~3; for (int j = len_aligned; j < len_aligned_4; j += 4) { dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); } auto reducedTail = vpaddq_f32(tailSum, tailSum); reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); // Second-tier tail fixup: handle all workloads. for (int j = len_aligned_4; j < len; ++j) { reducedSum += vec1[j] * vec2[j]; } return reducedSum; } float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { return dot_with_fp32_arith(vec1, vec2, len); } float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { return dot_with_fp32_arith(vec1, vec2, len); } // On my Apple M1 Macbook (which is ARM v8.5 and thus has the // instructions f32_fma_{low,high}_f16 is targeting), this kernel has // equivalent performance to the fp16-native kernel. static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { for (int i = begin; i < end; ++i) { y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m); } }); } static void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { for (int i = begin; i < end; ++i) { y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m); } }); } void fp16_gemv_trans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC if (at::globalContext().allowFP16ReductionCPU()) { return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); } #endif return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } void bf16_gemv_trans( const int m, const int n, const at::BFloat16 alpha, const at::BFloat16* a, const int lda, const at::BFloat16* x, const int incx, const at::BFloat16 beta, at::BFloat16* y, const int incy) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) { for (auto j = 0; j < n; j++) { auto vecCol = vdup_n_f16(x[j]); const auto* column = a + lda * j; for (auto i = 0; i < m; i += 4) { auto yf16 = y + i; auto matRow = vld1_f16(column + i); auto resVec = j != 0 ? vld1_f16(yf16) : vdup_n_f16(0); resVec = vfma_lane_f16(resVec, matRow, vecCol, 0); vst1_f16(yf16, resVec); } } } #endif static void fp16_gemv_notrans_fp32_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) { std::vector sum(m); for (auto j = 0; j < n; j++) { auto vecCol = vdup_n_f32(x[j]); const auto* column = a + lda * j; for (auto i = 0; i < m; i += 4) { auto sf32 = sum.data() + i; auto matRow = vcvt_f32_f16(vld1_f16(column + i)); auto resVec = j != 0 ? vld1q_f32(sf32) : vdupq_n_f32(0); resVec = vfmaq_lane_f32(resVec, matRow, vecCol, 0); vst1q_f32(sf32, resVec); } } for (auto i = 0; i < m; i+= 4) { vst1_f16(y + i, vcvt_f16_f32(vld1q_f32(sum.data() + i))); } } void fp16_gemv_notrans( const int m, const int n, const float alpha, const float16_t* a, const int lda, const float16_t* x, const int incx, const float beta, float16_t* y, const int incy) { if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && incy == 1) { #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC return at::globalContext().allowFP16ReductionCPU() ? fp16_gemv_notrans_fp16_arith(m, n, a, lda, x, y) : fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y); #else return fp16_gemv_notrans_fp32_arith(m, n, a, lda, x, y); #endif } std::vector sum(m); for (const auto j : c10::irange(n)) { const auto* column_ = a + lda * j; auto z = alpha * x[j * incx]; for (const auto i : c10::irange(m)) { sum[i] += z * column_[i]; } } if (beta == 0.0) { for (const auto i : c10::irange(m)) { y[i * incy] = sum[i]; } } else { for (const auto i : c10::irange(m)) { y[i * incy] += sum[i]; } } } template <> void gemv_fast_path( const char* trans, const int* m, const int* n, const at::Half* alpha, const at::Half* a, const int* lda, const at::Half* x, const int* incx, const at::Half* beta, at::Half* y, const int* incy) { using namespace c10::detail; if ((trans[0] == 'T') || (trans[0] == 't')) { fp16_gemv_trans( *m, *n, fp16_from_bits(alpha->x), reinterpret_cast(a), *lda, reinterpret_cast(x), *incx, fp16_from_bits(beta->x), reinterpret_cast(y), *incy); } else { fp16_gemv_notrans( *m, *n, fp16_from_bits(alpha->x), reinterpret_cast(a), *lda, reinterpret_cast(x), *incx, fp16_from_bits(beta->x), reinterpret_cast(y), *incy); } } template <> void gemv_fast_path( const char* trans, const int* m, const int* n, const at::BFloat16* alpha, const at::BFloat16* a, const int* lda, const at::BFloat16* x, const int* incx, const at::BFloat16* beta, at::BFloat16* y, const int* incy) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't'); bf16_gemv_trans( *m, *n, *alpha, a, *lda, x, *incx, *beta, y, *incy); } #else // defined(__aarch64__) && !defined(C10_MOBILE) INSTANTIATE(c10::Half); INSTANTIATE(c10::BFloat16); #endif // defined(__aarch64__) && !defined(C10_MOBILE) #undef INSTANTIATE } // namespace blas_impl template inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx) { if (n == 1) incx = 1; #if AT_BUILD_WITH_BLAS() if (blas_impl::scal_use_fast_path(n, incx)) { int i_n = (int)n; int i_incx = (int)incx; blas_impl::scal_fast_path(&i_n, &a, x, &i_incx); return; } #endif for (const auto i : c10::irange(n)) { if (a == scalar_t(0)) { x[i * incx] = 0; } else { x[i * incx] *= a; } } } template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) { if(n == 1) lda = m; #if AT_BUILD_WITH_BLAS() if (blas_impl::gemv_use_fast_path(trans, m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; int i_n = (int)n; int i_lda = (int)lda; int i_incx = (int)incx; int i_incy = (int)incy; blas_impl::gemv_fast_path(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); return; } #endif using opmath_t = at::opmath_type; if ((trans == 'T') || (trans == 't')) { for (const auto i : c10::irange(n)) { opmath_t sum = 0; const scalar_t *row_ = a + lda * i; for (const auto j : c10::irange(m)) { sum += x[j * incx] * row_[j]; } if (beta == scalar_t(0)) { y[i * incy] = alpha * sum; } else { y[i * incy] = beta * y[i * incy] + alpha * sum; } } } else { if (beta != scalar_t(1) && beta != scalar_t(0)) scal(m, beta, y, incy); constexpr bool is_low_precision = !std::is_same_v; std::vector sum; if constexpr (is_low_precision) { sum.resize(m); } for (const auto j : c10::irange(n)) { const scalar_t *column_ = a + lda * j; opmath_t z = alpha * static_cast(x[j * incx]); for (const auto i : c10::irange(m)) { //output values are ignored if beta is 0, and set to 0, nans and infs are not propagated if (j==0 && beta==scalar_t(0)) { if constexpr (!is_low_precision) { y[i * incy] = 0; } } if constexpr (is_low_precision) { sum[i] += z * column_[i]; } else { y[i * incy] += z * column_[i]; } } } if constexpr (is_low_precision) { if (beta == scalar_t(0)) { for (const auto i : c10::irange(m)) { y[i * incy] = sum[i]; } } else { for (const auto i : c10::irange(m)) { y[i * incy] += sum[i]; } } } } return; } #define INSTANTIATE(scalar_t, _) \ template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy); AT_FORALL_SCALAR_TYPES_AND2(BFloat16, Half, INSTANTIATE); AT_FORALL_COMPLEX_TYPES(INSTANTIATE); #undef INSTANTIATE namespace blas_impl { #if AT_BUILD_WITH_BLAS() static float dot_fast_path(int n, float* x, int incx, float* y, int incy) { // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) return sdot_(&n, x, &incx, y, &incy); } static double dot_fast_path(int n, double* x, int incx, double* y, int incy) { return ddot_(&n, x, &incx, y, &incy); } static c10::complex vdot_fast_path(int n, c10::complex* x, int incx, c10::complex* y, int incy) { c10::complex result; cdotc_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex vdot_fast_path(int n, c10::complex* x, int incx, c10::complex* y, int incy) { c10::complex result; zdotc_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex dot_fast_path(int n, c10::complex* x, int incx, c10::complex* y, int incy) { c10::complex result; zdotu_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex dot_fast_path(int n, c10::complex* x, int incx, c10::complex* y, int incy) { c10::complex result; cdotu_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } #endif template scalar_t dot_naive( int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy, Functor op) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t i; using opmath_t = at::opmath_type; opmath_t sum = 0; for (i = 0; i < n; i++) { sum += op(static_cast(x[i * incx]), static_cast(y[i * incy])); } return static_cast(sum); } } // namespace blas_impl template scalar_t dot_impl_floating(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { return blas_impl::dot_fast_path(n, x, incx, y, incy); } else { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } #else { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } #endif } template scalar_t dot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } template <> float dot_impl(int64_t n, float* x, int64_t incx, float* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> double dot_impl(int64_t n, double* x, int64_t incx, double* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> c10::complex dot_impl(int64_t n, c10::complex* x, int64_t incx, c10::complex* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> c10::complex dot_impl(int64_t n, c10::complex* x, int64_t incx, c10::complex* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } namespace { template struct vdot_op { scalar_t operator()(scalar_t x, scalar_t y) { return std::conj(x) * y; } }; } // anonymous namespace template scalar_t vdot_impl(int64_t n, scalar_t* x, int64_t incx, scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { return blas_impl::vdot_fast_path(n, x, incx, y, incy); } else { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op{}); } #else { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op{}); } #endif } // Skip reinstantiating the explicitly specialized types `float` and `double`. #define INSTANTIATE_DOT_IMPL(scalar_t) \ template scalar_t dot_impl( \ int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy); INSTANTIATE_DOT_IMPL(uint8_t); INSTANTIATE_DOT_IMPL(int8_t); INSTANTIATE_DOT_IMPL(int16_t); INSTANTIATE_DOT_IMPL(int); INSTANTIATE_DOT_IMPL(int64_t); INSTANTIATE_DOT_IMPL(c10::Half); INSTANTIATE_DOT_IMPL(c10::BFloat16); #define INSTANTIATE_VDOT_IMPL(scalar_t) \ template scalar_t vdot_impl( \ int64_t n, scalar_t * x, int64_t incx, scalar_t * y, int64_t incy); INSTANTIATE_VDOT_IMPL(c10::complex); INSTANTIATE_VDOT_IMPL(c10::complex); #undef INSTANTIATE_DOT_IMPL } // namespace at::native C10_DIAGNOSTIC_POP()