1*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDADeviceAssertionHost.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Backtrace.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
7*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime.h>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker #include <memory>
10*da0073e9SAndroid Build Coastguard Worker #include <string>
11*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
12*da0073e9SAndroid Build Coastguard Worker #include <chrono>
13*da0073e9SAndroid Build Coastguard Worker #include <thread>
14*da0073e9SAndroid Build Coastguard Worker #endif
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_CHECK_WO_DSA(EXPR) \
17*da0073e9SAndroid Build Coastguard Worker do { \
18*da0073e9SAndroid Build Coastguard Worker const cudaError_t __err = EXPR; \
19*da0073e9SAndroid Build Coastguard Worker c10::cuda::c10_cuda_check_implementation( \
20*da0073e9SAndroid Build Coastguard Worker static_cast<int32_t>(__err), \
21*da0073e9SAndroid Build Coastguard Worker __FILE__, \
22*da0073e9SAndroid Build Coastguard Worker __func__, /* Line number data type not well-defined between \
23*da0073e9SAndroid Build Coastguard Worker compilers, so we perform an explicit cast */ \
24*da0073e9SAndroid Build Coastguard Worker static_cast<uint32_t>(__LINE__), \
25*da0073e9SAndroid Build Coastguard Worker false); \
26*da0073e9SAndroid Build Coastguard Worker } while (0)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker namespace {
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
33*da0073e9SAndroid Build Coastguard Worker /// Get current device id
34*da0073e9SAndroid Build Coastguard Worker /// We need our own implementation of this function to prevent
35*da0073e9SAndroid Build Coastguard Worker /// an infinite initialization loop for CUDAKernelLaunchRegistry
dsa_get_device_id()36*da0073e9SAndroid Build Coastguard Worker int dsa_get_device_id() {
37*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = -1;
38*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(c10::cuda::GetDevice(&device));
39*da0073e9SAndroid Build Coastguard Worker return device;
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker /// Get a device's compute capability - note that this dangerously assumes
43*da0073e9SAndroid Build Coastguard Worker /// that if one CUDA GPU supports device-side assertions they all do. This is
44*da0073e9SAndroid Build Coastguard Worker /// probably fine since the latest CUDA GPU that doesn't support UVM is the
45*da0073e9SAndroid Build Coastguard Worker /// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be
46*da0073e9SAndroid Build Coastguard Worker /// rare enough that the defensive
47*da0073e9SAndroid Build Coastguard Worker /// We need our own implementation of this function to prevent
48*da0073e9SAndroid Build Coastguard Worker /// an infinite initialization loop for CUDAKernelLaunchRegistry
dsa_get_device_compute_capability(const int device_num)49*da0073e9SAndroid Build Coastguard Worker int dsa_get_device_compute_capability(const int device_num) {
50*da0073e9SAndroid Build Coastguard Worker int compute_capability = -1;
51*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(cudaDeviceGetAttribute(
52*da0073e9SAndroid Build Coastguard Worker &compute_capability, cudaDevAttrComputeCapabilityMajor, device_num));
53*da0073e9SAndroid Build Coastguard Worker return compute_capability;
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker #endif
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker /// Get the number of CUDA devices
58*da0073e9SAndroid Build Coastguard Worker /// We need our own implementation of this function to prevent
59*da0073e9SAndroid Build Coastguard Worker /// an infinite initialization loop for CUDAKernelLaunchRegistry
dsa_get_device_count()60*da0073e9SAndroid Build Coastguard Worker int dsa_get_device_count() {
61*da0073e9SAndroid Build Coastguard Worker int device_count = -1;
62*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(c10::cuda::GetDeviceCount(&device_count));
63*da0073e9SAndroid Build Coastguard Worker return device_count;
64*da0073e9SAndroid Build Coastguard Worker }
65*da0073e9SAndroid Build Coastguard Worker
dsa_check_if_all_devices_support_managed_memory()66*da0073e9SAndroid Build Coastguard Worker bool dsa_check_if_all_devices_support_managed_memory() {
67*da0073e9SAndroid Build Coastguard Worker // It looks as though this'll work best on CUDA GPUs with Pascal
68*da0073e9SAndroid Build Coastguard Worker // architectures or newer, per
69*da0073e9SAndroid Build Coastguard Worker // https://developer.nvidia.com/blog/unified-memory-cuda-beginners/
70*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
71*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(dsa_get_device_count())) {
72*da0073e9SAndroid Build Coastguard Worker if (dsa_get_device_compute_capability(i) < 6) {
73*da0073e9SAndroid Build Coastguard Worker return false;
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker }
76*da0073e9SAndroid Build Coastguard Worker return true;
77*da0073e9SAndroid Build Coastguard Worker #else
78*da0073e9SAndroid Build Coastguard Worker return false;
79*da0073e9SAndroid Build Coastguard Worker #endif
80*da0073e9SAndroid Build Coastguard Worker }
81*da0073e9SAndroid Build Coastguard Worker
env_flag_set(const char * env_var_name)82*da0073e9SAndroid Build Coastguard Worker bool env_flag_set(const char* env_var_name) {
83*da0073e9SAndroid Build Coastguard Worker const char* const env_string = std::getenv(env_var_name);
84*da0073e9SAndroid Build Coastguard Worker return (env_string == nullptr) ? false : std::strcmp(env_string, "0");
85*da0073e9SAndroid Build Coastguard Worker }
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker /// Deleter for UVM/managed memory pointers
uvm_deleter(DeviceAssertionsData * uvm_assertions_ptr)88*da0073e9SAndroid Build Coastguard Worker void uvm_deleter(DeviceAssertionsData* uvm_assertions_ptr) {
89*da0073e9SAndroid Build Coastguard Worker // Ignore error in destructor
90*da0073e9SAndroid Build Coastguard Worker if (uvm_assertions_ptr) {
91*da0073e9SAndroid Build Coastguard Worker C10_CUDA_IGNORE_ERROR(cudaFree(uvm_assertions_ptr));
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker } // namespace
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker /// Check that kernels ran correctly by checking the message buffer. BLOCKING.
c10_retrieve_device_side_assertion_info()98*da0073e9SAndroid Build Coastguard Worker std::string c10_retrieve_device_side_assertion_info() {
99*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
100*da0073e9SAndroid Build Coastguard Worker const auto& launch_registry = CUDAKernelLaunchRegistry::get_singleton_ref();
101*da0073e9SAndroid Build Coastguard Worker if (!launch_registry.enabled_at_runtime) {
102*da0073e9SAndroid Build Coastguard Worker return "Device-side assertion tracking was not enabled by user.";
103*da0073e9SAndroid Build Coastguard Worker } else if (!launch_registry.do_all_devices_support_managed_memory) {
104*da0073e9SAndroid Build Coastguard Worker return "Device-side assertions disabled because not all devices support managed memory.";
105*da0073e9SAndroid Build Coastguard Worker }
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker // Hack that saves a lot of challenging sync logic.
108*da0073e9SAndroid Build Coastguard Worker // The GPU increments the number of errors it's observed and the CPU can see
109*da0073e9SAndroid Build Coastguard Worker // that happening immediately which means we can make it here before the GPU
110*da0073e9SAndroid Build Coastguard Worker // is done writing information about those errors to memory.
111*da0073e9SAndroid Build Coastguard Worker // A short pause gives it time to finish. Since something's gone wrong, this
112*da0073e9SAndroid Build Coastguard Worker // pause shouldn't affect perf.
113*da0073e9SAndroid Build Coastguard Worker std::this_thread::sleep_for(std::chrono::seconds(1));
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker // The snapshot causes a brief block. That's okay because this function only
116*da0073e9SAndroid Build Coastguard Worker // executes if something's gone wrong such that speed is no longer a priority.
117*da0073e9SAndroid Build Coastguard Worker const auto launch_data = launch_registry.snapshot();
118*da0073e9SAndroid Build Coastguard Worker const auto& assertion_data = launch_data.first;
119*da0073e9SAndroid Build Coastguard Worker const auto& launch_infos = launch_data.second;
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker std::stringstream oss;
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker oss << "Looking for device-side assertion failure information...\n";
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker // Loop over each device that could be managed by the process
126*da0073e9SAndroid Build Coastguard Worker for (const auto device_num : c10::irange(assertion_data.size())) {
127*da0073e9SAndroid Build Coastguard Worker const auto& assertion_data_for_device = assertion_data.at(device_num);
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker // Did anything fail?
130*da0073e9SAndroid Build Coastguard Worker const auto failures_found = std::min(
131*da0073e9SAndroid Build Coastguard Worker assertion_data_for_device.assertion_count,
132*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DSA_ASSERTION_COUNT);
133*da0073e9SAndroid Build Coastguard Worker if (failures_found == 0) {
134*da0073e9SAndroid Build Coastguard Worker continue;
135*da0073e9SAndroid Build Coastguard Worker }
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker // Something failed, let's talk about that
138*da0073e9SAndroid Build Coastguard Worker oss << failures_found
139*da0073e9SAndroid Build Coastguard Worker << " CUDA device-side assertion failures were found on GPU #"
140*da0073e9SAndroid Build Coastguard Worker << device_num << "!" << std::endl;
141*da0073e9SAndroid Build Coastguard Worker if (assertion_data_for_device.assertion_count >
142*da0073e9SAndroid Build Coastguard Worker C10_CUDA_DSA_ASSERTION_COUNT) {
143*da0073e9SAndroid Build Coastguard Worker oss << "But at least " << assertion_data_for_device.assertion_count
144*da0073e9SAndroid Build Coastguard Worker << " assertion failures occurred on the device" << std::endl;
145*da0073e9SAndroid Build Coastguard Worker oss << "Adjust `C10_CUDA_DSA_ASSERTION_COUNT` if you need more assertion failure info"
146*da0073e9SAndroid Build Coastguard Worker << std::endl;
147*da0073e9SAndroid Build Coastguard Worker }
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(failures_found)) {
150*da0073e9SAndroid Build Coastguard Worker const auto& self = assertion_data_for_device.assertions[i];
151*da0073e9SAndroid Build Coastguard Worker const auto& launch_info = launch_infos[self.caller % launch_infos.size()];
152*da0073e9SAndroid Build Coastguard Worker oss << "Assertion failure " << i << std::endl;
153*da0073e9SAndroid Build Coastguard Worker oss << " GPU assertion failure message = " << self.assertion_msg
154*da0073e9SAndroid Build Coastguard Worker << std::endl;
155*da0073e9SAndroid Build Coastguard Worker oss << " File containing assertion = " << self.filename << ":"
156*da0073e9SAndroid Build Coastguard Worker << self.line_number << std::endl;
157*da0073e9SAndroid Build Coastguard Worker oss << " Device function containing assertion = " << self.function_name
158*da0073e9SAndroid Build Coastguard Worker << std::endl;
159*da0073e9SAndroid Build Coastguard Worker oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
160*da0073e9SAndroid Build Coastguard Worker << self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
161*da0073e9SAndroid Build Coastguard Worker oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
162*da0073e9SAndroid Build Coastguard Worker << self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
163*da0073e9SAndroid Build Coastguard Worker if (launch_info.generation_number == self.caller) {
164*da0073e9SAndroid Build Coastguard Worker oss << " File containing kernel launch = "
165*da0073e9SAndroid Build Coastguard Worker << launch_info.launch_filename << ":" << launch_info.launch_linenum
166*da0073e9SAndroid Build Coastguard Worker << std::endl;
167*da0073e9SAndroid Build Coastguard Worker oss << " Function containing kernel launch = "
168*da0073e9SAndroid Build Coastguard Worker << launch_info.launch_function << std::endl;
169*da0073e9SAndroid Build Coastguard Worker oss << " Name of kernel launched that led to failure = "
170*da0073e9SAndroid Build Coastguard Worker << launch_info.kernel_name << std::endl;
171*da0073e9SAndroid Build Coastguard Worker oss << " Device that launched kernel = " << launch_info.device
172*da0073e9SAndroid Build Coastguard Worker << std::endl;
173*da0073e9SAndroid Build Coastguard Worker oss << " Stream kernel was launched on = " << launch_info.stream
174*da0073e9SAndroid Build Coastguard Worker << std::endl;
175*da0073e9SAndroid Build Coastguard Worker oss << " Backtrace of kernel launch site = ";
176*da0073e9SAndroid Build Coastguard Worker if (launch_registry.gather_launch_stacktrace) {
177*da0073e9SAndroid Build Coastguard Worker oss << "Launch stacktracing disabled." << std::endl;
178*da0073e9SAndroid Build Coastguard Worker } else {
179*da0073e9SAndroid Build Coastguard Worker oss << "\n" << launch_info.launch_stacktrace << std::endl;
180*da0073e9SAndroid Build Coastguard Worker }
181*da0073e9SAndroid Build Coastguard Worker } else {
182*da0073e9SAndroid Build Coastguard Worker oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `CUDAKernelLaunchRegistry::max_size`."
183*da0073e9SAndroid Build Coastguard Worker << std::endl;
184*da0073e9SAndroid Build Coastguard Worker }
185*da0073e9SAndroid Build Coastguard Worker }
186*da0073e9SAndroid Build Coastguard Worker }
187*da0073e9SAndroid Build Coastguard Worker return oss.str();
188*da0073e9SAndroid Build Coastguard Worker #else
189*da0073e9SAndroid Build Coastguard Worker return "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n";
190*da0073e9SAndroid Build Coastguard Worker #endif
191*da0073e9SAndroid Build Coastguard Worker }
192*da0073e9SAndroid Build Coastguard Worker
CUDAKernelLaunchRegistry()193*da0073e9SAndroid Build Coastguard Worker CUDAKernelLaunchRegistry::CUDAKernelLaunchRegistry()
194*da0073e9SAndroid Build Coastguard Worker : do_all_devices_support_managed_memory(
195*da0073e9SAndroid Build Coastguard Worker dsa_check_if_all_devices_support_managed_memory()),
196*da0073e9SAndroid Build Coastguard Worker gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()),
197*da0073e9SAndroid Build Coastguard Worker enabled_at_runtime(check_env_for_dsa_enabled()) {
198*da0073e9SAndroid Build Coastguard Worker for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) {
199*da0073e9SAndroid Build Coastguard Worker uvm_assertions.emplace_back(nullptr, uvm_deleter);
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker kernel_launches.resize(max_kernel_launches);
203*da0073e9SAndroid Build Coastguard Worker }
204*da0073e9SAndroid Build Coastguard Worker
check_env_for_enable_launch_stacktracing() const205*da0073e9SAndroid Build Coastguard Worker bool CUDAKernelLaunchRegistry::check_env_for_enable_launch_stacktracing()
206*da0073e9SAndroid Build Coastguard Worker const {
207*da0073e9SAndroid Build Coastguard Worker return env_flag_set("PYTORCH_CUDA_DSA_STACKTRACING");
208*da0073e9SAndroid Build Coastguard Worker }
209*da0073e9SAndroid Build Coastguard Worker
check_env_for_dsa_enabled() const210*da0073e9SAndroid Build Coastguard Worker bool CUDAKernelLaunchRegistry::check_env_for_dsa_enabled() const {
211*da0073e9SAndroid Build Coastguard Worker return env_flag_set("PYTORCH_USE_CUDA_DSA");
212*da0073e9SAndroid Build Coastguard Worker }
213*da0073e9SAndroid Build Coastguard Worker
insert(const char * launch_filename,const char * launch_function,const uint32_t launch_linenum,const char * kernel_name,const int32_t stream_id)214*da0073e9SAndroid Build Coastguard Worker uint32_t CUDAKernelLaunchRegistry::insert(
215*da0073e9SAndroid Build Coastguard Worker const char* launch_filename,
216*da0073e9SAndroid Build Coastguard Worker const char* launch_function,
217*da0073e9SAndroid Build Coastguard Worker const uint32_t launch_linenum,
218*da0073e9SAndroid Build Coastguard Worker const char* kernel_name,
219*da0073e9SAndroid Build Coastguard Worker const int32_t stream_id) {
220*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
221*da0073e9SAndroid Build Coastguard Worker if (!enabled_at_runtime) {
222*da0073e9SAndroid Build Coastguard Worker return 0;
223*da0073e9SAndroid Build Coastguard Worker }
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker const auto backtrace = gather_launch_stacktrace ? c10::get_backtrace() : "";
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker const std::lock_guard<std::mutex> lock(read_write_mutex);
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker const auto my_gen_number = generation_number++;
230*da0073e9SAndroid Build Coastguard Worker // TODO: It would probably be good to get a stack trace here so that
231*da0073e9SAndroid Build Coastguard Worker // we can better indicate which launch caused the failure.
232*da0073e9SAndroid Build Coastguard Worker kernel_launches[my_gen_number % max_kernel_launches] = {
233*da0073e9SAndroid Build Coastguard Worker launch_filename,
234*da0073e9SAndroid Build Coastguard Worker launch_function,
235*da0073e9SAndroid Build Coastguard Worker launch_linenum,
236*da0073e9SAndroid Build Coastguard Worker backtrace,
237*da0073e9SAndroid Build Coastguard Worker kernel_name,
238*da0073e9SAndroid Build Coastguard Worker dsa_get_device_id(),
239*da0073e9SAndroid Build Coastguard Worker stream_id,
240*da0073e9SAndroid Build Coastguard Worker my_gen_number};
241*da0073e9SAndroid Build Coastguard Worker return my_gen_number;
242*da0073e9SAndroid Build Coastguard Worker #else
243*da0073e9SAndroid Build Coastguard Worker return 0;
244*da0073e9SAndroid Build Coastguard Worker #endif
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker std::pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>>
snapshot() const248*da0073e9SAndroid Build Coastguard Worker CUDAKernelLaunchRegistry::snapshot() const {
249*da0073e9SAndroid Build Coastguard Worker // This is likely to be the longest-lasting hold on the mutex, but
250*da0073e9SAndroid Build Coastguard Worker // we only expect it to be called in cases where we're already failing
251*da0073e9SAndroid Build Coastguard Worker // and speed is no longer important
252*da0073e9SAndroid Build Coastguard Worker const std::lock_guard<std::mutex> lock(read_write_mutex);
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker std::vector<DeviceAssertionsData> device_assertions_data;
255*da0073e9SAndroid Build Coastguard Worker for (const auto& x : uvm_assertions) {
256*da0073e9SAndroid Build Coastguard Worker if (x) {
257*da0073e9SAndroid Build Coastguard Worker device_assertions_data.push_back(*x);
258*da0073e9SAndroid Build Coastguard Worker } else {
259*da0073e9SAndroid Build Coastguard Worker device_assertions_data.emplace_back();
260*da0073e9SAndroid Build Coastguard Worker }
261*da0073e9SAndroid Build Coastguard Worker }
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker return std::make_pair(device_assertions_data, kernel_launches);
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker DeviceAssertionsData* CUDAKernelLaunchRegistry::
get_uvm_assertions_ptr_for_current_device()267*da0073e9SAndroid Build Coastguard Worker get_uvm_assertions_ptr_for_current_device() {
268*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
269*da0073e9SAndroid Build Coastguard Worker if (!enabled_at_runtime) {
270*da0073e9SAndroid Build Coastguard Worker return nullptr;
271*da0073e9SAndroid Build Coastguard Worker }
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker const auto device_num = dsa_get_device_id();
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker // If we've already set up this GPU with managed memory, return a pointer to
276*da0073e9SAndroid Build Coastguard Worker // the managed memory. This is a lock-free quick-return path.
277*da0073e9SAndroid Build Coastguard Worker if (uvm_assertions.at(device_num)) {
278*da0073e9SAndroid Build Coastguard Worker return uvm_assertions.at(device_num).get();
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker // Need a lock here so there's not race-condition on creating the new device
282*da0073e9SAndroid Build Coastguard Worker // assertions buffer
283*da0073e9SAndroid Build Coastguard Worker const std::lock_guard<std::mutex> lock(gpu_alloc_mutex);
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker // If we've already set up this GPU with managed memory, return a pointer to
286*da0073e9SAndroid Build Coastguard Worker // the managed memory. This locked path ensures that the device memory is
287*da0073e9SAndroid Build Coastguard Worker // allocated only once
288*da0073e9SAndroid Build Coastguard Worker if (uvm_assertions.at(device_num)) {
289*da0073e9SAndroid Build Coastguard Worker return uvm_assertions.at(device_num).get();
290*da0073e9SAndroid Build Coastguard Worker }
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker // Otherwise, set up the GPU to be able to use the device-side assertion
293*da0073e9SAndroid Build Coastguard Worker // system
294*da0073e9SAndroid Build Coastguard Worker DeviceAssertionsData* uvm_assertions_ptr = nullptr;
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(
297*da0073e9SAndroid Build Coastguard Worker cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
300*da0073e9SAndroid Build Coastguard Worker uvm_assertions_ptr,
301*da0073e9SAndroid Build Coastguard Worker sizeof(DeviceAssertionsData),
302*da0073e9SAndroid Build Coastguard Worker cudaMemAdviseSetPreferredLocation,
303*da0073e9SAndroid Build Coastguard Worker cudaCpuDeviceId));
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker // GPU will establish direct mapping of data in CPU memory, no page faults
306*da0073e9SAndroid Build Coastguard Worker // will be generated
307*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WO_DSA(cudaMemAdvise(
308*da0073e9SAndroid Build Coastguard Worker uvm_assertions_ptr,
309*da0073e9SAndroid Build Coastguard Worker sizeof(DeviceAssertionsData),
310*da0073e9SAndroid Build Coastguard Worker cudaMemAdviseSetAccessedBy,
311*da0073e9SAndroid Build Coastguard Worker cudaCpuDeviceId));
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker // Initialize the memory from the CPU; otherwise, pages may have to be created
314*da0073e9SAndroid Build Coastguard Worker // on demand. We think that UVM documentation indicates that first access may
315*da0073e9SAndroid Build Coastguard Worker // not honor preferred location, which would be bad, if true, because we want
316*da0073e9SAndroid Build Coastguard Worker // this memory on the host so we can access it post-assertion. Initializing
317*da0073e9SAndroid Build Coastguard Worker // this on the CPU helps ensure that that's where the memory will live.
318*da0073e9SAndroid Build Coastguard Worker *uvm_assertions_ptr = DeviceAssertionsData();
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker // Ownership and lifetime management of `uvm_assertions_ptr` now passes to the
321*da0073e9SAndroid Build Coastguard Worker // uvm_assertions unique_ptr vector
322*da0073e9SAndroid Build Coastguard Worker uvm_assertions.at(device_num).reset(uvm_assertions_ptr);
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker return uvm_assertions_ptr;
325*da0073e9SAndroid Build Coastguard Worker #else
326*da0073e9SAndroid Build Coastguard Worker return nullptr;
327*da0073e9SAndroid Build Coastguard Worker #endif
328*da0073e9SAndroid Build Coastguard Worker }
329*da0073e9SAndroid Build Coastguard Worker
get_singleton_ref()330*da0073e9SAndroid Build Coastguard Worker CUDAKernelLaunchRegistry& CUDAKernelLaunchRegistry::get_singleton_ref() {
331*da0073e9SAndroid Build Coastguard Worker static CUDAKernelLaunchRegistry launch_registry;
332*da0073e9SAndroid Build Coastguard Worker return launch_registry;
333*da0073e9SAndroid Build Coastguard Worker }
334*da0073e9SAndroid Build Coastguard Worker
has_failed() const335*da0073e9SAndroid Build Coastguard Worker bool CUDAKernelLaunchRegistry::has_failed() const {
336*da0073e9SAndroid Build Coastguard Worker for (const auto& x : uvm_assertions) {
337*da0073e9SAndroid Build Coastguard Worker if (x && x->assertion_count > 0) {
338*da0073e9SAndroid Build Coastguard Worker return true;
339*da0073e9SAndroid Build Coastguard Worker }
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker return false;
342*da0073e9SAndroid Build Coastguard Worker }
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
345