1 #include <ATen/detail/CUDAHooksInterface.h>
2
3 #include <c10/util/CallOnce.h>
4
5 #include <memory>
6
7 namespace at {
8 namespace detail {
9
10 // NB: We purposely leak the CUDA hooks object. This is because under some
11 // situations, we may need to reference the CUDA hooks while running destructors
12 // of objects which were constructed *prior* to the first invocation of
13 // getCUDAHooks. The example which precipitated this change was the fused
14 // kernel cache in the JIT. The kernel cache is a global variable which caches
15 // both CPU and CUDA kernels; CUDA kernels must interact with CUDA hooks on
16 // destruction. Because the kernel cache handles CPU kernels too, it can be
17 // constructed before we initialize CUDA; if it contains CUDA kernels at program
18 // destruction time, you will destruct the CUDA kernels after CUDA hooks has
19 // been unloaded. In principle, we could have also fixed the kernel cache store
20 // CUDA kernels in a separate global variable, but this solution is much
21 // simpler.
22 //
23 // CUDAHooks doesn't actually contain any data, so leaking it is very benign;
24 // you're probably losing only a word (the vptr in the allocated object.)
25 static CUDAHooksInterface* cuda_hooks = nullptr;
26
getCUDAHooks()27 const CUDAHooksInterface& getCUDAHooks() {
28 // NB: The once_flag here implies that if you try to call any CUDA
29 // functionality before libATen_cuda.so is loaded, CUDA is permanently
30 // disabled for that copy of ATen. In principle, we can relax this
31 // restriction, but you might have to fix some code. See getVariableHooks()
32 // for an example where we relax this restriction (but if you try to avoid
33 // needing a lock, be careful; it doesn't look like Registry.h is thread
34 // safe...)
35 #if !defined C10_MOBILE
36 static c10::once_flag once;
37 c10::call_once(once, [] {
38 cuda_hooks =
39 CUDAHooksRegistry()->Create("CUDAHooks", CUDAHooksArgs{}).release();
40 if (!cuda_hooks) {
41 cuda_hooks = new CUDAHooksInterface();
42 }
43 });
44 #else
45 if (cuda_hooks == nullptr) {
46 cuda_hooks = new CUDAHooksInterface();
47 }
48 #endif
49 return *cuda_hooks;
50 }
51 } // namespace detail
52
53 C10_DEFINE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs)
54
55 } // namespace at
56