xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/Handle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/detail/DeviceThreadHandles.h>
2 #include <ATen/cudnn/Handle.h>
3 #include <c10/cuda/CUDAStream.h>
4 
5 #include <ATen/cuda/Exceptions.h>
6 
7 namespace at::native {
8 namespace {
9 
createCuDNNHandle(cudnnHandle_t * handle)10 void createCuDNNHandle(cudnnHandle_t* handle) {
11   AT_CUDNN_CHECK(cudnnCreate(handle));
12 }
13 
destroyCuDNNHandle(cudnnHandle_t)14 void destroyCuDNNHandle(cudnnHandle_t /*handle*/) {
15   // this is because of something dumb in the ordering of
16   // destruction. Sometimes atexit, the cuda context (or something)
17   // would already be destroyed by the time this gets destroyed. It
18   // happens in fbcode setting. @colesbury and I decided to not destroy
19   // the handle as a workaround.
20   //   - @soumith
21   //
22   // Further note: this is now disabled globally, because we are seeing
23   // the same issue as mentioned above in CUDA 11 CI.
24   //   - @zasdfgbnm
25   //
26   // #ifdef NO_CUDNN_DESTROY_HANDLE
27   // #else
28   //   cudnnDestroy(handle);
29   // #endif
30 }
31 
32 using CudnnPoolType = at::cuda::DeviceThreadHandlePool<
33     cudnnHandle_t,
34     createCuDNNHandle,
35     destroyCuDNNHandle>;
36 
37 } // namespace
38 
getCudnnHandle()39 cudnnHandle_t getCudnnHandle() {
40   c10::DeviceIndex device = 0;
41   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
42 
43   // Thread local PoolWindows are lazily-initialized
44   // to avoid initialization issues that caused hangs on Windows.
45   // See: https://github.com/pytorch/pytorch/pull/22405
46   // This thread local unique_ptrs will be destroyed when the thread terminates,
47   // releasing its reserved handles back to the pool.
48   static auto pool = std::make_shared<CudnnPoolType>();
49   thread_local std::unique_ptr<CudnnPoolType::PoolWindow> myPoolWindow(
50       pool->newPoolWindow());
51 
52   auto handle = myPoolWindow->reserve(device);
53   AT_CUDNN_CHECK(cudnnSetStream(handle, c10::cuda::getCurrentCUDAStream()));
54   return handle;
55 }
56 
57 } // namespace at::native
58