xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/device_lazy_init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/TorchDispatchModeTLS.h>
2 #include <torch/csrc/utils/device_lazy_init.h>
3 
4 #include <torch/csrc/Exceptions.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/object_ptr.h>
7 #include <iostream>
8 namespace torch::utils {
9 namespace {
10 
11 std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{};
12 
13 } // anonymous namespace
14 
is_device_initialized(at::DeviceType device_type)15 bool is_device_initialized(at::DeviceType device_type) {
16   pybind11::gil_scoped_acquire g;
17   return is_initialized[static_cast<int>(device_type)];
18 }
19 
device_lazy_init(at::DeviceType device_type)20 void device_lazy_init(at::DeviceType device_type) {
21   pybind11::gil_scoped_acquire g;
22   // Protected by the GIL.  We don't use call_once because under ASAN it
23   // has a buggy implementation that deadlocks if an instance throws an
24   // exception.  In any case, call_once isn't necessary, because we
25   // have taken a lock.
26   if (is_device_initialized(device_type)) {
27     return;
28   }
29 
30   auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(
31       c10::impl::TorchDispatchModeKey::FAKE);
32   if (maybe_mode) {
33     return;
34   }
35 
36   std::string module_name = "torch." + at::DeviceTypeName(device_type, true);
37   auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
38   if (!module) {
39     throw python_error();
40   }
41 
42   if (device_type == at::DeviceType::PrivateUse1) {
43     auto has_lazy_init_method =
44         PyObject_HasAttrString(module.get(), "_lazy_init") == 1;
45     if (!has_lazy_init_method) {
46       return;
47     }
48   }
49 
50   auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", ""));
51   if (!res) {
52     throw python_error();
53   }
54 
55   is_initialized[static_cast<int>(device_type)] = true;
56 }
57 
set_requires_device_init(at::DeviceType device_type,bool value)58 void set_requires_device_init(at::DeviceType device_type, bool value) {
59   is_initialized[static_cast<int>(device_type)] = !value;
60 }
61 
62 } // namespace torch::utils
63