1 #include <torch/csrc/DataLoader.h>
2
3 // Together with `torch/utils/data/_utils/signal_handling.py`, the following
4 // is an effort to do our best to provide some error message to users when a
5 // worker dies due to error / critical signals.
6 //
7 // See NOTE [ Signal handling in multiprocessing data loading ] for more
8 // details.
9
10 // TODO: The following don't work on Windows. Specifically, sigaction, waitid
11 // calls, and SIGCHLD handler. Currently, dummy implementations are provided
12 // for Windows.
13
14 #ifndef _WIN32
15
16 #include <torch/csrc/Exceptions.h>
17 #include <torch/csrc/utils/python_numbers.h>
18
19 #include <c10/util/irange.h>
20 #include <fmt/format.h>
21
22 #include <sys/wait.h>
23 #include <csignal>
24 #include <map>
25 #include <set>
26 #include <sstream>
27
28 using namespace torch;
29
30 // Critical signal handlers should be registered on worker processes before
31 // doing work.
32 // The handler will raise default handler so that the kill information will be
33 // retrieved from main process.
34 // Python handle is _set_worker_signal_handlers().
35 #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
36 static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
37 auto _w = \
38 write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
39 (void)_w; \
40 struct sigaction sa {}; \
41 sa.sa_handler = SIG_DFL; \
42 sa.sa_flags = 0; \
43 if (sigemptyset(&sa.sa_mask) != 0 || \
44 sigaction(SIGNAL, &sa, nullptr) != 0) { \
45 _exit(EXIT_FAILURE); \
46 } else { \
47 raise(SIGNAL); \
48 } \
49 }
50
51 // signal(2) is really not portable. So use sigaction.
52 // http://man7.org/linux/man-pages/man2/signal.2.html
setSignalHandler(int signal,void (* handler)(int,siginfo_t *,void *),struct sigaction * old_sa_ptr)53 static inline void setSignalHandler(
54 int signal,
55 void (*handler)(int, siginfo_t*, void*),
56 struct sigaction* old_sa_ptr) {
57 struct sigaction sa {};
58 sa.sa_sigaction = handler;
59 sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
60 if (sigemptyset(&sa.sa_mask) != 0 ||
61 sigaction(signal, &sa, old_sa_ptr) != 0) {
62 std::ostringstream oss;
63 oss << "An error occurred while setting handler for " << strsignal(signal)
64 << ".";
65 throw std::runtime_error(oss.str());
66 }
67 }
68
69 SIGNAL_HANDLER(
70 SIGBUS,
71 handler_SIGBUS,
72 "ERROR: Unexpected bus error encountered in worker. "
73 "This might be caused by insufficient shared memory (shm).\n");
74 SIGNAL_HANDLER(
75 SIGSEGV,
76 handler_SIGSEGV,
77 "ERROR: Unexpected segmentation fault encountered in worker.\n");
78 SIGNAL_HANDLER(
79 SIGFPE,
80 handler_SIGFPE,
81 "ERROR: Unexpected floating-point exception encountered in worker.\n");
82
83 // When an error happened in DataLoader methods and Python starts to exit, the
84 // error trace will keep the loader alive, and Python may kill the children
85 // processes first before deleting the loader object. Then the cleaning up
86 // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
87 // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
88 // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
89 // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
90 // again, and then it defeats the whole purpose.
handler_SIGTERM(int sig,siginfo_t * info,void * ctx)91 static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
92 if (info->si_pid == getppid()) {
93 _exit(EXIT_SUCCESS);
94 }
95 struct sigaction sa {};
96 sa.sa_handler = SIG_DFL;
97 sa.sa_flags = 0;
98 if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
99 _exit(EXIT_FAILURE);
100 } else {
101 raise(SIGTERM);
102 }
103 }
104
setDataLoaderSignalHandlers()105 __attribute__((weak)) void setDataLoaderSignalHandlers() {}
106
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * arg)107 static PyObject* THPModule_setWorkerSignalHandlers(
108 PyObject* module,
109 PyObject* arg) {
110 HANDLE_TH_ERRORS
111 setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
112 setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
113 setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
114 setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
115 setDataLoaderSignalHandlers();
116 Py_RETURN_NONE;
117 END_HANDLE_TH_ERRORS
118 }
119
120 static std::map<int64_t, std::set<pid_t>> worker_pids = {};
121
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * noargs)122 static PyObject* THPModule_errorIfAnyWorkerFails(
123 PyObject* module,
124 PyObject* noargs) {
125 HANDLE_TH_ERRORS
126
127 // Only check the pids we care about
128 for (auto& w : worker_pids) {
129 auto& pid_set = w.second;
130 for (auto worker_pid : pid_set) {
131 // Use waitid rather than waitpid so that we can set NOWAIT, and that
132 // Python and other handlers can get whatever info they want about the
133 // child.
134 siginfo_t infop{};
135 infop.si_pid = 0;
136 auto error =
137 waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
138 // ignore errors and case with no waitable child
139 if (error < 0 || infop.si_pid == 0)
140 continue;
141 if (infop.si_code == CLD_EXITED &&
142 infop.si_status != EXIT_SUCCESS) { // exit with error
143 std::ostringstream oss;
144 oss << "DataLoader worker (pid " << worker_pid << ") exited "
145 << "unexpectedly with exit code " << infop.si_status << ". "
146 << "Details are lost due to multiprocessing. Rerunning with "
147 << "num_workers=0 may give better error trace.";
148 // This is necessary. Otherwise, the runtime error will kill the other
149 // workers, and trigger this again.
150 pid_set.clear();
151 throw std::runtime_error(oss.str());
152 } else if (
153 infop.si_code == CLD_KILLED ||
154 infop.si_code == CLD_DUMPED) { // killed by signal
155 std::ostringstream oss;
156 oss << "DataLoader worker (pid " << worker_pid << ") is killed "
157 << "by signal: " << strsignal(infop.si_status) << ". ";
158 if (infop.si_status == SIGBUS) {
159 oss << "It is possible that dataloader's workers are out of shared memory. "
160 << "Please try to raise your shared memory limit.";
161 }
162 // This is necessary. Otherwise, the runtime error will kill the other
163 // workers, and trigger this again.
164 pid_set.clear();
165 throw std::runtime_error(oss.str());
166 }
167 }
168 }
169 Py_RETURN_NONE;
170 END_HANDLE_TH_ERRORS
171 }
172
173 // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
174 // of pids we are interested in.
THPModule_setWorkerPIDs(PyObject * module,PyObject * args)175 static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
176 HANDLE_TH_ERRORS
177 TORCH_CHECK_TYPE(
178 PyTuple_GET_SIZE(args) == 2,
179 "_set_worker_pids expects exactly 2 arguments.");
180 int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
181 TORCH_CHECK_VALUE(
182 worker_pids.find(key) == worker_pids.end(),
183 "_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
184 PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
185 TORCH_CHECK_TYPE(
186 PyTuple_Check(child_pids),
187 "_set_worker_pids expects a tuple for child_pids, but got ",
188 Py_TYPE(child_pids)->tp_name,
189 ".");
190 std::set<pid_t> pids_set = {};
191 auto size = PyTuple_GET_SIZE(child_pids);
192 for (const auto idx : c10::irange(size)) {
193 PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
194 pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
195 }
196
197 worker_pids[key] = pids_set;
198
199 Py_RETURN_NONE;
200 END_HANDLE_TH_ERRORS
201 }
202
THPModule_removeWorkerPIDs(PyObject * module,PyObject * loader_id)203 static PyObject* THPModule_removeWorkerPIDs(
204 PyObject* module,
205 PyObject* loader_id) {
206 HANDLE_TH_ERRORS
207
208 int64_t key = THPUtils_unpackLong(loader_id);
209 auto it = worker_pids.find(key);
210 TORCH_CHECK_VALUE(
211 it != worker_pids.end(),
212 "Cannot find worker information for _BaseDataLoaderIter with id ",
213 key);
214 worker_pids.erase(it);
215
216 Py_RETURN_NONE;
217 END_HANDLE_TH_ERRORS
218 }
219
220 #undef SIGNAL_HANDLER
221
222 #else
223 // dummy implementations for windows
224
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * _ignored)225 static PyObject* THPModule_setWorkerSignalHandlers(
226 PyObject* module,
227 PyObject* _ignored) {
228 Py_RETURN_NONE;
229 }
230
THPModule_setWorkerPIDs(PyObject * module,PyObject * _ignored)231 static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
232 Py_RETURN_NONE;
233 }
234
THPModule_removeWorkerPIDs(PyObject * module,PyObject * _ignored)235 static PyObject* THPModule_removeWorkerPIDs(
236 PyObject* module,
237 PyObject* _ignored) {
238 Py_RETURN_NONE;
239 }
240
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * _ignored)241 static PyObject* THPModule_errorIfAnyWorkerFails(
242 PyObject* module,
243 PyObject* _ignored) {
244 Py_RETURN_NONE;
245 }
246
247 #endif
248
249 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
250 PyMethodDef DataLoaderMethods[] = {
251 {"_set_worker_signal_handlers",
252 THPModule_setWorkerSignalHandlers,
253 METH_NOARGS,
254 nullptr},
255 {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
256 {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
257 {"_error_if_any_worker_fails",
258 THPModule_errorIfAnyWorkerFails,
259 METH_NOARGS,
260 nullptr},
261 {nullptr, nullptr, 0, nullptr}};
262