xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/GemmCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11 
12 #include <string>
13 
14 #include <ATen/cuda/tunable/TunableOp.h>
15 #include <ATen/cuda/Exceptions.h>
16 #include <c10/util/StringUtil.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/allclose.h>
23 #include <ATen/ops/from_blob.h>
24 #endif
25 
26 namespace at::cuda::tunable {
27 
28 enum class BlasOp {
29   N = 0,
30   T = 1
31 };
32 
BlasOpToString(BlasOp op)33 inline std::string BlasOpToString(BlasOp op) {
34   switch (op) {
35     case BlasOp::N:
36       return "N";
37     case BlasOp::T:
38       return "T";
39   }
40   TORCH_CHECK(false, "unrecognized BlasOp");
41   return "N";
42 }
43 
44 namespace detail {
45 
NumericalCheck(ScalarType dtype,void * c,void * other_c,int64_t size)46 static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
47   auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
48   // comparison done as 1D tensor
49   at::Tensor ref = at::from_blob(c,       {size}, options);
50   at::Tensor oth = at::from_blob(other_c, {size}, options);
51   at::Tensor ref_float = ref.to(at::kFloat);
52   at::Tensor oth_float = oth.to(at::kFloat);
53   std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
54   std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
55   double last_succeed_atol = 1;
56   double last_succeed_rtol = 1;
57   for (auto& atol : atols) {
58     for (auto& rtol : rtols) {
59       if (at::allclose(ref_float, oth_float, rtol, atol)) {
60         last_succeed_atol = atol;
61         last_succeed_rtol = rtol;
62       }
63     }
64   }
65   if (last_succeed_atol == 1) {
66     return false;
67   }
68   else {
69     TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
70   }
71 
72   return true;
73 }
74 
75 }
76 
77 template <typename T>
78 struct GemmParams : OpParams {
GemmParamsGemmParams79   GemmParams() {
80     duplicate_inputs_ = false;
81   }
82 
SignatureGemmParams83   std::string Signature() const override {
84     return c10::str(transa, transb, "_", m, "_", n, "_", k);
85   }
86 
GetSizeAGemmParams87   size_t GetSizeA() const {
88     return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
89   }
90 
GetSizeBGemmParams91   size_t GetSizeB() const {
92     return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
93   }
94 
GetSizeCGemmParams95   size_t GetSizeC() const {
96     return sizeof(T) * ldc * n;
97   }
98 
GetSizeGemmParams99   size_t GetSize(bool duplicate_inputs) const {
100     size_t size = GetSizeC();
101     if (duplicate_inputs) {
102       size += GetSizeA();
103       size += GetSizeB();
104     }
105     return size;
106   }
107 
DeepCopyGemmParams108   GemmParams* DeepCopy(bool duplicate_inputs) const {
109     GemmParams* copy = new GemmParams;
110     *copy = *this;
111     c10::DeviceIndex device = 0;
112     AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
113     size_t c_size = GetSizeC();
114     copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
115     AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
116         copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
117     if (duplicate_inputs) {
118       size_t a_size = GetSizeA();
119       size_t b_size = GetSizeB();
120       copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
121       copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
122       copy->duplicate_inputs_ = true;
123     }
124     return copy;
125   }
126 
127   // only call on object returned by DeepCopy
DeleteGemmParams128   void Delete() {
129     c10::cuda::CUDACachingAllocator::raw_delete(c);
130     if (duplicate_inputs_) {
131       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
132       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
133     }
134   }
135 
NumericalCheckGemmParams136   TuningStatus NumericalCheck(GemmParams<T> *other) {
137     auto c_dtype = c10::CppTypeToScalarType<T>::value;
138     return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
139   }
140 
141   char transa;
142   char transb;
143   int64_t m;
144   int64_t n;
145   int64_t k;
146   at::opmath_type<T> alpha;
147   const T* a;
148   int64_t lda;
149   const T* b;
150   int64_t ldb;
151   at::opmath_type<T> beta;
152   T* c;
153   int64_t ldc;
154 private:
155   bool duplicate_inputs_;
156 };
157 
158 template <typename T>
159 struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParamsGemmStridedBatchedParams160   GemmStridedBatchedParams() {
161     duplicate_inputs_ = false;
162   }
163 
SignatureGemmStridedBatchedParams164   std::string Signature() const override {
165     return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
166   }
167 
GetSizeAGemmStridedBatchedParams168   size_t GetSizeA() const {
169     return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m) * batch;
170   }
171 
GetSizeBGemmStridedBatchedParams172   size_t GetSizeB() const {
173     return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k) * batch;
174   }
175 
GetSizeCGemmStridedBatchedParams176   size_t GetSizeC() const {
177     return sizeof(T) * ldc * n * batch;
178   }
179 
GetSizeGemmStridedBatchedParams180   size_t GetSize(bool duplicate_inputs) const {
181     size_t size = GetSizeC();
182     if (duplicate_inputs) {
183       size += GetSizeA();
184       size += GetSizeB();
185     }
186     return size;
187   }
188 
DeepCopyGemmStridedBatchedParams189   GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
190     GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
191     *copy = *this;
192     c10::DeviceIndex device = 0;
193     AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
194     size_t c_size = GetSizeC();
195     copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
196     AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
197         copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
198     if (duplicate_inputs) {
199       size_t a_size = GetSizeA();
200       size_t b_size = GetSizeB();
201       copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
202       copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
203       copy->duplicate_inputs_ = true;
204     }
205     return copy;
206   }
207 
208   // only call on object returned by DeepCopy
DeleteGemmStridedBatchedParams209   void Delete() {
210     c10::cuda::CUDACachingAllocator::raw_delete(c);
211     if (duplicate_inputs_) {
212       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
213       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
214     }
215   }
216 
NumericalCheckGemmStridedBatchedParams217   TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
218     auto c_dtype = c10::CppTypeToScalarType<T>::value;
219     return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
220   }
221 
222   char transa;
223   char transb;
224   int64_t m;
225   int64_t n;
226   int64_t k;
227   at::opmath_type<T> alpha;
228   const T* a;
229   int64_t lda;
230   int64_t stride_a;
231   const T* b;
232   int64_t ldb;
233   int64_t stride_b;
234   at::opmath_type<T> beta;
235   T* c;
236   int64_t ldc;
237   int64_t stride_c;
238   int64_t batch;
239 private:
240   bool duplicate_inputs_;
241 };
242 
243 template <typename T>
244 struct ScaledGemmParams : OpParams {
ScaledGemmParamsScaledGemmParams245   ScaledGemmParams() {
246     duplicate_inputs_ = false;
247   }
248 
SignatureScaledGemmParams249   std::string Signature() const override {
250     return c10::str(transa, transb, "_", m, "_", n, "_", k);
251   }
252 
GetSizeAScaledGemmParams253   size_t GetSizeA() const {
254     return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
255   }
256 
GetSizeBScaledGemmParams257   size_t GetSizeB() const {
258     return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
259   }
260 
GetSizeCScaledGemmParams261   size_t GetSizeC() const {
262     return sizeof(T) * ldc * n;
263   }
264 
GetSizeScaledGemmParams265   size_t GetSize(bool duplicate_inputs) const {
266     size_t size = GetSizeC();
267     if (duplicate_inputs) {
268       size += GetSizeA();
269       size += GetSizeB();
270     }
271     return size;
272   }
273 
DeepCopyScaledGemmParams274   ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
275     ScaledGemmParams* copy = new ScaledGemmParams;
276     *copy = *this;
277     c10::DeviceIndex device = 0;
278     AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
279     size_t c_size = GetSizeC();
280     copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
281     AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
282         copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
283     if (duplicate_inputs) {
284       size_t a_size = GetSizeA();
285       size_t b_size = GetSizeB();
286       copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
287       copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
288       copy->duplicate_inputs_ = true;
289     }
290     return copy;
291   }
292 
293   // only call on object returned by DeepCopy
DeleteScaledGemmParams294   void Delete() {
295     c10::cuda::CUDACachingAllocator::raw_delete(c);
296     if (duplicate_inputs_) {
297       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
298       c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
299     }
300   }
301 
NumericalCheckScaledGemmParams302   TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
303     return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
304   }
305 
306   char transa;
307   char transb;
308   int64_t m;
309   int64_t n;
310   int64_t k;
311   const void* a;
312   const void* a_scale_ptr;
313   int64_t lda;
314   ScalarType a_dtype;
315   const void* b;
316   const void* b_scale_ptr;
317   int64_t ldb;
318   ScalarType b_dtype;
319   const void* bias_ptr;
320   ScalarType bias_dtype;
321   void* c;
322   const void* c_scale_ptr;
323   int64_t ldc;
324   ScalarType c_dtype;
325   void* amax_ptr;
326   bool use_fast_accum;
327 private:
328   bool duplicate_inputs_;
329 };
330 
331 } // namespace at::cuda::tunable
332