1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <cstdint>
12 #include <type_traits>
13
14 #include <executorch/kernels/optimized/blas/BlasKernel.h>
15 #include <executorch/runtime/core/exec_aten/exec_aten.h>
16
17 namespace executorch {
18 namespace cpublas {
19
20 enum class TransposeType {
21 NoTranspose,
22 Transpose,
23 ConjTranspose,
24 };
25
26 // clang-format off
27 void normalize_last_dims(
28 TransposeType transa, TransposeType transb,
29 int64_t m, int64_t n, int64_t k,
30 int64_t *lda, int64_t *ldb, int64_t *ldc);
31 // clang-format on
32
to_blas(TransposeType trans)33 inline char to_blas(TransposeType trans) {
34 switch (trans) {
35 case TransposeType::Transpose:
36 return 'T';
37 case TransposeType::NoTranspose:
38 return 'N';
39 case TransposeType::ConjTranspose:
40 return 'C';
41 }
42 // Assume no transpose by default
43 return 'N';
44 }
45
46 // clang-format off
47 template <typename scalar_t, typename opmath_t>
gemm_impl(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)48 void gemm_impl(
49 TransposeType transa, TransposeType transb,
50 int64_t m, int64_t n, int64_t k,
51 opmath_t alpha,
52 const scalar_t *a, int64_t lda,
53 const scalar_t *b, int64_t ldb,
54 opmath_t beta,
55 scalar_t *c, int64_t ldc) {
56 if (transa == TransposeType::NoTranspose &&
57 transb == TransposeType::NoTranspose) {
58 return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
59 } else if (
60 transa == TransposeType::Transpose &&
61 transb != TransposeType::Transpose) {
62 gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
63 } else if (
64 transa == TransposeType::NoTranspose &&
65 transb == TransposeType::Transpose) {
66 gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
67 } else { // transa == TransposeType::Transpose && transb ==
68 // TransposeType::Transpose
69 gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
70 }
71 }
72 // clang-format on
73
74 // clang-format off
75 void gemm(
76 TransposeType transa, TransposeType transb,
77 int64_t m, int64_t n, int64_t k,
78 double alpha,
79 const double *a, int64_t lda,
80 const double *b, int64_t ldb,
81 double beta,
82 double *c, int64_t ldc);
83 // clang-format on
84
85 // clang-format off
86 void gemm(
87 TransposeType transa, TransposeType transb,
88 int64_t m, int64_t n, int64_t k,
89 const float alpha,
90 const float *a, int64_t lda,
91 const float *b, int64_t ldb,
92 const float beta,
93 float *c, int64_t ldc);
94 // clang-format on
95
96 // clang-format off
97 void gemm(
98 TransposeType transa, TransposeType transb,
99 int64_t m, int64_t n, int64_t k,
100 const exec_aten::Half alpha,
101 const exec_aten::Half *a, int64_t lda,
102 const exec_aten::Half *b, int64_t ldb,
103 const exec_aten::Half beta,
104 exec_aten::Half *c, int64_t ldc);
105
106 void gemm(
107 TransposeType transa, TransposeType transb,
108 int64_t m, int64_t n, int64_t k,
109 const exec_aten::BFloat16 alpha,
110 const exec_aten::BFloat16 *a, int64_t lda,
111 const exec_aten::BFloat16 *b, int64_t ldb,
112 const exec_aten::BFloat16 beta,
113 exec_aten::BFloat16 *c, int64_t ldc);
114 // clang-format on
115
116 // clang-format off
117 template <typename T,
118 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const T alpha,const T * a,int64_t lda,const T * b,int64_t ldb,const T beta,T * c,int64_t ldc)119 void gemm(
120 TransposeType transa, TransposeType transb,
121 int64_t m, int64_t n, int64_t k,
122 const T alpha,
123 const T *a, int64_t lda,
124 const T *b, int64_t ldb,
125 const T beta,
126 T *c, int64_t ldc) {
127 normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
128
129 using acc_type = utils::compute_dtype<T>;
130 gemm_impl(
131 transa, transb,
132 m, n, k,
133 static_cast<const acc_type>(alpha),
134 a, lda,
135 b, ldb,
136 static_cast<const acc_type>(beta),
137 c, ldc);
138 }
139 // clang-format on
140
141 } // namespace cpublas
142 } // namespace executorch
143