xref: /aosp_15_r20/external/pytorch/c10/util/signal_handler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <condition_variable>
5 #include <csignal>
6 #include <cstdint>
7 #include <mutex>
8 
9 #include <c10/macros/Export.h>
10 
11 #if defined(__APPLE__)
12 #define C10_SUPPORTS_SIGNAL_HANDLER
13 #elif defined(__linux__) && !defined(C10_DISABLE_SIGNAL_HANDLERS)
14 #define C10_SUPPORTS_FATAL_SIGNAL_HANDLERS
15 #define C10_SUPPORTS_SIGNAL_HANDLER
16 #endif
17 
18 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
19 #include <pthread.h>
20 #endif
21 
22 namespace c10 {
23 
24 class C10_API SignalHandler {
25  public:
26   enum class Action { NONE, STOP };
27 
28   // Constructor. Specify what action to take when a signal is received.
29   SignalHandler(Action SIGINT_action, Action SIGHUP_action);
30   ~SignalHandler();
31 
32   Action CheckForSignals();
33 
34   bool GotSIGINT();
35   bool GotSIGHUP();
36 
37   Action SIGINT_action_;
38   Action SIGHUP_action_;
39   std::atomic<uint64_t> my_sigint_count_;
40   std::atomic<uint64_t> my_sighup_count_;
41 };
42 
43 #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
44 class C10_API FatalSignalHandler {
45   // This works by setting up certain fatal signal handlers. Previous fatal
46   // signal handlers will still be called when the signal is raised. Defaults
47   // to being off.
48  public:
49   C10_API void setPrintStackTracesOnFatalSignal(bool print);
50   C10_API bool printStackTracesOnFatalSignal();
51   static FatalSignalHandler& getInstance();
52   virtual ~FatalSignalHandler();
53 
54  protected:
55   explicit FatalSignalHandler();
56 
57  private:
58   void installFatalSignalHandlers();
59   void uninstallFatalSignalHandlers();
60   static void fatalSignalHandlerStatic(int signum);
61   void fatalSignalHandler(int signum);
62   virtual void fatalSignalHandlerPostProcess();
63   struct sigaction* getPreviousSigaction(int signum);
64   const char* getSignalName(int signum);
65   void callPreviousSignalHandler(
66       struct sigaction* action,
67       int signum,
68       siginfo_t* info,
69       void* ctx);
70   void stacktraceSignalHandler(bool needsLock);
71   static void stacktraceSignalHandlerStatic(
72       int signum,
73       siginfo_t* info,
74       void* ctx);
75   void stacktraceSignalHandler(int signum, siginfo_t* info, void* ctx);
76 
77   // The mutex protects the bool.
78   std::mutex fatalSignalHandlersInstallationMutex;
79   bool fatalSignalHandlersInstalled;
80   // We need to hold a reference to call the previous SIGUSR2 handler in case
81   // we didn't signal it
82   struct sigaction previousSigusr2 {};
83   // Flag dictating whether the SIGUSR2 handler falls back to previous handlers
84   // or is intercepted in order to print a stack trace.
85   std::atomic<bool> fatalSignalReceived;
86   // Global state set when a fatal signal is received so that backtracing
87   // threads know why they're printing a stacktrace.
88   const char* fatalSignalName;
89   int fatalSignum = -1;
90   // This wait condition is used to wait for other threads to finish writing
91   // their stack trace when in fatal sig handler (we can't use pthread_join
92   // because there's no way to convert from a tid to a pthread_t).
93   std::condition_variable writingCond;
94   std::mutex writingMutex;
95   // used to indicate if the other thread responded to the signal
96   bool signalReceived;
97 
98   struct signal_handler {
99     const char* name;
100     int signum;
101     struct sigaction previous;
102   };
103 
104   // NOLINTNEXTLINE(*c-arrays*)
105   static signal_handler kSignalHandlers[];
106 };
107 
108 #endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
109 
110 } // namespace c10
111