1 #include <c10/util/thread_name.h>
2 #include <torch/csrc/Exceptions.h>
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/python_strings.h>
7
8 #include <stdexcept>
9
10 #if defined(__linux__)
11 #include <sys/prctl.h>
12 #endif
13
14 #define SYSASSERT(rv, ...) \
15 if ((rv) < 0) { \
16 throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
17 }
18
19 namespace torch::multiprocessing {
20
21 namespace {
22
multiprocessing_init(PyObject * _unused,PyObject * noargs)23 PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) {
24 auto multiprocessing_module =
25 THPObjectPtr(PyImport_ImportModule("torch.multiprocessing"));
26 if (!multiprocessing_module) {
27 throw python_error();
28 }
29
30 auto module = py::handle(multiprocessing_module).cast<py::module>();
31
32 module.def("_prctl_pr_set_pdeathsig", [](int signal) {
33 #if defined(__linux__)
34 auto rv = prctl(PR_SET_PDEATHSIG, signal);
35 SYSASSERT(rv, "prctl");
36 #endif
37 });
38
39 Py_RETURN_TRUE;
40 }
41
set_thread_name(PyObject * _unused,PyObject * arg)42 PyObject* set_thread_name(PyObject* _unused, PyObject* arg) {
43 TORCH_CHECK(THPUtils_checkString(arg), "invalid argument to setDevice");
44
45 auto name = THPUtils_unpackString(arg);
46 c10::setThreadName(name);
47
48 Py_RETURN_TRUE;
49 }
50
get_thread_name(PyObject * _unused,PyObject * noargs)51 PyObject* get_thread_name(PyObject* _unused, PyObject* noargs) {
52 return THPUtils_packString(c10::getThreadName());
53 }
54
55 } // namespace
56
57 // multiprocessing methods on torch._C
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
59 static PyMethodDef methods[] = {
60 {
61 "_multiprocessing_init",
62 multiprocessing_init,
63 METH_NOARGS,
64 nullptr,
65 },
66 {
67 "_set_thread_name",
68 set_thread_name,
69 METH_O,
70 nullptr,
71 },
72 {
73 "_get_thread_name",
74 get_thread_name,
75 METH_NOARGS,
76 nullptr,
77 },
78 {nullptr, nullptr, 0, nullptr},
79 };
80
python_functions()81 PyMethodDef* python_functions() {
82 return methods;
83 }
84
85 } // namespace torch::multiprocessing
86