xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/GemmRocblas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/tunable/TunableOp.h>
8 #include <ATen/cuda/tunable/GemmCommon.h>
9 #include <c10/util/StringUtil.h>
10 
11 #define ROCBLAS_BETA_FEATURES_API
12 #include <rocblas/rocblas.h>
13 
14 #define TORCH_ROCBLAS_CHECK(EXPR)                 \
15   do {                                            \
16     rocblas_status __err = EXPR;                  \
17     TORCH_CHECK(__err == rocblas_status_success,  \
18                 "rocblas error: ",                \
19                 rocblas_status_to_string(__err),  \
20                 " when calling `" #EXPR "`");     \
21   } while (0)
22 
23 namespace at::cuda::tunable {
24 
25 template <typename T>
26 constexpr rocblas_datatype RocBlasDataTypeFor();
27 
28 template <>
29 constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
30   return rocblas_datatype_f32_r;
31 }
32 
33 template <>
34 constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
35   return rocblas_datatype_f64_r;
36 }
37 
38 template <>
39 constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
40   return rocblas_datatype_f16_r;
41 }
42 
43 template <>
44 constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
45   return rocblas_datatype_bf16_r;
46 }
47 
48 template <>
49 constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
50   return rocblas_datatype_f32_c;
51 }
52 
53 template <>
54 constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
55   return rocblas_datatype_f64_c;
56 }
57 
58 template <typename T>
59 constexpr rocblas_datatype RocBlasComputeTypeFor();
60 
61 template <>
62 constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
63   return rocblas_datatype_f32_r;
64 }
65 
66 template <>
67 constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
68   return rocblas_datatype_f64_r;
69 }
70 
71 template <>
72 constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
73   // Note that we're returning the _compute_ type for a given datatype.
74   // As of 12/2022, using compute type FP16 for 16-bit floats was much
75   // slower than using compute type FP32. So we use FP32 compute even for
76   // FP16 datatypes. This is how GEMM is implemented even in the function
77   // rocblasGemmHelper (see fpgeneric.h)
78   return rocblas_datatype_f32_r;
79 }
80 
81 template <>
82 constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
83   // Note that we're returning the _compute_ type for a given datatype.
84   // As of 12/2022, using compute type FP16 for 16-bit floats was much
85   // slower than using compute type FP32. So we use FP32 compute even for
86   // BF16 datatypes. This is how GEMM is implemented even in the function
87   // rocblasGemmHelper (see fpgeneric.h)
88   return rocblas_datatype_f32_r;
89 }
90 
91 template <>
92 constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
93   return rocblas_datatype_f32_c;
94 }
95 
96 template <>
97 constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
98   return rocblas_datatype_f64_c;
99 }
100 
101 template <typename T>
DoCastForHalfOrBfloat16(const T fp)102 auto DoCastForHalfOrBfloat16(const T fp) {
103   return fp;
104 }
105 
106 template <>
107 inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
108   // alpha and beta should be the same as compute_type, in Half case it is float.
109   float h = fp;
110   return h;
111 }
112 
113 template <>
114 inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
115   // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
116   float h = fp;
117   return h;
118 }
119 
_rocblasOpFromChar(char op)120 static rocblas_operation _rocblasOpFromChar(char op) {
121   switch (op) {
122     case 'n':
123     case 'N':
124       return rocblas_operation_none;
125     case 't':
126     case 'T':
127       return rocblas_operation_transpose;
128     case 'c':
129     case 'C':
130       return rocblas_operation_conjugate_transpose;
131   }
132   AT_ERROR(
133       "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
134 }
135 
136 template <typename T>
137 class RocblasGemmOp : public Callable<GemmParams<T>> {
138   public:
RocblasGemmOp(int solution)139     RocblasGemmOp(int solution) : solution_{solution} {}
140 
Call(const GemmParams<T> * params)141     TuningStatus Call(const GemmParams<T>* params) override {
142       auto input_output_type = RocBlasDataTypeFor<T>();
143       auto compute_type = RocBlasComputeTypeFor<T>();
144       auto h_a = DoCastForHalfOrBfloat16(params->alpha);
145       auto h_b = DoCastForHalfOrBfloat16(params->beta);
146       auto status = rocblas_gemm_ex(
147           (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
148           _rocblasOpFromChar(params->transa),
149           _rocblasOpFromChar(params->transb),
150           params->m, params->n, params->k,
151           &h_a,
152           params->a, input_output_type, params->lda,
153           params->b, input_output_type, params->ldb,
154           &h_b,
155           params->c, input_output_type, params->ldc,
156           params->c, input_output_type, params->ldc,
157           compute_type,
158           rocblas_gemm_algo_solution_index,
159           solution_,
160           rocblas_gemm_flags_none);
161       if (status != rocblas_status_success) {
162         return FAIL;
163       }
164       return OK;
165     }
166 
167   private:
168     int solution_;
169 };
170 
171 template <typename T>
GetRocBlasGemmTypeStringAndOps()172 auto GetRocBlasGemmTypeStringAndOps() {
173   rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
174   int solution_size;
175   auto input_output_type = RocBlasDataTypeFor<T>();
176   auto compute_type = RocBlasComputeTypeFor<T>();
177   // Get the number of available solutions
178   TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
179                                                             input_output_type,
180                                                             input_output_type,
181                                                             compute_type,
182                                                             rocblas_gemm_flags_none,
183                                                             nullptr,
184                                                             &solution_size));
185   std::vector<int> solutions(solution_size);
186   // Get the list of available solutions
187   TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
188                                                             input_output_type,
189                                                             input_output_type,
190                                                             compute_type,
191                                                             rocblas_gemm_flags_none,
192                                                             solutions.data(),
193                                                             &solution_size));
194   // Sort the solutions in ascending order to make the solution vector deterministic across runs
195   std::sort(solutions.begin(), solutions.end());
196 
197   std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198   for (size_t i = 0; i < solutions.size(); ++i) {
199     auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
200     ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
201   }
202   return ret;
203 }
204 
205 template <typename T>
206 class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
207   public:
RocblasGemmStridedBatchedOp(int solution)208     RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
209 
Call(const GemmStridedBatchedParams<T> * params)210     TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
211       auto input_output_type = RocBlasDataTypeFor<T>();
212       auto compute_type = RocBlasComputeTypeFor<T>();
213       auto h_a = DoCastForHalfOrBfloat16(params->alpha);
214       auto h_b = DoCastForHalfOrBfloat16(params->beta);
215       auto status = rocblas_gemm_strided_batched_ex(
216           (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
217           _rocblasOpFromChar(params->transa),
218           _rocblasOpFromChar(params->transb),
219           params->m, params->n, params->k,
220           &h_a,
221           params->a, input_output_type, params->lda, params->stride_a,
222           params->b, input_output_type, params->ldb, params->stride_b,
223           &h_b,
224           params->c, input_output_type, params->ldc, params->stride_c,
225           params->c, input_output_type, params->ldc, params->stride_c,
226           params->batch,
227           compute_type,
228           rocblas_gemm_algo_solution_index,
229           solution_,
230           rocblas_gemm_flags_none);
231       if (status != rocblas_status_success) {
232         return FAIL;
233       }
234       return OK;
235     }
236 
237   private:
238     int solution_;
239 };
240 
241 template <typename T>
GetRocBlasGemmStridedBatchedTypeStringAndOps()242 auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
243   rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
244   int solution_size;
245   auto input_output_type = RocBlasDataTypeFor<T>();
246   auto compute_type = RocBlasComputeTypeFor<T>();
247   // Get the number of available solutions
248   TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
249                                                             input_output_type,
250                                                             input_output_type,
251                                                             compute_type,
252                                                             rocblas_gemm_flags_none,
253                                                             nullptr,
254                                                             &solution_size));
255   std::vector<int> solutions(solution_size);
256   // Get the list of available solutions
257   TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
258                                                             input_output_type,
259                                                             input_output_type,
260                                                             compute_type,
261                                                             rocblas_gemm_flags_none,
262                                                             solutions.data(),
263                                                             &solution_size));
264   // Sort the solutions in ascending order to make the solution vector deterministic across runs
265   std::sort(solutions.begin(), solutions.end());
266 
267   std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
268   for (size_t i = 0; i < solutions.size(); ++i) {
269     auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
270     ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
271   }
272   return ret;
273 }
274 
275 }  // namespace at::cuda::tunable
276