xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/Handle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/miopen/Exceptions.h>
2 #include <ATen/miopen/Handle.h>
3 #include <ATen/hip/detail/DeviceThreadHandles.h>
4 #include <c10/hip/HIPStream.h>
5 
6 namespace at { namespace native {
7 namespace {
8 
createMIOpenHandle(miopenHandle_t * handle)9 void createMIOpenHandle(miopenHandle_t *handle) {
10   MIOPEN_CHECK(miopenCreate(handle));
11 }
12 
destroyMIOpenHandle(miopenHandle_t handle)13 void destroyMIOpenHandle(miopenHandle_t handle) {
14 // this is because of something dumb in the ordering of
15 // destruction. Sometimes atexit, the cuda context (or something)
16 // would already be destroyed by the time this gets destroyed. It
17 // happens in fbcode setting. @colesbury and I decided to not destroy
18 // the handle as a workaround.
19 //   - @soumith
20 //
21 // Further note: this is now disabled globally, because we are seeing
22 // the same issue as mentioned above in CUDA 11 CI.
23 //   - @zasdfgbnm
24 //
25 // #ifdef NO_MIOPEN_DESTROY_HANDLE
26 // #else
27 //   miopenDestroy(handle);
28 // #endif
29 }
30 
31 using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<miopenHandle_t, createMIOpenHandle, destroyMIOpenHandle>;
32 
33 } // namespace
34 
getMiopenHandle()35 miopenHandle_t getMiopenHandle() {
36   int device;
37   HIP_CHECK(hipGetDevice(&device));
38 
39   // Thread local PoolWindows are lazily-initialized
40   // to avoid initialization issues that caused hangs on Windows.
41   // See: https://github.com/pytorch/pytorch/pull/22405
42   // This thread local unique_ptrs will be destroyed when the thread terminates,
43   // releasing its reserved handles back to the pool.
44   static auto pool = std::make_shared<MIOpenPoolType>();
45   thread_local std::unique_ptr<MIOpenPoolType::PoolWindow> myPoolWindow(
46       pool->newPoolWindow());
47 
48   auto handle = myPoolWindow->reserve(device);
49   MIOPEN_CHECK(miopenSetStream(handle, at::hip::getCurrentHIPStream()));
50   return handle;
51 }
52 
53 }} // namespace at::native
54