xref: /aosp_15_r20/external/pytorch/torch/csrc/DataLoader.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/DataLoader.h>
2 
3 // Together with `torch/utils/data/_utils/signal_handling.py`, the following
4 // is an effort to do our best to provide some error message to users when a
5 // worker dies due to error / critical signals.
6 //
7 // See NOTE [ Signal handling in multiprocessing data loading ] for more
8 // details.
9 
10 // TODO: The following don't work on Windows. Specifically, sigaction, waitid
11 // calls, and SIGCHLD handler. Currently, dummy implementations are provided
12 // for Windows.
13 
14 #ifndef _WIN32
15 
16 #include <torch/csrc/Exceptions.h>
17 #include <torch/csrc/utils/python_numbers.h>
18 
19 #include <c10/util/irange.h>
20 #include <fmt/format.h>
21 
22 #include <sys/wait.h>
23 #include <csignal>
24 #include <map>
25 #include <set>
26 #include <sstream>
27 
28 using namespace torch;
29 
30 // Critical signal handlers should be registered on worker processes before
31 // doing work.
32 // The handler will raise default handler so that the kill information will be
33 // retrieved from main process.
34 // Python handle is _set_worker_signal_handlers().
35 #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG)                    \
36   static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) {          \
37     auto _w =                                                              \
38         write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
39     (void)_w;                                                              \
40     struct sigaction sa {};                                                \
41     sa.sa_handler = SIG_DFL;                                               \
42     sa.sa_flags = 0;                                                       \
43     if (sigemptyset(&sa.sa_mask) != 0 ||                                   \
44         sigaction(SIGNAL, &sa, nullptr) != 0) {                            \
45       _exit(EXIT_FAILURE);                                                 \
46     } else {                                                               \
47       raise(SIGNAL);                                                       \
48     }                                                                      \
49   }
50 
51 // signal(2) is really not portable. So use sigaction.
52 // http://man7.org/linux/man-pages/man2/signal.2.html
setSignalHandler(int signal,void (* handler)(int,siginfo_t *,void *),struct sigaction * old_sa_ptr)53 static inline void setSignalHandler(
54     int signal,
55     void (*handler)(int, siginfo_t*, void*),
56     struct sigaction* old_sa_ptr) {
57   struct sigaction sa {};
58   sa.sa_sigaction = handler;
59   sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
60   if (sigemptyset(&sa.sa_mask) != 0 ||
61       sigaction(signal, &sa, old_sa_ptr) != 0) {
62     std::ostringstream oss;
63     oss << "An error occurred while setting handler for " << strsignal(signal)
64         << ".";
65     throw std::runtime_error(oss.str());
66   }
67 }
68 
69 SIGNAL_HANDLER(
70     SIGBUS,
71     handler_SIGBUS,
72     "ERROR: Unexpected bus error encountered in worker. "
73     "This might be caused by insufficient shared memory (shm).\n");
74 SIGNAL_HANDLER(
75     SIGSEGV,
76     handler_SIGSEGV,
77     "ERROR: Unexpected segmentation fault encountered in worker.\n");
78 SIGNAL_HANDLER(
79     SIGFPE,
80     handler_SIGFPE,
81     "ERROR: Unexpected floating-point exception encountered in worker.\n");
82 
83 // When an error happened in DataLoader methods and Python starts to exit, the
84 // error trace will keep the loader alive, and Python may kill the children
85 // processes first before deleting the loader object. Then the cleaning up
86 // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
87 // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
88 // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
89 // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
90 // again, and then it defeats the whole purpose.
handler_SIGTERM(int sig,siginfo_t * info,void * ctx)91 static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
92   if (info->si_pid == getppid()) {
93     _exit(EXIT_SUCCESS);
94   }
95   struct sigaction sa {};
96   sa.sa_handler = SIG_DFL;
97   sa.sa_flags = 0;
98   if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
99     _exit(EXIT_FAILURE);
100   } else {
101     raise(SIGTERM);
102   }
103 }
104 
setDataLoaderSignalHandlers()105 __attribute__((weak)) void setDataLoaderSignalHandlers() {}
106 
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * arg)107 static PyObject* THPModule_setWorkerSignalHandlers(
108     PyObject* module,
109     PyObject* arg) {
110   HANDLE_TH_ERRORS
111   setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
112   setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
113   setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
114   setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
115   setDataLoaderSignalHandlers();
116   Py_RETURN_NONE;
117   END_HANDLE_TH_ERRORS
118 }
119 
120 static std::map<int64_t, std::set<pid_t>> worker_pids = {};
121 
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * noargs)122 static PyObject* THPModule_errorIfAnyWorkerFails(
123     PyObject* module,
124     PyObject* noargs) {
125   HANDLE_TH_ERRORS
126 
127   // Only check the pids we care about
128   for (auto& w : worker_pids) {
129     auto& pid_set = w.second;
130     for (auto worker_pid : pid_set) {
131       // Use waitid rather than waitpid so that we can set NOWAIT, and that
132       // Python and other handlers can get whatever info they want about the
133       // child.
134       siginfo_t infop{};
135       infop.si_pid = 0;
136       auto error =
137           waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
138       // ignore errors and case with no waitable child
139       if (error < 0 || infop.si_pid == 0)
140         continue;
141       if (infop.si_code == CLD_EXITED &&
142           infop.si_status != EXIT_SUCCESS) { // exit with error
143         std::ostringstream oss;
144         oss << "DataLoader worker (pid " << worker_pid << ") exited "
145             << "unexpectedly with exit code " << infop.si_status << ". "
146             << "Details are lost due to multiprocessing. Rerunning with "
147             << "num_workers=0 may give better error trace.";
148         // This is necessary. Otherwise, the runtime error will kill the other
149         // workers, and trigger this again.
150         pid_set.clear();
151         throw std::runtime_error(oss.str());
152       } else if (
153           infop.si_code == CLD_KILLED ||
154           infop.si_code == CLD_DUMPED) { // killed by signal
155         std::ostringstream oss;
156         oss << "DataLoader worker (pid " << worker_pid << ") is killed "
157             << "by signal: " << strsignal(infop.si_status) << ". ";
158         if (infop.si_status == SIGBUS) {
159           oss << "It is possible that dataloader's workers are out of shared memory. "
160               << "Please try to raise your shared memory limit.";
161         }
162         // This is necessary. Otherwise, the runtime error will kill the other
163         // workers, and trigger this again.
164         pid_set.clear();
165         throw std::runtime_error(oss.str());
166       }
167     }
168   }
169   Py_RETURN_NONE;
170   END_HANDLE_TH_ERRORS
171 }
172 
173 // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
174 // of pids we are interested in.
THPModule_setWorkerPIDs(PyObject * module,PyObject * args)175 static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
176   HANDLE_TH_ERRORS
177   TORCH_CHECK_TYPE(
178       PyTuple_GET_SIZE(args) == 2,
179       "_set_worker_pids expects exactly 2 arguments.");
180   int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
181   TORCH_CHECK_VALUE(
182       worker_pids.find(key) == worker_pids.end(),
183       "_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
184   PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
185   TORCH_CHECK_TYPE(
186       PyTuple_Check(child_pids),
187       "_set_worker_pids expects a tuple for child_pids, but got ",
188       Py_TYPE(child_pids)->tp_name,
189       ".");
190   std::set<pid_t> pids_set = {};
191   auto size = PyTuple_GET_SIZE(child_pids);
192   for (const auto idx : c10::irange(size)) {
193     PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
194     pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
195   }
196 
197   worker_pids[key] = pids_set;
198 
199   Py_RETURN_NONE;
200   END_HANDLE_TH_ERRORS
201 }
202 
THPModule_removeWorkerPIDs(PyObject * module,PyObject * loader_id)203 static PyObject* THPModule_removeWorkerPIDs(
204     PyObject* module,
205     PyObject* loader_id) {
206   HANDLE_TH_ERRORS
207 
208   int64_t key = THPUtils_unpackLong(loader_id);
209   auto it = worker_pids.find(key);
210   TORCH_CHECK_VALUE(
211       it != worker_pids.end(),
212       "Cannot find worker information for _BaseDataLoaderIter with id ",
213       key);
214   worker_pids.erase(it);
215 
216   Py_RETURN_NONE;
217   END_HANDLE_TH_ERRORS
218 }
219 
220 #undef SIGNAL_HANDLER
221 
222 #else
223 // dummy implementations for windows
224 
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * _ignored)225 static PyObject* THPModule_setWorkerSignalHandlers(
226     PyObject* module,
227     PyObject* _ignored) {
228   Py_RETURN_NONE;
229 }
230 
THPModule_setWorkerPIDs(PyObject * module,PyObject * _ignored)231 static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
232   Py_RETURN_NONE;
233 }
234 
THPModule_removeWorkerPIDs(PyObject * module,PyObject * _ignored)235 static PyObject* THPModule_removeWorkerPIDs(
236     PyObject* module,
237     PyObject* _ignored) {
238   Py_RETURN_NONE;
239 }
240 
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * _ignored)241 static PyObject* THPModule_errorIfAnyWorkerFails(
242     PyObject* module,
243     PyObject* _ignored) {
244   Py_RETURN_NONE;
245 }
246 
247 #endif
248 
249 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
250 PyMethodDef DataLoaderMethods[] = {
251     {"_set_worker_signal_handlers",
252      THPModule_setWorkerSignalHandlers,
253      METH_NOARGS,
254      nullptr},
255     {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
256     {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
257     {"_error_if_any_worker_fails",
258      THPModule_errorIfAnyWorkerFails,
259      METH_NOARGS,
260      nullptr},
261     {nullptr, nullptr, 0, nullptr}};
262