xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDADeviceAssertion.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/cuda/CUDAException.h>
4 #include <c10/macros/Macros.h>
5 
6 namespace c10::cuda {
7 
8 #ifdef TORCH_USE_CUDA_DSA
9 // Copy string from `src` to `dst`
dstrcpy(char * dst,const char * src)10 static __device__ void dstrcpy(char* dst, const char* src) {
11   int i = 0;
12   // Copy string from source to destination, ensuring that it
13   // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
14   while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
15     *dst++ = *src++;
16   }
17   *dst = '\0';
18 }
19 
dsa_add_new_assertion_failure(DeviceAssertionsData * assertions_data,const char * assertion_msg,const char * filename,const char * function_name,const int line_number,const uint32_t caller,const dim3 block_id,const dim3 thread_id)20 static __device__ void dsa_add_new_assertion_failure(
21     DeviceAssertionsData* assertions_data,
22     const char* assertion_msg,
23     const char* filename,
24     const char* function_name,
25     const int line_number,
26     const uint32_t caller,
27     const dim3 block_id,
28     const dim3 thread_id) {
29   // `assertions_data` may be nullptr if device-side assertion checking
30   // is disabled at run-time. If it is disabled at compile time this
31   // function will never be called
32   if (!assertions_data) {
33     return;
34   }
35 
36   // Atomically increment so other threads can fail at the same time
37   // Note that incrementing this means that the CPU can observe that
38   // a failure has happened and can begin to respond before we've
39   // written information about that failure out to the buffer.
40   const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
41 
42   if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
43     // At this point we're ran out of assertion buffer space.
44     // We could print a message about this, but that'd get
45     // spammy if a lot of threads did it, so we just silently
46     // ignore any other assertion failures. In most cases the
47     // failures will all probably be analogous anyway.
48     return;
49   }
50 
51   // Write information about the assertion failure to memory.
52   // Note that this occurs only after the `assertion_count`
53   // increment broadcasts that there's been a problem.
54   auto& self = assertions_data->assertions[nid];
55   dstrcpy(self.assertion_msg, assertion_msg);
56   dstrcpy(self.filename, filename);
57   dstrcpy(self.function_name, function_name);
58   self.line_number = line_number;
59   self.caller = caller;
60   self.block_id[0] = block_id.x;
61   self.block_id[1] = block_id.y;
62   self.block_id[2] = block_id.z;
63   self.thread_id[0] = thread_id.x;
64   self.thread_id[1] = thread_id.y;
65   self.thread_id[2] = thread_id.z;
66 }
67 
68 // Emulates a kernel assertion. The assertion won't stop the kernel's progress,
69 // so you should assume everything the kernel produces is garbage if there's an
70 // assertion failure.
71 // NOTE: This assumes that `assertions_data` and  `assertion_caller_id` are
72 //       arguments of the kernel and therefore accessible.
73 #define CUDA_KERNEL_ASSERT2(condition)                                   \
74   do {                                                                   \
75     if (C10_UNLIKELY(!(condition))) {                                    \
76       /* Has an atomic element so threads can fail at the same time */   \
77       c10::cuda::dsa_add_new_assertion_failure(                          \
78           assertions_data,                                               \
79           C10_STRINGIZE(condition),                                      \
80           __FILE__,                                                      \
81           __FUNCTION__,                                                  \
82           __LINE__,                                                      \
83           assertion_caller_id,                                           \
84           blockIdx,                                                      \
85           threadIdx);                                                    \
86       /* Now that the kernel has failed we early exit the kernel, but */ \
87       /* otherwise keep going and rely on the host to check UVM and */   \
88       /* determine we've had a problem */                                \
89       return;                                                            \
90     }                                                                    \
91   } while (false)
92 #else
93 #define CUDA_KERNEL_ASSERT2(condition) assert(condition)
94 #endif
95 
96 } // namespace c10::cuda
97