1 #include <c10/core/impl/TorchDispatchModeTLS.h> 2 #include <torch/csrc/utils/device_lazy_init.h> 3 4 #include <torch/csrc/Exceptions.h> 5 #include <torch/csrc/python_headers.h> 6 #include <torch/csrc/utils/object_ptr.h> 7 #include <iostream> 8 namespace torch::utils { 9 namespace { 10 11 std::array<bool, at::COMPILE_TIME_MAX_DEVICE_TYPES> is_initialized{}; 12 13 } // anonymous namespace 14 is_device_initialized(at::DeviceType device_type)15bool is_device_initialized(at::DeviceType device_type) { 16 pybind11::gil_scoped_acquire g; 17 return is_initialized[static_cast<int>(device_type)]; 18 } 19 device_lazy_init(at::DeviceType device_type)20void device_lazy_init(at::DeviceType device_type) { 21 pybind11::gil_scoped_acquire g; 22 // Protected by the GIL. We don't use call_once because under ASAN it 23 // has a buggy implementation that deadlocks if an instance throws an 24 // exception. In any case, call_once isn't necessary, because we 25 // have taken a lock. 26 if (is_device_initialized(device_type)) { 27 return; 28 } 29 30 auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode( 31 c10::impl::TorchDispatchModeKey::FAKE); 32 if (maybe_mode) { 33 return; 34 } 35 36 std::string module_name = "torch." + at::DeviceTypeName(device_type, true); 37 auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str())); 38 if (!module) { 39 throw python_error(); 40 } 41 42 if (device_type == at::DeviceType::PrivateUse1) { 43 auto has_lazy_init_method = 44 PyObject_HasAttrString(module.get(), "_lazy_init") == 1; 45 if (!has_lazy_init_method) { 46 return; 47 } 48 } 49 50 auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", "")); 51 if (!res) { 52 throw python_error(); 53 } 54 55 is_initialized[static_cast<int>(device_type)] = true; 56 } 57 set_requires_device_init(at::DeviceType device_type,bool value)58void set_requires_device_init(at::DeviceType device_type, bool value) { 59 is_initialized[static_cast<int>(device_type)] = !value; 60 } 61 62 } // namespace torch::utils 63