xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/TunableGemm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11 
12 #include <ATen/cuda/tunable/GemmCommon.h>
13 #ifdef USE_ROCM
14 #include <ATen/cuda/tunable/GemmHipblaslt.h>
15 #include <ATen/cuda/tunable/GemmRocblas.h>
16 #endif
17 #include <ATen/cuda/tunable/StreamTimer.h>
18 #include <ATen/cuda/tunable/TunableOp.h>
19 #include <c10/cuda/CUDACachingAllocator.h>
20 #include <c10/util/Float8_e4m3fn.h>
21 #include <c10/util/Float8_e4m3fnuz.h>
22 #include <c10/util/Float8_e5m2.h>
23 #include <c10/util/Float8_e5m2fnuz.h>
24 #include <c10/util/StringUtil.h>
25 
26 namespace at::cuda::tunable {
27 
28 template <typename T>
29 class DefaultGemmOp : public Callable<GemmParams<T>> {
30   public:
Call(const GemmParams<T> * params)31     TuningStatus Call(const GemmParams<T>* params) override {
32       at::cuda::blas::gemm_internal<T>(
33           params->transa, params->transb,
34           params->m, params->n, params->k,
35           params->alpha,
36           params->a, params->lda,
37           params->b, params->ldb,
38           params->beta,
39           params->c, params->ldc);
40       return OK;
41     }
42 };
43 
_transposeBoolFromChar(char op)44 static bool _transposeBoolFromChar(char op) {
45   return op == 't' || op == 'T';
46 }
47 
48 template <typename T>
49 class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
50   public:
Call(const GemmAndBiasParams<T> * params)51     TuningStatus Call(const GemmAndBiasParams<T>* params) override {
52       at::cuda::blas::gemm_and_bias<T>(
53           _transposeBoolFromChar(params->transa),
54           _transposeBoolFromChar(params->transb),
55           params->m, params->n, params->k,
56           params->alpha,
57           params->a, params->lda,
58           params->b, params->ldb,
59           params->bias,
60           params->c, params->ldc,
61           params->activation);
62       return OK;
63     }
64 };
65 
66 template <typename T>
67 class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
68   public:
Call(const GemmStridedBatchedParams<T> * params)69     TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
70       at::cuda::blas::bgemm_internal<T>(
71           params->transa, params->transb,
72           params->m, params->n, params->k,
73           params->alpha,
74           params->a, params->lda, params->stride_a,
75           params->b, params->ldb, params->stride_b,
76           params->beta,
77           params->c, params->ldc, params->stride_c,
78           params->batch);
79       return OK;
80     }
81 };
82 
83 template <typename T>
84 class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
85   public:
Call(const ScaledGemmParams<T> * params)86     TuningStatus Call(const ScaledGemmParams<T>* params) override {
87       at::cuda::blas::scaled_gemm(
88           params->transa,
89           params->transb,
90           params->m,
91           params->n,
92           params->k,
93           params->a,
94           params->a_scale_ptr,
95           params->lda,
96           params->a_dtype,
97           params->b,
98           params->b_scale_ptr,
99           params->ldb,
100           params->b_dtype,
101           params->bias_ptr,
102           params->bias_dtype,
103           params->c,
104           params->c_scale_ptr,
105           params->ldc,
106           params->c_dtype,
107           params->amax_ptr,
108           params->use_fast_accum);
109       return OK;
110     }
111 };
112 
113 template <typename T>
IsZero(T v)114 inline bool IsZero(T v) {
115   return v == 0.0f;
116 }
117 
118 template <>
IsZero(BFloat16 v)119 inline bool IsZero(BFloat16 v) {
120   return v.x == 0;
121 }
122 
123 template <>
IsZero(Half v)124 inline bool IsZero(Half v) {
125   return float(v) == 0.0f;
126 }
127 
128 template <>
IsZero(c10::complex<double> v)129 inline bool IsZero(c10::complex<double> v) {
130   return v == 0.0;
131 }
132 
133 template <>
IsZero(c10::complex<float> v)134 inline bool IsZero(c10::complex<float> v) {
135   return v == 0.0f;
136 }
137 
138 template <typename T>
TypeName(T v)139 inline std::string TypeName(T v) {
140   return "unknown";
141 }
142 
143 template <>
TypeName(float v)144 inline std::string TypeName(float v) {
145   return "float";
146 }
147 
148 template <>
TypeName(double v)149 inline std::string TypeName(double v) {
150   return "double";
151 }
152 
153 template <>
TypeName(BFloat16 v)154 inline std::string TypeName(BFloat16 v) {
155   return "BFloat16";
156 }
157 
158 template <>
TypeName(Half v)159 inline std::string TypeName(Half v) {
160   return "Half";
161 }
162 
163 template <>
TypeName(Float8_e4m3fn v)164 inline std::string TypeName(Float8_e4m3fn v) {
165   return "Float8_e4m3fn";
166 }
167 
168 template <>
TypeName(Float8_e5m2 v)169 inline std::string TypeName(Float8_e5m2 v) {
170   return "Float8_e5m2";
171 }
172 
173 template <>
TypeName(Float8_e4m3fnuz v)174 inline std::string TypeName(Float8_e4m3fnuz v) {
175   return "Float8_e4m3fnuz";
176 }
177 
178 template <>
TypeName(Float8_e5m2fnuz v)179 inline std::string TypeName(Float8_e5m2fnuz v) {
180   return "Float8_e5m2fnuz";
181 }
182 
183 template <>
TypeName(c10::complex<double> v)184 inline std::string TypeName(c10::complex<double> v) {
185   return "c10::complex<double>";
186 }
187 
188 template <>
TypeName(c10::complex<float> v)189 inline std::string TypeName(c10::complex<float> v) {
190   return "c10::complex<float>";
191 }
192 
193 template <typename T, BlasOp ALayout, BlasOp BLayout>
194 class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
195  public:
GemmTunableOp()196   GemmTunableOp() {
197     this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
198 
199 #ifdef USE_ROCM
200     static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
201     if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
202       for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
203         this->RegisterOp(std::move(name), std::move(op));
204       }
205     }
206 
207     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
208     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
209       // disallow tuning of hipblaslt with c10::complex
210       if constexpr (
211           !std::is_same_v<T, c10::complex<float>> &&
212           !std::is_same_v<T, c10::complex<double>>) {
213         for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
214           this->RegisterOp(std::move(name), std::move(op));
215         }
216       }
217     }
218 #endif
219   }
220 
Signature()221   std::string Signature() override {
222     return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
223   }
224 };
225 
226 template <typename T, BlasOp ALayout, BlasOp BLayout>
227 class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
228  public:
GemmAndBiasTunableOp()229   GemmAndBiasTunableOp() {
230     this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
231 
232 #ifdef USE_ROCM
233     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
234     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
235       // disallow tuning of hipblaslt with c10::complex
236       if constexpr (
237           !std::is_same_v<T, c10::complex<float>> &&
238           !std::is_same_v<T, c10::complex<double>>) {
239         for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
240           this->RegisterOp(std::move(name), std::move(op));
241         }
242       }
243     }
244 #endif
245   }
246 
Signature()247   std::string Signature() override {
248     return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
249   }
250 };
251 
252 template <typename T, BlasOp ALayout, BlasOp BLayout>
253 class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
254  public:
GemmStridedBatchedTunableOp()255   GemmStridedBatchedTunableOp() {
256     this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
257 
258 #ifdef USE_ROCM
259     static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
260     if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
261       for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
262         this->RegisterOp(std::move(name), std::move(op));
263       }
264     }
265 
266     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
267     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
268       // disallow tuning of hipblaslt with c10::complex
269       if constexpr (
270           !std::is_same_v<T, c10::complex<float>> &&
271           !std::is_same_v<T, c10::complex<double>>) {
272         for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
273           this->RegisterOp(std::move(name), std::move(op));
274         }
275       }
276     }
277 #endif
278   }
279 
Signature()280   std::string Signature() override {
281     return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
282   }
283 };
284 
285 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
286 class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
287  public:
ScaledGemmTunableOp()288   ScaledGemmTunableOp() {
289     this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
290 
291 #ifdef USE_ROCM
292     for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
293       this->RegisterOp(std::move(name), std::move(op));
294     }
295 #endif
296   }
297 
Signature()298   std::string Signature() override {
299     return c10::str("ScaledGemmTunableOp",
300             "_", TypeName<AT>(AT{}),
301             "_", TypeName<BT>(BT{}),
302             "_", TypeName<CT>(CT{}),
303             "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
304   }
305 };
306 
307 } // namespace at::cuda::tunable
308