1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
3 #include <ATen/cuda/detail/DeviceThreadHandles.h>
4
5 #include <c10/cuda/CUDACachingAllocator.h>
6
7 #include <map>
8 #include <memory>
9 #include <regex>
10 #include <string>
11 #include <tuple>
12
13 /**
14 * Note [hipblaslt handles]
15 * ~~~~~~~~~~~~~~~~~~~~~~~~
16 * The cublas documentation states:
17 * cuBLAS handle (cublasHandle_t) encapsulates a cuBLASLt handle.
18 * Any valid cublasHandle_t can be used in place of cublasLtHandle_t with a simple cast.
19 *
20 * hipblaslt does not behave in this way.
21 * A hipblas handle does not encapsulate a hipblaslt handle.
22 *
23 * To work around this difference in behavior, a separate handle pool is available for ROCm builds.
24 * For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle,
25 * whereas for ROCm builds, it is a distinct function.
26 */
27
28 namespace at::cuda {
29
30 namespace {
31
32 #if defined(USE_ROCM)
createCublasLtHandle(cublasLtHandle_t * handle)33 void createCublasLtHandle(cublasLtHandle_t *handle) {
34 TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
35 }
36
destroyCublasLtHandle(cublasLtHandle_t handle)37 void destroyCublasLtHandle(cublasLtHandle_t handle) {
38 // this is because of something dumb in the ordering of
39 // destruction. Sometimes atexit, the cuda context (or something)
40 // would already be destroyed by the time this gets destroyed. It
41 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
42 // the handle as a workaround.
43 // - Comments of @soumith copied from cuDNN handle pool implementation
44 #ifdef NO_CUDNN_DESTROY_HANDLE
45 #else
46 cublasLtDestroy(handle);
47 #endif
48 }
49
50 using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
51 #endif
52
cublas_handle_stream_to_workspace()53 std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
54 static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
55 return instance;
56 }
57
createCublasHandle(cublasHandle_t * handle)58 void createCublasHandle(cublasHandle_t *handle) {
59 TORCH_CUDABLAS_CHECK(cublasCreate(handle));
60 }
61
destroyCublasHandle(cublasHandle_t handle)62 void destroyCublasHandle(cublasHandle_t handle) {
63 // this is because of something dumb in the ordering of
64 // destruction. Sometimes atexit, the cuda context (or something)
65 // would already be destroyed by the time this gets destroyed. It
66 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
67 // the handle as a workaround.
68 // - Comments of @soumith copied from cuDNN handle pool implementation
69 #ifdef NO_CUDNN_DESTROY_HANDLE
70 #else
71 cublasDestroy(handle);
72 #endif
73 }
74
75 using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle, destroyCublasHandle>;
76
77 } // namespace
78
clearCublasWorkspaces()79 void clearCublasWorkspaces() {
80 #if !defined(USE_ROCM)
81 cublas_handle_stream_to_workspace().clear();
82 #endif
83 }
84
parseChosenWorkspaceSize()85 size_t parseChosenWorkspaceSize() {
86 const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
87 /* :4096:2:16:8 default, 32MiB for Hopper */
88 cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
89 const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
90 const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
91
92 if (val) {
93 size_t total_size = 0;
94 const std::string config(val);
95 std::regex exp(":([0-9]+):([0-9]+)");
96 std::sregex_iterator next(config.begin(), config.end(), exp);
97 std::sregex_iterator end;
98 if (next == end) {
99 TORCH_WARN("Could not parse CUBLAS_WORKSPACE_CONFIG, using default workspace size of ", default_size, " bytes.");
100 return default_size;
101 }
102 while (next != end) {
103 std::smatch match = *next;
104 TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
105 size_t curr_size = (size_t) std::stoi(match.str(1));
106 size_t count = (size_t) std::stoi(match.str(2));
107 total_size += curr_size * 1024 * count;
108 next++;
109 }
110 return total_size;
111 } else {
112 return default_size;
113 }
114 }
115
getChosenWorkspaceSize()116 size_t getChosenWorkspaceSize() {
117 size_t pool_size = parseChosenWorkspaceSize();
118 return pool_size;
119 }
120
getNewWorkspace()121 at::DataPtr getNewWorkspace() {
122 return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
123 }
124
getCurrentCUDABlasHandle()125 cublasHandle_t getCurrentCUDABlasHandle() {
126 c10::DeviceIndex device = 0;
127 AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
128
129 #if !defined(USE_ROCM)
130 CUcontext pctx = nullptr;
131 at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
132 if (C10_UNLIKELY(!pctx)) {
133 // workaround for corner case where a primary context exists but is not
134 // the current context, seen in multithreaded use-cases
135 TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context...");
136 at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device);
137 at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
138 }
139 #endif
140
141 // Thread local PoolWindows are lazily-initialized
142 // to avoid initialization issues that caused hangs on Windows.
143 // See: https://github.com/pytorch/pytorch/pull/22405
144 // This thread local unique_ptrs will be destroyed when the thread terminates,
145 // releasing its reserved handles back to the pool.
146
147 // Use a leaky singleton for the pool following standard practice around
148 // singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
149 static auto pool = std::shared_ptr<CuBlasPoolType>(
150 new CuBlasPoolType(), [](CuBlasPoolType* p) {
151 // Leak the memory.
152 });
153 thread_local std::unique_ptr<CuBlasPoolType::PoolWindow> myPoolWindow(
154 pool->newPoolWindow());
155
156 auto handle = myPoolWindow->reserve(device);
157 auto stream = c10::cuda::getCurrentCUDAStream();
158 TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
159 #if !defined(USE_ROCM)
160 // We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
161 // issue where memory usage increased during graph capture.
162 // original issue: https://github.com/pytorch/pytorch/pull/83461
163 // This is because in CUDA 12.2+, the use of cudaMallocAsync in cublas
164 // will allocate memory dynamically (even if they're cheap) outside
165 // PyTorch's CUDA caching allocator. It's possible that CCA used up
166 // all the memory and cublas's cudaMallocAsync will return OOM
167 cudaStream_t _stream = stream;
168 auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
169 auto workspace_it = cublas_handle_stream_to_workspace().find(key);
170 if (workspace_it == cublas_handle_stream_to_workspace().end()) {
171 workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
172 }
173 TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
174 // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
175 // FP32 data type calculations based on the value of the allow_tf32 flag.
176 // To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
177 if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
178 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
179 } else {
180 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
181 }
182 #else
183 hipblasAtomicsMode_t hipblas_mode;
184 if (at::globalContext().deterministicAlgorithms()) {
185 hipblas_mode = HIPBLAS_ATOMICS_NOT_ALLOWED;
186 } else {
187 hipblas_mode = HIPBLAS_ATOMICS_ALLOWED;
188 }
189 TORCH_CUDABLAS_CHECK(hipblasSetAtomicsMode(handle, hipblas_mode));
190 #endif
191 return handle;
192 }
193
getCurrentCUDABlasLtHandle()194 cublasLtHandle_t getCurrentCUDABlasLtHandle() {
195 #ifdef USE_ROCM
196 c10::DeviceIndex device = 0;
197 AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
198
199 // Thread local PoolWindows are lazily-initialized
200 // to avoid initialization issues that caused hangs on Windows.
201 // See: https://github.com/pytorch/pytorch/pull/22405
202 // This thread local unique_ptrs will be destroyed when the thread terminates,
203 // releasing its reserved handles back to the pool.
204
205 // Use a leaky singleton for the pool following standard practice around
206 // singletons: https://isocpp.org/wiki/faq/ctors#construct-on-first-use-v2
207 static auto pool = std::shared_ptr<CuBlasLtPoolType>(
208 new CuBlasLtPoolType(), [](CuBlasLtPoolType* p) {
209 // Leak the memory.
210 });
211 thread_local std::unique_ptr<CuBlasLtPoolType::PoolWindow> myPoolWindow(
212 pool->newPoolWindow());
213
214 auto handle = myPoolWindow->reserve(device);
215 return handle;
216 #else
217 return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
218 #endif
219 }
220
221 } // namespace at::cuda
222