xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDADeviceAssertionHost.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
6*da0073e9SAndroid Build Coastguard Worker #include <memory>
7*da0073e9SAndroid Build Coastguard Worker #include <mutex>
8*da0073e9SAndroid Build Coastguard Worker #include <string>
9*da0073e9SAndroid Build Coastguard Worker #include <utility>
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #ifdef USE_CUDA
13*da0073e9SAndroid Build Coastguard Worker #define TORCH_USE_CUDA_DSA
14*da0073e9SAndroid Build Coastguard Worker #endif
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker /// Number of assertion failure messages we can store. If this is too small
17*da0073e9SAndroid Build Coastguard Worker /// threads will fail silently.
18*da0073e9SAndroid Build Coastguard Worker constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10;
19*da0073e9SAndroid Build Coastguard Worker constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512;
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker /// Holds information about any device-side assertions that fail.
24*da0073e9SAndroid Build Coastguard Worker /// Held in managed memory and access by both the CPU and the GPU.
25*da0073e9SAndroid Build Coastguard Worker struct DeviceAssertionData {
26*da0073e9SAndroid Build Coastguard Worker   /// Stringification of the assertion
27*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
28*da0073e9SAndroid Build Coastguard Worker   char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]{};
29*da0073e9SAndroid Build Coastguard Worker   /// File the assertion was in
30*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
31*da0073e9SAndroid Build Coastguard Worker   char filename[C10_CUDA_DSA_MAX_STR_LEN]{};
32*da0073e9SAndroid Build Coastguard Worker   /// Name of the function the assertion was in
33*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
34*da0073e9SAndroid Build Coastguard Worker   char function_name[C10_CUDA_DSA_MAX_STR_LEN]{};
35*da0073e9SAndroid Build Coastguard Worker   /// Line number the assertion was at
36*da0073e9SAndroid Build Coastguard Worker   int line_number{};
37*da0073e9SAndroid Build Coastguard Worker   /// Number uniquely identifying the kernel launch that triggered the assertion
38*da0073e9SAndroid Build Coastguard Worker   uint32_t caller{};
39*da0073e9SAndroid Build Coastguard Worker   /// block_id of the thread that failed the assertion
40*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
41*da0073e9SAndroid Build Coastguard Worker   int32_t block_id[3]{};
42*da0073e9SAndroid Build Coastguard Worker   /// third_id of the thread that failed the assertion
43*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
44*da0073e9SAndroid Build Coastguard Worker   int32_t thread_id[3]{};
45*da0073e9SAndroid Build Coastguard Worker };
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker /// Used to hold assertions generated by the device
48*da0073e9SAndroid Build Coastguard Worker /// Held in managed memory and access by both the CPU and the GPU.
49*da0073e9SAndroid Build Coastguard Worker struct DeviceAssertionsData {
50*da0073e9SAndroid Build Coastguard Worker   /// Total number of assertions found; a subset of thse will be recorded
51*da0073e9SAndroid Build Coastguard Worker   /// in `assertions`
52*da0073e9SAndroid Build Coastguard Worker   int32_t assertion_count{};
53*da0073e9SAndroid Build Coastguard Worker   /// An array of assertions that will be written to in a race-free manner
54*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-c-arrays)
55*da0073e9SAndroid Build Coastguard Worker   DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]{};
56*da0073e9SAndroid Build Coastguard Worker };
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker /// Use to hold info about kernel launches so that we can run kernels
59*da0073e9SAndroid Build Coastguard Worker /// asynchronously and still associate launches with device-side
60*da0073e9SAndroid Build Coastguard Worker /// assertion failures
61*da0073e9SAndroid Build Coastguard Worker struct CUDAKernelLaunchInfo {
62*da0073e9SAndroid Build Coastguard Worker   /// Filename of the code where the kernel was launched from
63*da0073e9SAndroid Build Coastguard Worker   const char* launch_filename;
64*da0073e9SAndroid Build Coastguard Worker   /// Function from which the kernel was launched
65*da0073e9SAndroid Build Coastguard Worker   const char* launch_function;
66*da0073e9SAndroid Build Coastguard Worker   /// Line number of where the code was launched from
67*da0073e9SAndroid Build Coastguard Worker   uint32_t launch_linenum;
68*da0073e9SAndroid Build Coastguard Worker   /// Backtrace of where the kernel was launched from, only populated if
69*da0073e9SAndroid Build Coastguard Worker   /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True
70*da0073e9SAndroid Build Coastguard Worker   std::string launch_stacktrace;
71*da0073e9SAndroid Build Coastguard Worker   /// Kernel that was launched
72*da0073e9SAndroid Build Coastguard Worker   const char* kernel_name;
73*da0073e9SAndroid Build Coastguard Worker   /// Device the kernel was launched on
74*da0073e9SAndroid Build Coastguard Worker   int device;
75*da0073e9SAndroid Build Coastguard Worker   /// Stream the kernel was launched on
76*da0073e9SAndroid Build Coastguard Worker   int32_t stream;
77*da0073e9SAndroid Build Coastguard Worker   /// A number that uniquely identifies the kernel launch
78*da0073e9SAndroid Build Coastguard Worker   uint64_t generation_number;
79*da0073e9SAndroid Build Coastguard Worker };
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker /// Circular buffer used to hold information about kernel launches
82*da0073e9SAndroid Build Coastguard Worker /// this is later used to reconstruct how a device-side kernel assertion failure
83*da0073e9SAndroid Build Coastguard Worker /// occurred CUDAKernelLaunchRegistry is used as a singleton
84*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAKernelLaunchRegistry {
85*da0073e9SAndroid Build Coastguard Worker  private:
86*da0073e9SAndroid Build Coastguard Worker   /// Assume that this is the max number of kernel launches that might ever be
87*da0073e9SAndroid Build Coastguard Worker   /// enqueued across all streams on a single device
88*da0073e9SAndroid Build Coastguard Worker   static constexpr int max_kernel_launches = 1024;
89*da0073e9SAndroid Build Coastguard Worker   /// How many kernel launch infos we've inserted. Used to ensure that circular
90*da0073e9SAndroid Build Coastguard Worker   /// queue doesn't provide false information by always increasing, but also to
91*da0073e9SAndroid Build Coastguard Worker   /// mark where we are inserting into the queue
92*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
93*da0073e9SAndroid Build Coastguard Worker   uint64_t generation_number = 0;
94*da0073e9SAndroid Build Coastguard Worker #endif
95*da0073e9SAndroid Build Coastguard Worker   /// Shared mutex between writer and accessor to ensure multi-threaded safety.
96*da0073e9SAndroid Build Coastguard Worker   mutable std::mutex read_write_mutex;
97*da0073e9SAndroid Build Coastguard Worker   /// Used to ensure prevent race conditions in GPU memory allocation
98*da0073e9SAndroid Build Coastguard Worker   mutable std::mutex gpu_alloc_mutex;
99*da0073e9SAndroid Build Coastguard Worker   /// Pointer to managed memory keeping track of device-side assertions. There
100*da0073e9SAndroid Build Coastguard Worker   /// is one entry for each possible device the process might work with. Unused
101*da0073e9SAndroid Build Coastguard Worker   /// entries are nullptrs. We could also use an unordered_set here, but this
102*da0073e9SAndroid Build Coastguard Worker   /// vector design will be faster and the wasted memory is small since we
103*da0073e9SAndroid Build Coastguard Worker   /// expect the number of GPUs per node will always be small
104*da0073e9SAndroid Build Coastguard Worker   std::vector<
105*da0073e9SAndroid Build Coastguard Worker       std::unique_ptr<DeviceAssertionsData, void (*)(DeviceAssertionsData*)>>
106*da0073e9SAndroid Build Coastguard Worker       uvm_assertions;
107*da0073e9SAndroid Build Coastguard Worker   /// A single circular buffer holds information about every kernel launch the
108*da0073e9SAndroid Build Coastguard Worker   /// process makes across all devices.
109*da0073e9SAndroid Build Coastguard Worker   std::vector<CUDAKernelLaunchInfo> kernel_launches;
110*da0073e9SAndroid Build Coastguard Worker   bool check_env_for_enable_launch_stacktracing() const;
111*da0073e9SAndroid Build Coastguard Worker   bool check_env_for_dsa_enabled() const;
112*da0073e9SAndroid Build Coastguard Worker 
113*da0073e9SAndroid Build Coastguard Worker  public:
114*da0073e9SAndroid Build Coastguard Worker   CUDAKernelLaunchRegistry();
115*da0073e9SAndroid Build Coastguard Worker   /// Register a new kernel launch and obtain a generation number back to be
116*da0073e9SAndroid Build Coastguard Worker   /// passed to the kernel
117*da0073e9SAndroid Build Coastguard Worker   uint32_t insert(
118*da0073e9SAndroid Build Coastguard Worker       const char* launch_filename,
119*da0073e9SAndroid Build Coastguard Worker       const char* launch_function,
120*da0073e9SAndroid Build Coastguard Worker       const uint32_t launch_linenum,
121*da0073e9SAndroid Build Coastguard Worker       const char* kernel_name,
122*da0073e9SAndroid Build Coastguard Worker       const int32_t stream_id);
123*da0073e9SAndroid Build Coastguard Worker   /// Get copies of the kernel launch registry and each device's assertion
124*da0073e9SAndroid Build Coastguard Worker   /// failure buffer so they can be inspected without raising race conditions
125*da0073e9SAndroid Build Coastguard Worker   std::
126*da0073e9SAndroid Build Coastguard Worker       pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>>
127*da0073e9SAndroid Build Coastguard Worker       snapshot() const;
128*da0073e9SAndroid Build Coastguard Worker   /// Get a pointer to the current device's assertion failure buffer. If no such
129*da0073e9SAndroid Build Coastguard Worker   /// buffer exists then one is created. This means that the first kernel launch
130*da0073e9SAndroid Build Coastguard Worker   /// made on each device will be slightly slower because memory allocations are
131*da0073e9SAndroid Build Coastguard Worker   /// required
132*da0073e9SAndroid Build Coastguard Worker   DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device();
133*da0073e9SAndroid Build Coastguard Worker   /// Gets the global singleton of the registry
134*da0073e9SAndroid Build Coastguard Worker   static CUDAKernelLaunchRegistry& get_singleton_ref();
135*da0073e9SAndroid Build Coastguard Worker   /// If not all devices support DSA, we disable it
136*da0073e9SAndroid Build Coastguard Worker   const bool do_all_devices_support_managed_memory = false;
137*da0073e9SAndroid Build Coastguard Worker   /// Whether or not to gather stack traces when launching kernels
138*da0073e9SAndroid Build Coastguard Worker   bool gather_launch_stacktrace = false;
139*da0073e9SAndroid Build Coastguard Worker   /// Whether or not host-side DSA is enabled or disabled at run-time
140*da0073e9SAndroid Build Coastguard Worker   /// Note: Device-side code cannot be enabled/disabled at run-time
141*da0073e9SAndroid Build Coastguard Worker   bool enabled_at_runtime = false;
142*da0073e9SAndroid Build Coastguard Worker   /// Whether or not a device has indicated a failure
143*da0073e9SAndroid Build Coastguard Worker   bool has_failed() const;
144*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
145*da0073e9SAndroid Build Coastguard Worker   const bool enabled_at_compile_time = true;
146*da0073e9SAndroid Build Coastguard Worker #else
147*da0073e9SAndroid Build Coastguard Worker   const bool enabled_at_compile_time = false;
148*da0073e9SAndroid Build Coastguard Worker #endif
149*da0073e9SAndroid Build Coastguard Worker };
150*da0073e9SAndroid Build Coastguard Worker 
151*da0073e9SAndroid Build Coastguard Worker std::string c10_retrieve_device_side_assertion_info();
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
154*da0073e9SAndroid Build Coastguard Worker 
155*da0073e9SAndroid Build Coastguard Worker // Each kernel launched with TORCH_DSA_KERNEL_LAUNCH
156*da0073e9SAndroid Build Coastguard Worker // requires the same input arguments. We introduce the following macro to
157*da0073e9SAndroid Build Coastguard Worker // standardize these.
158*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_ARGS                                              \
159*da0073e9SAndroid Build Coastguard Worker   [[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \
160*da0073e9SAndroid Build Coastguard Worker       [[maybe_unused]] uint32_t assertion_caller_id
161*da0073e9SAndroid Build Coastguard Worker 
162*da0073e9SAndroid Build Coastguard Worker // This macro can be used to pass the DSA arguments onward to another
163*da0073e9SAndroid Build Coastguard Worker // function
164*da0073e9SAndroid Build Coastguard Worker #define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id
165