1 #include <c10/util/Exception.h>
2 #include <gtest/gtest.h>
3 #include <stdexcept>
4
5 using c10::Error;
6
7 namespace {
8
9 template <class Functor>
expectThrowsEq(Functor && functor,const char * expectedMessage)10 inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
11 try {
12 std::forward<Functor>(functor)();
13 } catch (const Error& e) {
14 EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
15 return;
16 }
17 ADD_FAILURE() << "Expected to throw exception with message \""
18 << expectedMessage << "\" but didn't throw";
19 }
20 } // namespace
21
TEST(ExceptionTest,TORCH_INTERNAL_ASSERT_DEBUG_ONLY)22 TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
23 #ifdef NDEBUG
24 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
25 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false));
26 // Does nothing - `throw ...` should not be evaluated
27 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
28 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29 (throw std::runtime_error("I'm throwing..."), true)));
30 #else
31 ASSERT_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false), c10::Error);
32 ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(true));
33 #endif
34 }
35
36 // On these platforms there's no assert
37 #if !defined(__ANDROID__) && !defined(__APPLE__)
TEST(ExceptionTest,CUDA_KERNEL_ASSERT)38 TEST(ExceptionTest, CUDA_KERNEL_ASSERT) {
39 // This function always throws even in NDEBUG mode
40 ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert");
41 }
42 #endif
43
TEST(WarningTest,JustPrintWarning)44 TEST(WarningTest, JustPrintWarning) {
45 TORCH_WARN("I'm a warning");
46 }
47
TEST(ExceptionTest,ErrorFormatting)48 TEST(ExceptionTest, ErrorFormatting) {
49 expectThrowsEq(
50 []() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid");
51
52 expectThrowsEq(
53 []() {
54 try {
55 TORCH_CHECK(false, "This is invalid");
56 } catch (Error& e) {
57 TORCH_RETHROW(e, "While checking X");
58 }
59 },
60 "This is invalid (While checking X)");
61
62 expectThrowsEq(
63 []() {
64 try {
65 try {
66 TORCH_CHECK(false, "This is invalid");
67 } catch (Error& e) {
68 TORCH_RETHROW(e, "While checking X");
69 }
70 } catch (Error& e) {
71 TORCH_RETHROW(e, "While checking Y");
72 }
73 },
74 R"msg(This is invalid
75 While checking X
76 While checking Y)msg");
77 }
78
79 static int assertionArgumentCounter = 0;
getAssertionArgument()80 static int getAssertionArgument() {
81 return ++assertionArgumentCounter;
82 }
83
failCheck()84 static void failCheck() {
85 TORCH_CHECK(false, "message ", getAssertionArgument());
86 }
87
failInternalAssert()88 static void failInternalAssert() {
89 TORCH_INTERNAL_ASSERT(false, "message ", getAssertionArgument());
90 }
91
TEST(ExceptionTest,DontCallArgumentFunctionsTwiceOnFailure)92 TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) {
93 assertionArgumentCounter = 0;
94 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
95 EXPECT_ANY_THROW(failCheck());
96 EXPECT_EQ(assertionArgumentCounter, 1) << "TORCH_CHECK called argument twice";
97
98 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
99 EXPECT_ANY_THROW(failInternalAssert());
100 EXPECT_EQ(assertionArgumentCounter, 2)
101 << "TORCH_INTERNAL_ASSERT called argument twice";
102 }
103