1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
12*523fa7a6SAndroid Build Coastguard Worker #include <type_traits>
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/BlasKernel.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
18*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker enum class TransposeType {
21*523fa7a6SAndroid Build Coastguard Worker NoTranspose,
22*523fa7a6SAndroid Build Coastguard Worker Transpose,
23*523fa7a6SAndroid Build Coastguard Worker ConjTranspose,
24*523fa7a6SAndroid Build Coastguard Worker };
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker // clang-format off
27*523fa7a6SAndroid Build Coastguard Worker void normalize_last_dims(
28*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
29*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
30*523fa7a6SAndroid Build Coastguard Worker int64_t *lda, int64_t *ldb, int64_t *ldc);
31*523fa7a6SAndroid Build Coastguard Worker // clang-format on
32*523fa7a6SAndroid Build Coastguard Worker
to_blas(TransposeType trans)33*523fa7a6SAndroid Build Coastguard Worker inline char to_blas(TransposeType trans) {
34*523fa7a6SAndroid Build Coastguard Worker switch (trans) {
35*523fa7a6SAndroid Build Coastguard Worker case TransposeType::Transpose:
36*523fa7a6SAndroid Build Coastguard Worker return 'T';
37*523fa7a6SAndroid Build Coastguard Worker case TransposeType::NoTranspose:
38*523fa7a6SAndroid Build Coastguard Worker return 'N';
39*523fa7a6SAndroid Build Coastguard Worker case TransposeType::ConjTranspose:
40*523fa7a6SAndroid Build Coastguard Worker return 'C';
41*523fa7a6SAndroid Build Coastguard Worker }
42*523fa7a6SAndroid Build Coastguard Worker // Assume no transpose by default
43*523fa7a6SAndroid Build Coastguard Worker return 'N';
44*523fa7a6SAndroid Build Coastguard Worker }
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker // clang-format off
47*523fa7a6SAndroid Build Coastguard Worker template <typename scalar_t, typename opmath_t>
gemm_impl(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,opmath_t alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,opmath_t beta,scalar_t * c,int64_t ldc)48*523fa7a6SAndroid Build Coastguard Worker void gemm_impl(
49*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
50*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
51*523fa7a6SAndroid Build Coastguard Worker opmath_t alpha,
52*523fa7a6SAndroid Build Coastguard Worker const scalar_t *a, int64_t lda,
53*523fa7a6SAndroid Build Coastguard Worker const scalar_t *b, int64_t ldb,
54*523fa7a6SAndroid Build Coastguard Worker opmath_t beta,
55*523fa7a6SAndroid Build Coastguard Worker scalar_t *c, int64_t ldc) {
56*523fa7a6SAndroid Build Coastguard Worker if (transa == TransposeType::NoTranspose &&
57*523fa7a6SAndroid Build Coastguard Worker transb == TransposeType::NoTranspose) {
58*523fa7a6SAndroid Build Coastguard Worker return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
59*523fa7a6SAndroid Build Coastguard Worker } else if (
60*523fa7a6SAndroid Build Coastguard Worker transa == TransposeType::Transpose &&
61*523fa7a6SAndroid Build Coastguard Worker transb != TransposeType::Transpose) {
62*523fa7a6SAndroid Build Coastguard Worker gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
63*523fa7a6SAndroid Build Coastguard Worker } else if (
64*523fa7a6SAndroid Build Coastguard Worker transa == TransposeType::NoTranspose &&
65*523fa7a6SAndroid Build Coastguard Worker transb == TransposeType::Transpose) {
66*523fa7a6SAndroid Build Coastguard Worker gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
67*523fa7a6SAndroid Build Coastguard Worker } else { // transa == TransposeType::Transpose && transb ==
68*523fa7a6SAndroid Build Coastguard Worker // TransposeType::Transpose
69*523fa7a6SAndroid Build Coastguard Worker gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
70*523fa7a6SAndroid Build Coastguard Worker }
71*523fa7a6SAndroid Build Coastguard Worker }
72*523fa7a6SAndroid Build Coastguard Worker // clang-format on
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker // clang-format off
75*523fa7a6SAndroid Build Coastguard Worker void gemm(
76*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
77*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
78*523fa7a6SAndroid Build Coastguard Worker double alpha,
79*523fa7a6SAndroid Build Coastguard Worker const double *a, int64_t lda,
80*523fa7a6SAndroid Build Coastguard Worker const double *b, int64_t ldb,
81*523fa7a6SAndroid Build Coastguard Worker double beta,
82*523fa7a6SAndroid Build Coastguard Worker double *c, int64_t ldc);
83*523fa7a6SAndroid Build Coastguard Worker // clang-format on
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker // clang-format off
86*523fa7a6SAndroid Build Coastguard Worker void gemm(
87*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
88*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
89*523fa7a6SAndroid Build Coastguard Worker const float alpha,
90*523fa7a6SAndroid Build Coastguard Worker const float *a, int64_t lda,
91*523fa7a6SAndroid Build Coastguard Worker const float *b, int64_t ldb,
92*523fa7a6SAndroid Build Coastguard Worker const float beta,
93*523fa7a6SAndroid Build Coastguard Worker float *c, int64_t ldc);
94*523fa7a6SAndroid Build Coastguard Worker // clang-format on
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker // clang-format off
97*523fa7a6SAndroid Build Coastguard Worker void gemm(
98*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
99*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
100*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Half alpha,
101*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Half *a, int64_t lda,
102*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Half *b, int64_t ldb,
103*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Half beta,
104*523fa7a6SAndroid Build Coastguard Worker exec_aten::Half *c, int64_t ldc);
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker void gemm(
107*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
108*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
109*523fa7a6SAndroid Build Coastguard Worker const exec_aten::BFloat16 alpha,
110*523fa7a6SAndroid Build Coastguard Worker const exec_aten::BFloat16 *a, int64_t lda,
111*523fa7a6SAndroid Build Coastguard Worker const exec_aten::BFloat16 *b, int64_t ldb,
112*523fa7a6SAndroid Build Coastguard Worker const exec_aten::BFloat16 beta,
113*523fa7a6SAndroid Build Coastguard Worker exec_aten::BFloat16 *c, int64_t ldc);
114*523fa7a6SAndroid Build Coastguard Worker // clang-format on
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker // clang-format off
117*523fa7a6SAndroid Build Coastguard Worker template <typename T,
118*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const T alpha,const T * a,int64_t lda,const T * b,int64_t ldb,const T beta,T * c,int64_t ldc)119*523fa7a6SAndroid Build Coastguard Worker void gemm(
120*523fa7a6SAndroid Build Coastguard Worker TransposeType transa, TransposeType transb,
121*523fa7a6SAndroid Build Coastguard Worker int64_t m, int64_t n, int64_t k,
122*523fa7a6SAndroid Build Coastguard Worker const T alpha,
123*523fa7a6SAndroid Build Coastguard Worker const T *a, int64_t lda,
124*523fa7a6SAndroid Build Coastguard Worker const T *b, int64_t ldb,
125*523fa7a6SAndroid Build Coastguard Worker const T beta,
126*523fa7a6SAndroid Build Coastguard Worker T *c, int64_t ldc) {
127*523fa7a6SAndroid Build Coastguard Worker normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker using acc_type = utils::compute_dtype<T>;
130*523fa7a6SAndroid Build Coastguard Worker gemm_impl(
131*523fa7a6SAndroid Build Coastguard Worker transa, transb,
132*523fa7a6SAndroid Build Coastguard Worker m, n, k,
133*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(alpha),
134*523fa7a6SAndroid Build Coastguard Worker a, lda,
135*523fa7a6SAndroid Build Coastguard Worker b, ldb,
136*523fa7a6SAndroid Build Coastguard Worker static_cast<const acc_type>(beta),
137*523fa7a6SAndroid Build Coastguard Worker c, ldc);
138*523fa7a6SAndroid Build Coastguard Worker }
139*523fa7a6SAndroid Build Coastguard Worker // clang-format on
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
142*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
143