xref: /aosp_15_r20/external/pytorch/aten/src/ATen/detail/CUDAHooksInterface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/detail/CUDAHooksInterface.h>
2 
3 #include <c10/util/CallOnce.h>
4 
5 #include <memory>
6 
7 namespace at {
8 namespace detail {
9 
10 // NB: We purposely leak the CUDA hooks object.  This is because under some
11 // situations, we may need to reference the CUDA hooks while running destructors
12 // of objects which were constructed *prior* to the first invocation of
13 // getCUDAHooks.  The example which precipitated this change was the fused
14 // kernel cache in the JIT.  The kernel cache is a global variable which caches
15 // both CPU and CUDA kernels; CUDA kernels must interact with CUDA hooks on
16 // destruction.  Because the kernel cache handles CPU kernels too, it can be
17 // constructed before we initialize CUDA; if it contains CUDA kernels at program
18 // destruction time, you will destruct the CUDA kernels after CUDA hooks has
19 // been unloaded.  In principle, we could have also fixed the kernel cache store
20 // CUDA kernels in a separate global variable, but this solution is much
21 // simpler.
22 //
23 // CUDAHooks doesn't actually contain any data, so leaking it is very benign;
24 // you're probably losing only a word (the vptr in the allocated object.)
25 static CUDAHooksInterface* cuda_hooks = nullptr;
26 
getCUDAHooks()27 const CUDAHooksInterface& getCUDAHooks() {
28   // NB: The once_flag here implies that if you try to call any CUDA
29   // functionality before libATen_cuda.so is loaded, CUDA is permanently
30   // disabled for that copy of ATen.  In principle, we can relax this
31   // restriction, but you might have to fix some code.  See getVariableHooks()
32   // for an example where we relax this restriction (but if you try to avoid
33   // needing a lock, be careful; it doesn't look like Registry.h is thread
34   // safe...)
35 #if !defined C10_MOBILE
36   static c10::once_flag once;
37   c10::call_once(once, [] {
38     cuda_hooks =
39         CUDAHooksRegistry()->Create("CUDAHooks", CUDAHooksArgs{}).release();
40     if (!cuda_hooks) {
41       cuda_hooks = new CUDAHooksInterface();
42     }
43   });
44 #else
45   if (cuda_hooks == nullptr) {
46     cuda_hooks = new CUDAHooksInterface();
47   }
48 #endif
49   return *cuda_hooks;
50 }
51 } // namespace detail
52 
53 C10_DEFINE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs)
54 
55 } // namespace at
56