1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13
14 using ::testing::HasSubstr;
15
did_not_fail_diagnostics()16 void did_not_fail_diagnostics() {
17 std::cerr
18 << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = "
19 << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime
20 << std::endl;
21 std::cerr
22 << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_compile_time = "
23 << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_compile_time
24 << std::endl;
25 std::cerr
26 << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().do_all_devices_support_managed_memory = "
27 << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref()
28 .do_all_devices_support_managed_memory
29 << std::endl;
30 }
31
32 /**
33 * Device kernel that takes a single integer parameter as argument and
34 * will always trigger a device side assertion.
35 */
cuda_always_fail_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)36 __global__ void cuda_always_fail_assertion_kernel(
37 const int a,
38 TORCH_DSA_KERNEL_ARGS) {
39 CUDA_KERNEL_ASSERT2(a != a);
40 }
41
42 /**
43 * TEST: Triggering device side assertion on a simple <<<1,1>>> config.
44 * kernel used takes only 1 variable as parameter function.
45 */
cuda_device_assertions_1_var_test()46 void cuda_device_assertions_1_var_test() {
47 const auto stream = c10::cuda::getStreamFromPool();
48 TORCH_DSA_KERNEL_LAUNCH(
49 cuda_always_fail_assertion_kernel,
50 1, /* Blocks */
51 1, /* Threads */
52 0, /* Shared mem */
53 stream, /* Stream */
54 1);
55
56 try {
57 c10::cuda::device_synchronize();
58 did_not_fail_diagnostics();
59 throw std::runtime_error("Test didn't fail, but should have.");
60 } catch (const c10::Error& err) {
61 const auto err_str = std::string(err.what());
62 ASSERT_THAT(
63 err_str,
64 HasSubstr("CUDA device-side assertion failures were found on GPU #0!"));
65 ASSERT_THAT(
66 err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]"));
67 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
68 ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
69 ASSERT_THAT(
70 err_str,
71 HasSubstr(
72 "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
73 ASSERT_THAT(
74 err_str, HasSubstr("File containing kernel launch = " __FILE__));
75 ASSERT_THAT(
76 err_str,
77 HasSubstr(
78 "Function containing kernel launch = " +
79 std::string(__FUNCTION__)));
80 ASSERT_THAT(
81 err_str,
82 HasSubstr(
83 "Stream kernel was launched on = " + std::to_string(stream.id())));
84 }
85 }
86
TEST(CUDATest,cuda_device_assertions_1_var_test)87 TEST(CUDATest, cuda_device_assertions_1_var_test) {
88 #ifdef TORCH_USE_CUDA_DSA
89 c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
90 did_not_fail_diagnostics();
91 cuda_device_assertions_1_var_test();
92 #else
93 GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
94 #endif
95 }
96