xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
3 #include <ATen/cuda/detail/DeviceThreadHandles.h>
4 
5 #include <c10/cuda/CUDACachingAllocator.h>
6 
7 #include <map>
8 #include <memory>
9 #include <regex>
10 #include <string>
11 #include <tuple>
12 
13 /**
14  * Note [hipblaslt handles]
15  * ~~~~~~~~~~~~~~~~~~~~~~~~
16  * The cublas documentation states:
17  * cuBLAS handle (cublasHandle_t) encapsulates a cuBLASLt handle.
18  * Any valid cublasHandle_t can be used in place of cublasLtHandle_t with a simple cast.
19  *
20  * hipblaslt does not behave in this way.
21  * A hipblas handle does not encapsulate a hipblaslt handle.
22  *
23  * To work around this difference in behavior, a separate handle pool is available for ROCm builds.
24  * For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle,
25  * whereas for ROCm builds, it is a distinct function.
26  */
27 
28 namespace at::cuda {
29 
30 namespace {
31 
32 #if defined(USE_ROCM)
createCublasLtHandle(cublasLtHandle_t * handle)33 void createCublasLtHandle(cublasLtHandle_t *handle) {
34   TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
35 }
36 
destroyCublasLtHandle(cublasLtHandle_t handle)37 void destroyCublasLtHandle(cublasLtHandle_t handle) {
38 // this is because of something dumb in the ordering of
39 // destruction. Sometimes atexit, the cuda context (or something)
40 // would already be destroyed by the time this gets destroyed. It
41 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
42 // the handle as a workaround.
43 //   - Comments of @soumith copied from cuDNN handle pool implementation
44 #ifdef NO_CUDNN_DESTROY_HANDLE
45 #else
46     cublasLtDestroy(handle);
47 #endif
48 }
49 
50 using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
51 #endif
52 
cublas_handle_stream_to_workspace()53 std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
54   static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
55   return instance;
56 }
57 
createCublasHandle(cublasHandle_t * handle)58 void createCublasHandle(cublasHandle_t *handle) {
59   TORCH_CUDABLAS_CHECK(cublasCreate(handle));
60 }
61 
destroyCublasHandle(cublasHandle_t handle)62 void destroyCublasHandle(cublasHandle_t handle) {
63 // this is because of something dumb in the ordering of
64 // destruction. Sometimes atexit, the cuda context (or something)
65 // would already be destroyed by the time this gets destroyed. It
66 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
67 // the handle as a workaround.
68 //   - Comments of @soumith copied from cuDNN handle pool implementation
69 #ifdef NO_CUDNN_DESTROY_HANDLE
70 #else
71     cublasDestroy(handle);
72 #endif
73 }
74 
75 using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle, destroyCublasHandle>;
76 
77 } // namespace
78 
clearCublasWorkspaces()79 void clearCublasWorkspaces() {
80   #if !defined(USE_ROCM)
81       cublas_handle_stream_to_workspace().clear();
82   #endif
83 }
84 
parseChosenWorkspaceSize()85 size_t parseChosenWorkspaceSize() {
86   const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
87   /* :4096:2:16:8 default, 32MiB for Hopper */
88   cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
89   const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
90   const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
91 
92   if (val) {
93     size_t total_size = 0;
94     const std::string config(val);
95     std::regex exp(":([0-9]+):([0-9]+)");
96     std::sregex_iterator next(config.begin(), config.end(), exp);
97     std::sregex_iterator end;
98     if (next == end) {
99       TORCH_WARN("Could not parse CUBLAS_WORKSPACE_CONFIG, using default workspace size of ", default_size, " bytes.");
100       return default_size;
101     }
102     while (next != end) {
103       std::smatch match = *next;
104       TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
105       size_t curr_size = (size_t) std::stoi(match.str(1));
106       size_t count = (size_t) std::stoi(match.str(2));
107       total_size += curr_size * 1024 * count;
108       next++;
109     }
110     return total_size;
111   } else {
112     return default_size;
113   }
114 }
115 
getChosenWorkspaceSize()116 size_t getChosenWorkspaceSize() {
117   size_t pool_size = parseChosenWorkspaceSize();
118   return pool_size;
119 }
120 
getNewWorkspace()121 at::DataPtr getNewWorkspace() {
122   return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
123 }
124 
getCurrentCUDABlasHandle()125 cublasHandle_t getCurrentCUDABlasHandle() {
126   c10::DeviceIndex device = 0;
127   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
128 
129 #if !defined(USE_ROCM)
130   CUcontext pctx = nullptr;
131   at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
132   if (C10_UNLIKELY(!pctx)) {
133     // workaround for corner case where a primary context exists but is not
134     // the current context, seen in multithreaded use-cases
135     TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context...");
136     at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device);
137     at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
138   }
139 #endif
140 
141   // Thread local PoolWindows are lazily-initialized
142   // to avoid initialization issues that caused hangs on Windows.
143   // See: https://github.com/pytorch/pytorch/pull/22405
144   // This thread local unique_ptrs will be destroyed when the thread terminates,
145   // releasing its reserved handles back to the pool.
146 
147   // Use a leaky singleton for the pool following standard practice around
148   // singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
149   static auto pool = std::shared_ptr<CuBlasPoolType>(
150       new CuBlasPoolType(), [](CuBlasPoolType* p) {
151         // Leak the memory.
152       });
153   thread_local std::unique_ptr<CuBlasPoolType::PoolWindow> myPoolWindow(
154       pool->newPoolWindow());
155 
156   auto handle = myPoolWindow->reserve(device);
157   auto stream = c10::cuda::getCurrentCUDAStream();
158   TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
159 #if !defined(USE_ROCM)
160   // We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
161   // issue where memory usage increased during graph capture.
162   // original issue: https://github.com/pytorch/pytorch/pull/83461
163   // This is because in CUDA 12.2+, the use of cudaMallocAsync in cublas
164   // will allocate memory dynamically (even if they're cheap) outside
165   // PyTorch's CUDA caching allocator. It's possible that CCA used up
166   // all the memory and cublas's cudaMallocAsync will return OOM
167   cudaStream_t _stream = stream;
168   auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
169   auto workspace_it = cublas_handle_stream_to_workspace().find(key);
170   if (workspace_it == cublas_handle_stream_to_workspace().end()) {
171     workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
172   }
173   TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
174   // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
175   // FP32 data type calculations based on the value of the allow_tf32 flag.
176   // To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
177   if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
178     TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
179   } else {
180     TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
181   }
182 #else
183   hipblasAtomicsMode_t hipblas_mode;
184   if (at::globalContext().deterministicAlgorithms()) {
185     hipblas_mode = HIPBLAS_ATOMICS_NOT_ALLOWED;
186   } else {
187     hipblas_mode = HIPBLAS_ATOMICS_ALLOWED;
188   }
189   TORCH_CUDABLAS_CHECK(hipblasSetAtomicsMode(handle, hipblas_mode));
190 #endif
191   return handle;
192 }
193 
getCurrentCUDABlasLtHandle()194 cublasLtHandle_t getCurrentCUDABlasLtHandle() {
195 #ifdef USE_ROCM
196   c10::DeviceIndex device = 0;
197   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
198 
199   // Thread local PoolWindows are lazily-initialized
200   // to avoid initialization issues that caused hangs on Windows.
201   // See: https://github.com/pytorch/pytorch/pull/22405
202   // This thread local unique_ptrs will be destroyed when the thread terminates,
203   // releasing its reserved handles back to the pool.
204 
205   // Use a leaky singleton for the pool following standard practice around
206   // singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
207   static auto pool = std::shared_ptr<CuBlasLtPoolType>(
208       new CuBlasLtPoolType(), [](CuBlasLtPoolType* p) {
209         // Leak the memory.
210       });
211   thread_local std::unique_ptr<CuBlasLtPoolType::PoolWindow> myPoolWindow(
212       pool->newPoolWindow());
213 
214   auto handle = myPoolWindow->reserve(device);
215   return handle;
216 #else
217   return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
218 #endif
219 }
220 
221 } // namespace at::cuda
222