xref: /aosp_15_r20/external/pytorch/c10/util/signal_handler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Backtrace.h>
2 #include <c10/util/Logging.h>
3 #include <c10/util/signal_handler.h>
4 
5 #if defined(C10_SUPPORTS_SIGNAL_HANDLER)
6 
7 // Normal signal handler implementation.
8 #include <dirent.h>
9 #include <fmt/core.h>
10 #include <sys/syscall.h>
11 #include <unistd.h>
12 
13 #include <atomic>
14 #include <chrono>
15 #include <condition_variable>
16 #include <cstdint>
17 #include <cstdio>
18 #include <cstdlib>
19 #include <iostream>
20 #include <mutex>
21 
22 #ifdef C10_ANDROID
23 #ifndef SYS_gettid
24 #define SYS_gettid __NR_gettid
25 #endif
26 #ifndef SYS_tgkill
27 #define SYS_tgkill __NR_tgkill
28 #endif
29 #endif
30 
31 namespace {
32 
33 struct sigaction previousSighup;
34 struct sigaction previousSigint;
35 std::atomic<int> sigintCount(0);
36 std::atomic<int> sighupCount(0);
37 std::atomic<int> hookedUpCount(0);
38 
handleSignal(int signal)39 void handleSignal(int signal) {
40   switch (signal) {
41     // TODO: what if the previous handler uses sa_sigaction?
42     case SIGHUP:
43       sighupCount += 1;
44       if (previousSighup.sa_handler) {
45         previousSighup.sa_handler(signal);
46       }
47       break;
48     case SIGINT:
49       sigintCount += 1;
50       if (previousSigint.sa_handler) {
51         previousSigint.sa_handler(signal);
52       }
53       break;
54   }
55 }
56 
hookupHandler()57 void hookupHandler() {
58   if (hookedUpCount++) {
59     return;
60   }
61   struct sigaction sa {};
62   // Setup the handler
63   sa.sa_handler = &handleSignal;
64   // Restart the system call, if at all possible
65   sa.sa_flags = SA_RESTART;
66   // Block every signal during the handler
67   sigfillset(&sa.sa_mask);
68   // Intercept SIGHUP and SIGINT
69   if (sigaction(SIGHUP, &sa, &previousSighup) == -1) {
70     LOG(FATAL) << "Cannot install SIGHUP handler.";
71   }
72   if (sigaction(SIGINT, &sa, &previousSigint) == -1) {
73     LOG(FATAL) << "Cannot install SIGINT handler.";
74   }
75 }
76 
77 // Set the signal handlers to the default.
unhookHandler()78 void unhookHandler() {
79   if (--hookedUpCount > 0) {
80     return;
81   }
82   struct sigaction sa {};
83   // Setup the sighub handler
84   sa.sa_handler = SIG_DFL;
85   // Restart the system call, if at all possible
86   sa.sa_flags = SA_RESTART;
87   // Block every signal during the handler
88   sigfillset(&sa.sa_mask);
89   // Intercept SIGHUP and SIGINT
90   if (sigaction(SIGHUP, &previousSighup, nullptr) == -1) {
91     LOG(FATAL) << "Cannot uninstall SIGHUP handler.";
92   }
93   if (sigaction(SIGINT, &previousSigint, nullptr) == -1) {
94     LOG(FATAL) << "Cannot uninstall SIGINT handler.";
95   }
96 }
97 
98 } // namespace
99 
100 namespace c10 {
101 
102 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
103 
getInstance()104 FatalSignalHandler& FatalSignalHandler::getInstance() {
105   // Leaky singleton to avoid module destructor race.
106   static FatalSignalHandler* handler = new FatalSignalHandler();
107   return *handler;
108 }
109 
110 FatalSignalHandler::~FatalSignalHandler() = default;
111 
FatalSignalHandler()112 FatalSignalHandler::FatalSignalHandler()
113     : fatalSignalHandlersInstalled(false),
114       fatalSignalReceived(false),
115       fatalSignalName("<UNKNOWN>"),
116       writingCond(),
117       writingMutex(),
118       signalReceived(false) {}
119 
120 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
121 FatalSignalHandler::signal_handler FatalSignalHandler::kSignalHandlers[] = {
122     {"SIGABRT", SIGABRT, {}},
123     {"SIGINT", SIGINT, {}},
124     {"SIGILL", SIGILL, {}},
125     {"SIGFPE", SIGFPE, {}},
126     {"SIGBUS", SIGBUS, {}},
127     {"SIGSEGV", SIGSEGV, {}},
128     {nullptr, 0, {}}};
129 
getPreviousSigaction(int signum)130 struct sigaction* FatalSignalHandler::getPreviousSigaction(int signum) {
131   for (auto handler = kSignalHandlers; handler->name != nullptr; handler++) {
132     if (handler->signum == signum) {
133       return &handler->previous;
134     }
135   }
136   return nullptr;
137 }
138 
getSignalName(int signum)139 const char* FatalSignalHandler::getSignalName(int signum) {
140   for (auto handler = kSignalHandlers; handler->name != nullptr; handler++) {
141     if (handler->signum == signum) {
142       return handler->name;
143     }
144   }
145   return nullptr;
146 }
147 
callPreviousSignalHandler(struct sigaction * action,int signum,siginfo_t * info,void * ctx)148 void FatalSignalHandler::callPreviousSignalHandler(
149     struct sigaction* action,
150     int signum,
151     siginfo_t* info,
152     void* ctx) {
153   if (!action->sa_handler) {
154     return;
155   }
156   if ((action->sa_flags & SA_SIGINFO) == SA_SIGINFO) {
157     action->sa_sigaction(signum, info, ctx);
158   } else {
159     action->sa_handler(signum);
160   }
161 }
162 
163 // needsLock signals whether we need to lock our writing mutex.
stacktraceSignalHandler(bool needsLock)164 void FatalSignalHandler::stacktraceSignalHandler(bool needsLock) {
165   std::unique_lock<std::mutex> ul(writingMutex, std::defer_lock);
166   if (needsLock) {
167     ul.lock();
168     signalReceived = true;
169   }
170   pid_t tid = static_cast<pid_t>(syscall(SYS_gettid));
171   std::string backtrace = fmt::format(
172       "{}({}), PID: {}, Thread {}: \n {}",
173       fatalSignalName,
174       fatalSignum,
175       ::getpid(),
176       tid,
177       c10::get_backtrace());
178   std::cerr << backtrace << std::endl;
179   if (needsLock) {
180     ul.unlock();
181     writingCond.notify_all();
182   }
183 }
184 
fatalSignalHandlerPostProcess()185 void FatalSignalHandler::fatalSignalHandlerPostProcess() {}
186 
fatalSignalHandlerStatic(int signum)187 void FatalSignalHandler::fatalSignalHandlerStatic(int signum) {
188   getInstance().fatalSignalHandler(signum);
189 }
190 
191 // Our fatal signal entry point
fatalSignalHandler(int signum)192 void FatalSignalHandler::fatalSignalHandler(int signum) {
193   // Check if this is a proper signal that we declared above.
194   const char* name = getSignalName(signum);
195   if (!name) {
196     return;
197   }
198   if (fatalSignalReceived) {
199     return;
200   }
201   // Set the flag so that our SIGUSR2 handler knows that we're aborting and
202   // that it should intercept any SIGUSR2 signal.
203   fatalSignalReceived = true;
204   // Set state for other threads.
205   fatalSignum = signum;
206   fatalSignalName = name;
207   // Linux doesn't have a nice userland API for enumerating threads so we
208   // need to use the proc pseudo-filesystem.
209   DIR* procDir = opendir("/proc/self/task");
210   if (procDir) {
211     pid_t pid = getpid();
212     pid_t currentTid = static_cast<pid_t>(syscall(SYS_gettid));
213     struct dirent* entry = nullptr;
214     std::unique_lock<std::mutex> ul(writingMutex);
215     while ((entry = readdir(procDir)) != nullptr) {
216       if (entry->d_name[0] == '.') {
217         continue;
218       }
219       pid_t tid = atoi(entry->d_name);
220       // If we've found the current thread then we'll jump into the SIGUSR2
221       // handler instead of signaling to avoid deadlocking.
222       if (tid != currentTid) {
223         signalReceived = false;
224         syscall(SYS_tgkill, pid, tid, SIGUSR2);
225         auto now = std::chrono::system_clock::now();
226         using namespace std::chrono_literals;
227         // we use wait_until instead of wait because on ROCm there was
228         // a single thread that wouldn't receive the SIGUSR2
229         if (std::cv_status::timeout == writingCond.wait_until(ul, now + 2s)) {
230           if (!signalReceived) {
231             std::cerr << "signal lost waiting for stacktrace " << pid << ":"
232                       << tid << std::endl;
233             break;
234           }
235         }
236       } else {
237         stacktraceSignalHandler(false);
238       }
239     }
240   } else {
241     perror("Failed to open /proc/self/task");
242   }
243   fatalSignalHandlerPostProcess();
244   sigaction(signum, getPreviousSigaction(signum), nullptr);
245   raise(signum);
246 }
247 
248 // Our SIGUSR2 entry point
stacktraceSignalHandlerStatic(int signum,siginfo_t * info,void * ctx)249 void FatalSignalHandler::stacktraceSignalHandlerStatic(
250     int signum,
251     siginfo_t* info,
252     void* ctx) {
253   getInstance().stacktraceSignalHandler(signum, info, ctx);
254 }
255 
stacktraceSignalHandler(int signum,siginfo_t * info,void * ctx)256 void FatalSignalHandler::stacktraceSignalHandler(
257     int signum,
258     siginfo_t* info,
259     void* ctx) {
260   if (fatalSignalReceived) {
261     stacktraceSignalHandler(true);
262   } else {
263     // We don't want to actually change the signal handler as we want to
264     // remain the signal handler so that we may get the usr2 signal later.
265     callPreviousSignalHandler(&previousSigusr2, signum, info, ctx);
266   }
267 }
268 
269 // Installs SIGABRT signal handler so that we get stack traces
270 // from every thread on SIGABRT caused exit. Also installs SIGUSR2 handler
271 // so that threads can communicate with each other (be sure if you use SIGUSR2)
272 // to install your handler before initing caffe2 (we properly fall back to
273 // the previous handler if we didn't initiate the SIGUSR2).
installFatalSignalHandlers()274 void FatalSignalHandler::installFatalSignalHandlers() {
275   std::lock_guard<std::mutex> locker(fatalSignalHandlersInstallationMutex);
276   if (fatalSignalHandlersInstalled) {
277     return;
278   }
279   fatalSignalHandlersInstalled = true;
280   struct sigaction sa {};
281   sigemptyset(&sa.sa_mask);
282   // Since we'll be in an exiting situation it's possible there's memory
283   // corruption, so make our own stack just in case.
284   sa.sa_flags = SA_ONSTACK | SA_SIGINFO;
285   sa.sa_handler = FatalSignalHandler::fatalSignalHandlerStatic;
286   for (auto* handler = kSignalHandlers; handler->name != nullptr; handler++) {
287     if (sigaction(handler->signum, &sa, &handler->previous)) {
288       std::string str("Failed to add ");
289       str += handler->name;
290       str += " handler!";
291       perror(str.c_str());
292     }
293   }
294   sa.sa_sigaction = FatalSignalHandler::stacktraceSignalHandlerStatic;
295   if (sigaction(SIGUSR2, &sa, &previousSigusr2)) {
296     perror("Failed to add SIGUSR2 handler!");
297   }
298 }
299 
uninstallFatalSignalHandlers()300 void FatalSignalHandler::uninstallFatalSignalHandlers() {
301   std::lock_guard<std::mutex> locker(fatalSignalHandlersInstallationMutex);
302   if (!fatalSignalHandlersInstalled) {
303     return;
304   }
305   fatalSignalHandlersInstalled = false;
306   for (auto* handler = kSignalHandlers; handler->name != nullptr; handler++) {
307     if (sigaction(handler->signum, &handler->previous, nullptr)) {
308       std::string str("Failed to remove ");
309       str += handler->name;
310       str += " handler!";
311       perror(str.c_str());
312     } else {
313       handler->previous = {};
314     }
315   }
316   if (sigaction(SIGUSR2, &previousSigusr2, nullptr)) {
317     perror("Failed to add SIGUSR2 handler!");
318   } else {
319     previousSigusr2 = {};
320   }
321 }
322 #endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
323 
SignalHandler(SignalHandler::Action SIGINT_action,SignalHandler::Action SIGHUP_action)324 SignalHandler::SignalHandler(
325     SignalHandler::Action SIGINT_action,
326     SignalHandler::Action SIGHUP_action)
327     : SIGINT_action_(SIGINT_action),
328       SIGHUP_action_(SIGHUP_action),
329       my_sigint_count_(sigintCount),
330       my_sighup_count_(sighupCount) {
331   hookupHandler();
332 }
333 
~SignalHandler()334 SignalHandler::~SignalHandler() {
335   unhookHandler();
336 }
337 
338 // Return true iff a SIGINT has been received since the last time this
339 // function was called.
GotSIGINT()340 bool SignalHandler::GotSIGINT() {
341   uint64_t count = sigintCount;
342   uint64_t localCount = my_sigint_count_.exchange(count);
343   return (localCount != count);
344 }
345 
346 // Return true iff a SIGHUP has been received since the last time this
347 // function was called.
GotSIGHUP()348 bool SignalHandler::GotSIGHUP() {
349   uint64_t count = sighupCount;
350   uint64_t localCount = my_sighup_count_.exchange(count);
351   return (localCount != count);
352 }
353 
CheckForSignals()354 SignalHandler::Action SignalHandler::CheckForSignals() {
355   if (GotSIGHUP()) {
356     return SIGHUP_action_;
357   }
358   if (GotSIGINT()) {
359     return SIGINT_action_;
360   }
361   return SignalHandler::Action::NONE;
362 }
363 
364 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
setPrintStackTracesOnFatalSignal(bool print)365 void FatalSignalHandler::setPrintStackTracesOnFatalSignal(bool print) {
366   if (print) {
367     installFatalSignalHandlers();
368   } else {
369     uninstallFatalSignalHandlers();
370   }
371 }
printStackTracesOnFatalSignal()372 bool FatalSignalHandler::printStackTracesOnFatalSignal() {
373   std::lock_guard<std::mutex> locker(fatalSignalHandlersInstallationMutex);
374   return fatalSignalHandlersInstalled;
375 }
376 
377 #endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
378 } // namespace c10
379 
380 #else // defined(C10_SUPPORTS_SIGNAL_HANDLER)
381 
382 // TODO: Currently we do not support signal handling in non-Linux yet - below is
383 // a minimal implementation that makes things compile.
384 namespace c10 {
SignalHandler(SignalHandler::Action SIGINT_action,SignalHandler::Action SIGHUP_action)385 SignalHandler::SignalHandler(
386     SignalHandler::Action SIGINT_action,
387     SignalHandler::Action SIGHUP_action) {
388   SIGINT_action_ = SIGINT_action;
389   SIGHUP_action_ = SIGHUP_action;
390   my_sigint_count_ = 0;
391   my_sighup_count_ = 0;
392 }
~SignalHandler()393 SignalHandler::~SignalHandler() {}
GotSIGINT()394 bool SignalHandler::GotSIGINT() {
395   return false;
396 }
GotSIGHUP()397 bool SignalHandler::GotSIGHUP() {
398   return false;
399 }
CheckForSignals()400 SignalHandler::Action SignalHandler::CheckForSignals() {
401   return SignalHandler::Action::NONE;
402 }
403 } // namespace c10
404 
405 #endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
406