xref: /aosp_15_r20/external/pytorch/torch/csrc/mtia/Module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <c10/core/DeviceType.h>
3 #include <c10/core/Stream.h>
4 #include <c10/util/CallOnce.h>
5 #include <torch/csrc/Generator.h>
6 #include <torch/csrc/Stream.h>
7 #include <torch/csrc/python_headers.h>
8 #include <torch/csrc/utils/device_lazy_init.h>
9 #include <torch/csrc/utils/pybind.h>
10 #ifndef WIN32
11 #include <pthread.h>
12 #endif
13 
14 namespace torch {
15 namespace mtia {
16 
17 static bool in_bad_fork = false; // True for children forked after mtia init
18 
19 #ifndef WIN32
20 // Called in the forked child if mtia has already been initialized
forked_child()21 static void forked_child() {
22   in_bad_fork = true;
23   torch::utils::set_requires_device_init(at::kMTIA, true);
24 }
25 #endif
26 
27 // Should be called before the first mtia call.
28 // Note: This is distinct from initExtension because a stub mtia implementation
29 // has some working functions (e.g. device_count) but cannot fully initialize.
poison_fork()30 static void poison_fork() {
31 #ifndef WIN32
32   static c10::once_flag flag;
33   c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
34 #endif
35 }
36 
initModule(PyObject * module)37 void initModule(PyObject* module) {
38   auto m = py::handle(module).cast<py::module>();
39 
40   m.def("_mtia_init", []() {
41     TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
42     poison_fork();
43     at::globalContext().lazyInitMTIA();
44   });
45 
46   m.def("_mtia_isBuilt", []() {
47     // Check if the MTIAHooks class has been registered with the registry.
48     return at::detail::isMTIAHooksBuilt();
49   });
50 
51   m.def("_mtia_isInBadFork", []() { return in_bad_fork; });
52 
53   m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
54     torch::utils::device_lazy_init(at::kMTIA);
55     return at::detail::getMTIAHooks().getCurrentStream(device_index);
56   });
57 
58   m.def("_mtia_deviceSynchronize", []() {
59     torch::utils::device_lazy_init(at::kMTIA);
60     at::detail::getMTIAHooks().deviceSynchronize(
61         at::detail::getMTIAHooks().getCurrentDevice());
62   });
63 
64   m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
65     torch::utils::device_lazy_init(at::kMTIA);
66     return at::detail::getMTIAHooks().getDefaultStream(device_index);
67   });
68 
69   m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
70     torch::utils::device_lazy_init(at::kMTIA);
71     auto device = at::detail::getMTIAHooks().getCurrentDevice();
72     if (device != stream.device_index()) {
73       at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
74     }
75     at::detail::getMTIAHooks().setCurrentStream(stream);
76   });
77 
78   m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
79     PyObject* raw_pyobject =
80         at::detail::getMTIAHooks().memoryStats(device_index);
81     return py::reinterpret_steal<py::object>(raw_pyobject);
82   });
83 }
84 
85 } // namespace mtia
86 } // namespace torch
87