xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/CPUBlas.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
12*523fa7a6SAndroid Build Coastguard Worker #include <type_traits>
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/BlasKernel.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
18*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker enum class TransposeType {
21*523fa7a6SAndroid Build Coastguard Worker   NoTranspose,
22*523fa7a6SAndroid Build Coastguard Worker   Transpose,
23*523fa7a6SAndroid Build Coastguard Worker   ConjTranspose,
24*523fa7a6SAndroid Build Coastguard Worker };
25*523fa7a6SAndroid Build Coastguard Worker 
26*523fa7a6SAndroid Build Coastguard Worker // clang-format off
27*523fa7a6SAndroid Build Coastguard Worker void normalize_last_dims(
28*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
29*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
30*523fa7a6SAndroid Build Coastguard Worker     int64_t *lda, int64_t *ldb, int64_t *ldc);
31*523fa7a6SAndroid Build Coastguard Worker // clang-format on
32*523fa7a6SAndroid Build Coastguard Worker 
to_blas(TransposeType trans)33*523fa7a6SAndroid Build Coastguard Worker inline char to_blas(TransposeType trans) {
34*523fa7a6SAndroid Build Coastguard Worker   switch (trans) {
35*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::Transpose:
36*523fa7a6SAndroid Build Coastguard Worker       return 'T';
37*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::NoTranspose:
38*523fa7a6SAndroid Build Coastguard Worker       return 'N';
39*523fa7a6SAndroid Build Coastguard Worker     case TransposeType::ConjTranspose:
40*523fa7a6SAndroid Build Coastguard Worker       return 'C';
41*523fa7a6SAndroid Build Coastguard Worker   }
42*523fa7a6SAndroid Build Coastguard Worker   // Assume no transpose by default
43*523fa7a6SAndroid Build Coastguard Worker   return 'N';
44*523fa7a6SAndroid Build Coastguard Worker }
45*523fa7a6SAndroid Build Coastguard Worker 
46*523fa7a6SAndroid Build Coastguard Worker // clang-format off
47*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
gemm_impl(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)48*523fa7a6SAndroid Build Coastguard Worker void gemm_impl(
49*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
50*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
51*523fa7a6SAndroid Build Coastguard Worker     opmath_t alpha,
52*523fa7a6SAndroid Build Coastguard Worker     const scalar_t *a, int64_t lda,
53*523fa7a6SAndroid Build Coastguard Worker     const scalar_t *b, int64_t ldb,
54*523fa7a6SAndroid Build Coastguard Worker     opmath_t beta,
55*523fa7a6SAndroid Build Coastguard Worker     scalar_t *c, int64_t ldc) {
56*523fa7a6SAndroid Build Coastguard Worker   if (transa == TransposeType::NoTranspose &&
57*523fa7a6SAndroid Build Coastguard Worker       transb == TransposeType::NoTranspose) {
58*523fa7a6SAndroid Build Coastguard Worker     return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
59*523fa7a6SAndroid Build Coastguard Worker   } else if (
60*523fa7a6SAndroid Build Coastguard Worker       transa == TransposeType::Transpose &&
61*523fa7a6SAndroid Build Coastguard Worker       transb != TransposeType::Transpose) {
62*523fa7a6SAndroid Build Coastguard Worker     gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
63*523fa7a6SAndroid Build Coastguard Worker   } else if (
64*523fa7a6SAndroid Build Coastguard Worker       transa == TransposeType::NoTranspose &&
65*523fa7a6SAndroid Build Coastguard Worker       transb == TransposeType::Transpose) {
66*523fa7a6SAndroid Build Coastguard Worker     gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
67*523fa7a6SAndroid Build Coastguard Worker   } else { // transa == TransposeType::Transpose && transb ==
68*523fa7a6SAndroid Build Coastguard Worker            // TransposeType::Transpose
69*523fa7a6SAndroid Build Coastguard Worker     gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
70*523fa7a6SAndroid Build Coastguard Worker   }
71*523fa7a6SAndroid Build Coastguard Worker }
72*523fa7a6SAndroid Build Coastguard Worker // clang-format on
73*523fa7a6SAndroid Build Coastguard Worker 
74*523fa7a6SAndroid Build Coastguard Worker // clang-format off
75*523fa7a6SAndroid Build Coastguard Worker void gemm(
76*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
77*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
78*523fa7a6SAndroid Build Coastguard Worker     double alpha,
79*523fa7a6SAndroid Build Coastguard Worker     const double *a, int64_t lda,
80*523fa7a6SAndroid Build Coastguard Worker     const double *b, int64_t ldb,
81*523fa7a6SAndroid Build Coastguard Worker     double beta,
82*523fa7a6SAndroid Build Coastguard Worker     double *c, int64_t ldc);
83*523fa7a6SAndroid Build Coastguard Worker // clang-format on
84*523fa7a6SAndroid Build Coastguard Worker 
85*523fa7a6SAndroid Build Coastguard Worker // clang-format off
86*523fa7a6SAndroid Build Coastguard Worker void gemm(
87*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
88*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
89*523fa7a6SAndroid Build Coastguard Worker     const float alpha,
90*523fa7a6SAndroid Build Coastguard Worker     const float *a, int64_t lda,
91*523fa7a6SAndroid Build Coastguard Worker     const float *b, int64_t ldb,
92*523fa7a6SAndroid Build Coastguard Worker     const float beta,
93*523fa7a6SAndroid Build Coastguard Worker     float *c, int64_t ldc);
94*523fa7a6SAndroid Build Coastguard Worker // clang-format on
95*523fa7a6SAndroid Build Coastguard Worker 
96*523fa7a6SAndroid Build Coastguard Worker // clang-format off
97*523fa7a6SAndroid Build Coastguard Worker void gemm(
98*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
99*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
100*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::Half alpha,
101*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::Half *a, int64_t lda,
102*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::Half *b, int64_t ldb,
103*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::Half beta,
104*523fa7a6SAndroid Build Coastguard Worker     exec_aten::Half *c, int64_t ldc);
105*523fa7a6SAndroid Build Coastguard Worker 
106*523fa7a6SAndroid Build Coastguard Worker void gemm(
107*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
108*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
109*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::BFloat16 alpha,
110*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::BFloat16 *a, int64_t lda,
111*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::BFloat16 *b, int64_t ldb,
112*523fa7a6SAndroid Build Coastguard Worker     const exec_aten::BFloat16 beta,
113*523fa7a6SAndroid Build Coastguard Worker     exec_aten::BFloat16 *c, int64_t ldc);
114*523fa7a6SAndroid Build Coastguard Worker // clang-format on
115*523fa7a6SAndroid Build Coastguard Worker 
116*523fa7a6SAndroid Build Coastguard Worker // clang-format off
117*523fa7a6SAndroid Build Coastguard Worker template <typename T,
118*523fa7a6SAndroid Build Coastguard Worker           typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const T alpha,const T * a,int64_t lda,const T * b,int64_t ldb,const T beta,T * c,int64_t ldc)119*523fa7a6SAndroid Build Coastguard Worker void gemm(
120*523fa7a6SAndroid Build Coastguard Worker     TransposeType transa, TransposeType transb,
121*523fa7a6SAndroid Build Coastguard Worker     int64_t m, int64_t n, int64_t k,
122*523fa7a6SAndroid Build Coastguard Worker     const T alpha,
123*523fa7a6SAndroid Build Coastguard Worker     const T *a, int64_t lda,
124*523fa7a6SAndroid Build Coastguard Worker     const T *b, int64_t ldb,
125*523fa7a6SAndroid Build Coastguard Worker     const T beta,
126*523fa7a6SAndroid Build Coastguard Worker     T *c, int64_t ldc) {
127*523fa7a6SAndroid Build Coastguard Worker   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
128*523fa7a6SAndroid Build Coastguard Worker 
129*523fa7a6SAndroid Build Coastguard Worker   using acc_type = utils::compute_dtype<T>;
130*523fa7a6SAndroid Build Coastguard Worker   gemm_impl(
131*523fa7a6SAndroid Build Coastguard Worker       transa, transb,
132*523fa7a6SAndroid Build Coastguard Worker       m, n, k,
133*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(alpha),
134*523fa7a6SAndroid Build Coastguard Worker       a, lda,
135*523fa7a6SAndroid Build Coastguard Worker       b, ldb,
136*523fa7a6SAndroid Build Coastguard Worker       static_cast<const acc_type>(beta),
137*523fa7a6SAndroid Build Coastguard Worker       c, ldc);
138*523fa7a6SAndroid Build Coastguard Worker }
139*523fa7a6SAndroid Build Coastguard Worker // clang-format on
140*523fa7a6SAndroid Build Coastguard Worker 
141*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
142*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
143