1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3
4 #pragma once
5
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/tunable/TunableOp.h>
8 #include <ATen/cuda/tunable/GemmCommon.h>
9 #include <c10/util/StringUtil.h>
10
11 #define ROCBLAS_BETA_FEATURES_API
12 #include <rocblas/rocblas.h>
13
14 #define TORCH_ROCBLAS_CHECK(EXPR) \
15 do { \
16 rocblas_status __err = EXPR; \
17 TORCH_CHECK(__err == rocblas_status_success, \
18 "rocblas error: ", \
19 rocblas_status_to_string(__err), \
20 " when calling `" #EXPR "`"); \
21 } while (0)
22
23 namespace at::cuda::tunable {
24
25 template <typename T>
26 constexpr rocblas_datatype RocBlasDataTypeFor();
27
28 template <>
29 constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
30 return rocblas_datatype_f32_r;
31 }
32
33 template <>
34 constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
35 return rocblas_datatype_f64_r;
36 }
37
38 template <>
39 constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
40 return rocblas_datatype_f16_r;
41 }
42
43 template <>
44 constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
45 return rocblas_datatype_bf16_r;
46 }
47
48 template <>
49 constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
50 return rocblas_datatype_f32_c;
51 }
52
53 template <>
54 constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
55 return rocblas_datatype_f64_c;
56 }
57
58 template <typename T>
59 constexpr rocblas_datatype RocBlasComputeTypeFor();
60
61 template <>
62 constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
63 return rocblas_datatype_f32_r;
64 }
65
66 template <>
67 constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
68 return rocblas_datatype_f64_r;
69 }
70
71 template <>
72 constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
73 // Note that we're returning the _compute_ type for a given datatype.
74 // As of 12/2022, using compute type FP16 for 16-bit floats was much
75 // slower than using compute type FP32. So we use FP32 compute even for
76 // FP16 datatypes. This is how GEMM is implemented even in the function
77 // rocblasGemmHelper (see fpgeneric.h)
78 return rocblas_datatype_f32_r;
79 }
80
81 template <>
82 constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
83 // Note that we're returning the _compute_ type for a given datatype.
84 // As of 12/2022, using compute type FP16 for 16-bit floats was much
85 // slower than using compute type FP32. So we use FP32 compute even for
86 // BF16 datatypes. This is how GEMM is implemented even in the function
87 // rocblasGemmHelper (see fpgeneric.h)
88 return rocblas_datatype_f32_r;
89 }
90
91 template <>
92 constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
93 return rocblas_datatype_f32_c;
94 }
95
96 template <>
97 constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
98 return rocblas_datatype_f64_c;
99 }
100
101 template <typename T>
DoCastForHalfOrBfloat16(const T fp)102 auto DoCastForHalfOrBfloat16(const T fp) {
103 return fp;
104 }
105
106 template <>
107 inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
108 // alpha and beta should be the same as compute_type, in Half case it is float.
109 float h = fp;
110 return h;
111 }
112
113 template <>
114 inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
115 // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
116 float h = fp;
117 return h;
118 }
119
_rocblasOpFromChar(char op)120 static rocblas_operation _rocblasOpFromChar(char op) {
121 switch (op) {
122 case 'n':
123 case 'N':
124 return rocblas_operation_none;
125 case 't':
126 case 'T':
127 return rocblas_operation_transpose;
128 case 'c':
129 case 'C':
130 return rocblas_operation_conjugate_transpose;
131 }
132 AT_ERROR(
133 "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
134 }
135
136 template <typename T>
137 class RocblasGemmOp : public Callable<GemmParams<T>> {
138 public:
RocblasGemmOp(int solution)139 RocblasGemmOp(int solution) : solution_{solution} {}
140
Call(const GemmParams<T> * params)141 TuningStatus Call(const GemmParams<T>* params) override {
142 auto input_output_type = RocBlasDataTypeFor<T>();
143 auto compute_type = RocBlasComputeTypeFor<T>();
144 auto h_a = DoCastForHalfOrBfloat16(params->alpha);
145 auto h_b = DoCastForHalfOrBfloat16(params->beta);
146 auto status = rocblas_gemm_ex(
147 (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
148 _rocblasOpFromChar(params->transa),
149 _rocblasOpFromChar(params->transb),
150 params->m, params->n, params->k,
151 &h_a,
152 params->a, input_output_type, params->lda,
153 params->b, input_output_type, params->ldb,
154 &h_b,
155 params->c, input_output_type, params->ldc,
156 params->c, input_output_type, params->ldc,
157 compute_type,
158 rocblas_gemm_algo_solution_index,
159 solution_,
160 rocblas_gemm_flags_none);
161 if (status != rocblas_status_success) {
162 return FAIL;
163 }
164 return OK;
165 }
166
167 private:
168 int solution_;
169 };
170
171 template <typename T>
GetRocBlasGemmTypeStringAndOps()172 auto GetRocBlasGemmTypeStringAndOps() {
173 rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
174 int solution_size;
175 auto input_output_type = RocBlasDataTypeFor<T>();
176 auto compute_type = RocBlasComputeTypeFor<T>();
177 // Get the number of available solutions
178 TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
179 input_output_type,
180 input_output_type,
181 compute_type,
182 rocblas_gemm_flags_none,
183 nullptr,
184 &solution_size));
185 std::vector<int> solutions(solution_size);
186 // Get the list of available solutions
187 TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
188 input_output_type,
189 input_output_type,
190 compute_type,
191 rocblas_gemm_flags_none,
192 solutions.data(),
193 &solution_size));
194 // Sort the solutions in ascending order to make the solution vector deterministic across runs
195 std::sort(solutions.begin(), solutions.end());
196
197 std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198 for (size_t i = 0; i < solutions.size(); ++i) {
199 auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
200 ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
201 }
202 return ret;
203 }
204
205 template <typename T>
206 class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
207 public:
RocblasGemmStridedBatchedOp(int solution)208 RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
209
Call(const GemmStridedBatchedParams<T> * params)210 TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
211 auto input_output_type = RocBlasDataTypeFor<T>();
212 auto compute_type = RocBlasComputeTypeFor<T>();
213 auto h_a = DoCastForHalfOrBfloat16(params->alpha);
214 auto h_b = DoCastForHalfOrBfloat16(params->beta);
215 auto status = rocblas_gemm_strided_batched_ex(
216 (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
217 _rocblasOpFromChar(params->transa),
218 _rocblasOpFromChar(params->transb),
219 params->m, params->n, params->k,
220 &h_a,
221 params->a, input_output_type, params->lda, params->stride_a,
222 params->b, input_output_type, params->ldb, params->stride_b,
223 &h_b,
224 params->c, input_output_type, params->ldc, params->stride_c,
225 params->c, input_output_type, params->ldc, params->stride_c,
226 params->batch,
227 compute_type,
228 rocblas_gemm_algo_solution_index,
229 solution_,
230 rocblas_gemm_flags_none);
231 if (status != rocblas_status_success) {
232 return FAIL;
233 }
234 return OK;
235 }
236
237 private:
238 int solution_;
239 };
240
241 template <typename T>
GetRocBlasGemmStridedBatchedTypeStringAndOps()242 auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
243 rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
244 int solution_size;
245 auto input_output_type = RocBlasDataTypeFor<T>();
246 auto compute_type = RocBlasComputeTypeFor<T>();
247 // Get the number of available solutions
248 TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
249 input_output_type,
250 input_output_type,
251 compute_type,
252 rocblas_gemm_flags_none,
253 nullptr,
254 &solution_size));
255 std::vector<int> solutions(solution_size);
256 // Get the list of available solutions
257 TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
258 input_output_type,
259 input_output_type,
260 compute_type,
261 rocblas_gemm_flags_none,
262 solutions.data(),
263 &solution_size));
264 // Sort the solutions in ascending order to make the solution vector deterministic across runs
265 std::sort(solutions.begin(), solutions.end());
266
267 std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
268 for (size_t i = 0; i < solutions.size(); ++i) {
269 auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
270 ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
271 }
272 return ret;
273 }
274
275 } // namespace at::cuda::tunable
276