xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/CPUBlas.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cstdint>
12 #include <type_traits>
13 
14 #include <executorch/kernels/optimized/blas/BlasKernel.h>
15 #include <executorch/runtime/core/exec_aten/exec_aten.h>
16 
17 namespace executorch {
18 namespace cpublas {
19 
20 enum class TransposeType {
21   NoTranspose,
22   Transpose,
23   ConjTranspose,
24 };
25 
26 // clang-format off
27 void normalize_last_dims(
28     TransposeType transa, TransposeType transb,
29     int64_t m, int64_t n, int64_t k,
30     int64_t *lda, int64_t *ldb, int64_t *ldc);
31 // clang-format on
32 
to_blas(TransposeType trans)33 inline char to_blas(TransposeType trans) {
34   switch (trans) {
35     case TransposeType::Transpose:
36       return 'T';
37     case TransposeType::NoTranspose:
38       return 'N';
39     case TransposeType::ConjTranspose:
40       return 'C';
41   }
42   // Assume no transpose by default
43   return 'N';
44 }
45 
46 // clang-format off
47 template <typename scalar_t, typename opmath_t>
gemm_impl(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)48 void gemm_impl(
49     TransposeType transa, TransposeType transb,
50     int64_t m, int64_t n, int64_t k,
51     opmath_t alpha,
52     const scalar_t *a, int64_t lda,
53     const scalar_t *b, int64_t ldb,
54     opmath_t beta,
55     scalar_t *c, int64_t ldc) {
56   if (transa == TransposeType::NoTranspose &&
57       transb == TransposeType::NoTranspose) {
58     return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
59   } else if (
60       transa == TransposeType::Transpose &&
61       transb != TransposeType::Transpose) {
62     gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
63   } else if (
64       transa == TransposeType::NoTranspose &&
65       transb == TransposeType::Transpose) {
66     gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
67   } else { // transa == TransposeType::Transpose && transb ==
68            // TransposeType::Transpose
69     gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
70   }
71 }
72 // clang-format on
73 
74 // clang-format off
75 void gemm(
76     TransposeType transa, TransposeType transb,
77     int64_t m, int64_t n, int64_t k,
78     double alpha,
79     const double *a, int64_t lda,
80     const double *b, int64_t ldb,
81     double beta,
82     double *c, int64_t ldc);
83 // clang-format on
84 
85 // clang-format off
86 void gemm(
87     TransposeType transa, TransposeType transb,
88     int64_t m, int64_t n, int64_t k,
89     const float alpha,
90     const float *a, int64_t lda,
91     const float *b, int64_t ldb,
92     const float beta,
93     float *c, int64_t ldc);
94 // clang-format on
95 
96 // clang-format off
97 void gemm(
98     TransposeType transa, TransposeType transb,
99     int64_t m, int64_t n, int64_t k,
100     const exec_aten::Half alpha,
101     const exec_aten::Half *a, int64_t lda,
102     const exec_aten::Half *b, int64_t ldb,
103     const exec_aten::Half beta,
104     exec_aten::Half *c, int64_t ldc);
105 
106 void gemm(
107     TransposeType transa, TransposeType transb,
108     int64_t m, int64_t n, int64_t k,
109     const exec_aten::BFloat16 alpha,
110     const exec_aten::BFloat16 *a, int64_t lda,
111     const exec_aten::BFloat16 *b, int64_t ldb,
112     const exec_aten::BFloat16 beta,
113     exec_aten::BFloat16 *c, int64_t ldc);
114 // clang-format on
115 
116 // clang-format off
117 template <typename T,
118           typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const T alpha,const T * a,int64_t lda,const T * b,int64_t ldb,const T beta,T * c,int64_t ldc)119 void gemm(
120     TransposeType transa, TransposeType transb,
121     int64_t m, int64_t n, int64_t k,
122     const T alpha,
123     const T *a, int64_t lda,
124     const T *b, int64_t ldb,
125     const T beta,
126     T *c, int64_t ldc) {
127   normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
128 
129   using acc_type = utils::compute_dtype<T>;
130   gemm_impl(
131       transa, transb,
132       m, n, k,
133       static_cast<const acc_type>(alpha),
134       a, lda,
135       b, ldb,
136       static_cast<const acc_type>(beta),
137       c, ldc);
138 }
139 // clang-format on
140 
141 } // namespace cpublas
142 } // namespace executorch
143