xref: /aosp_15_r20/external/pytorch/torch/csrc/multiprocessing/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/thread_name.h>
2 #include <torch/csrc/Exceptions.h>
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_strings.h>
7 
8 #include <stdexcept>
9 
10 #if defined(__linux__)
11 #include <sys/prctl.h>
12 #endif
13 
14 #define SYSASSERT(rv, ...)                                                 \
15   if ((rv) < 0) {                                                          \
16     throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
17   }
18 
19 namespace torch::multiprocessing {
20 
21 namespace {
22 
multiprocessing_init(PyObject * _unused,PyObject * noargs)23 PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) {
24   auto multiprocessing_module =
25       THPObjectPtr(PyImport_ImportModule("torch.multiprocessing"));
26   if (!multiprocessing_module) {
27     throw python_error();
28   }
29 
30   auto module = py::handle(multiprocessing_module).cast<py::module>();
31 
32   module.def("_prctl_pr_set_pdeathsig", [](int signal) {
33 #if defined(__linux__)
34     auto rv = prctl(PR_SET_PDEATHSIG, signal);
35     SYSASSERT(rv, "prctl");
36 #endif
37   });
38 
39   Py_RETURN_TRUE;
40 }
41 
set_thread_name(PyObject * _unused,PyObject * arg)42 PyObject* set_thread_name(PyObject* _unused, PyObject* arg) {
43   TORCH_CHECK(THPUtils_checkString(arg), "invalid argument to setDevice");
44 
45   auto name = THPUtils_unpackString(arg);
46   c10::setThreadName(name);
47 
48   Py_RETURN_TRUE;
49 }
50 
get_thread_name(PyObject * _unused,PyObject * noargs)51 PyObject* get_thread_name(PyObject* _unused, PyObject* noargs) {
52   return THPUtils_packString(c10::getThreadName());
53 }
54 
55 } // namespace
56 
57 // multiprocessing methods on torch._C
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
59 static PyMethodDef methods[] = {
60     {
61         "_multiprocessing_init",
62         multiprocessing_init,
63         METH_NOARGS,
64         nullptr,
65     },
66     {
67         "_set_thread_name",
68         set_thread_name,
69         METH_O,
70         nullptr,
71     },
72     {
73         "_get_thread_name",
74         get_thread_name,
75         METH_NOARGS,
76         nullptr,
77     },
78     {nullptr, nullptr, 0, nullptr},
79 };
80 
python_functions()81 PyMethodDef* python_functions() {
82   return methods;
83 }
84 
85 } // namespace torch::multiprocessing
86