1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13
14 using ::testing::HasSubstr;
15
16 const auto max_assertions_failure_str =
17 "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
18
19 /**
20 * Device kernel that takes a single integer parameter as argument and
21 * will always trigger a device side assertion.
22 */
cuda_always_fail_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)23 __global__ void cuda_always_fail_assertion_kernel(
24 const int a,
25 TORCH_DSA_KERNEL_ARGS) {
26 CUDA_KERNEL_ASSERT2(a != a);
27 }
28
29 /**
30 * TEST: Triggering device side assertion from single block and multiple threads
31 * <<<1,128>>>. Once the very first thread asserts all the other threads will
32 * basically be in bad state and the block id with failed assertion would be
33 * [0,0,0].
34 */
cuda_device_assertions_multiple_writes_from_same_block()35 void cuda_device_assertions_multiple_writes_from_same_block() {
36 const auto stream = c10::cuda::getStreamFromPool();
37 TORCH_DSA_KERNEL_LAUNCH(
38 cuda_always_fail_assertion_kernel,
39 1, /* Blocks */
40 128, /* Threads */
41 0, /* Shared mem */
42 stream, /* Stream */
43 1);
44
45 try {
46 c10::cuda::device_synchronize();
47 throw std::runtime_error("Test didn't fail, but should have.");
48 } catch (const c10::Error& err) {
49 const auto err_str = std::string(err.what());
50 ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
51 ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
52 ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
53 ASSERT_THAT(
54 err_str,
55 HasSubstr(
56 "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
57 ASSERT_THAT(
58 err_str, HasSubstr("File containing kernel launch = " __FILE__));
59 ASSERT_THAT(
60 err_str,
61 HasSubstr(
62 "Function containing kernel launch = " +
63 std::string(__FUNCTION__)));
64 ASSERT_THAT(
65 err_str,
66 HasSubstr(
67 "Stream kernel was launched on = " + std::to_string(stream.id())));
68 }
69 }
70
TEST(CUDATest,cuda_device_assertions_multiple_writes_from_same_block)71 TEST(CUDATest, cuda_device_assertions_multiple_writes_from_same_block) {
72 #ifdef TORCH_USE_CUDA_DSA
73 c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
74 cuda_device_assertions_multiple_writes_from_same_block();
75 #else
76 GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
77 #endif
78 }
79