1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13
14 using ::testing::HasSubstr;
15
16 const auto max_assertions_failure_str =
17 "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
18
19 /**
20 * Device kernel that takes a single integer parameter as argument and
21 * will always trigger a device side assertion.
22 */
cuda_always_fail_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)23 __global__ void cuda_always_fail_assertion_kernel(
24 const int a,
25 TORCH_DSA_KERNEL_ARGS) {
26 CUDA_KERNEL_ASSERT2(a != a);
27 }
28
29 /**
30 * TEST: Triggering device side assertion from multiple block but single thread
31 * <<<10,1>>>. Here we are triggering assertion on 10 blocks, each with only 1
32 * thread. Since we have more than 10 SM on a GPU, we expect each block to be
33 * executed and successfully assert, Hence we will see assertions logged from
34 * each block here.
35 */
cuda_device_assertions_multiple_writes_from_multiple_blocks()36 void cuda_device_assertions_multiple_writes_from_multiple_blocks() {
37 const auto stream = c10::cuda::getStreamFromPool();
38 TORCH_DSA_KERNEL_LAUNCH(
39 cuda_always_fail_assertion_kernel,
40 10, /* Blocks */
41 1, /* Threads */
42 0, /* Shared mem */
43 stream, /* Stream */
44 1);
45
46 try {
47 c10::cuda::device_synchronize();
48 throw std::runtime_error("Test didn't fail, but should have.");
49 } catch (const c10::Error& err) {
50 const auto err_str = std::string(err.what());
51 ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
52 ASSERT_THAT(
53 err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]"));
54 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
55 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [1,0,0]"));
56 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [2,0,0]"));
57 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [3,0,0]"));
58 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [4,0,0]"));
59 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [5,0,0]"));
60 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [6,0,0]"));
61 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [7,0,0]"));
62 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [8,0,0]"));
63 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [9,0,0]"));
64 ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
65 ASSERT_THAT(
66 err_str,
67 HasSubstr(
68 "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
69 ASSERT_THAT(
70 err_str, HasSubstr("File containing kernel launch = " __FILE__));
71 ASSERT_THAT(
72 err_str,
73 HasSubstr(
74 "Function containing kernel launch = " +
75 std::string(__FUNCTION__)));
76 ASSERT_THAT(
77 err_str,
78 HasSubstr(
79 "Stream kernel was launched on = " + std::to_string(stream.id())));
80 }
81 }
82
TEST(CUDATest,cuda_device_assertions_multiple_writes_from_multiple_blocks)83 TEST(CUDATest, cuda_device_assertions_multiple_writes_from_multiple_blocks) {
84 #ifdef TORCH_USE_CUDA_DSA
85 c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
86 cuda_device_assertions_multiple_writes_from_multiple_blocks();
87 #else
88 GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
89 #endif
90 }
91