xref: /aosp_15_r20/external/pytorch/c10/test/util/exception_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <gtest/gtest.h>
3 #include <stdexcept>
4 
5 using c10::Error;
6 
7 namespace {
8 
9 template <class Functor>
expectThrowsEq(Functor && functor,const char * expectedMessage)10 inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
11   try {
12     std::forward<Functor>(functor)();
13   } catch (const Error& e) {
14     EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
15     return;
16   }
17   ADD_FAILURE() << "Expected to throw exception with message \""
18                 << expectedMessage << "\" but didn't throw";
19 }
20 } // namespace
21 
TEST(ExceptionTest,TORCH_INTERNAL_ASSERT_DEBUG_ONLY)22 TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
23 #ifdef NDEBUG
24   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
25   ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false));
26   // Does nothing - `throw ...` should not be evaluated
27   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
28   ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
29       (throw std::runtime_error("I'm throwing..."), true)));
30 #else
31   ASSERT_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false), c10::Error);
32   ASSERT_NO_THROW(TORCH_INTERNAL_ASSERT_DEBUG_ONLY(true));
33 #endif
34 }
35 
36 // On these platforms there's no assert
37 #if !defined(__ANDROID__) && !defined(__APPLE__)
TEST(ExceptionTest,CUDA_KERNEL_ASSERT)38 TEST(ExceptionTest, CUDA_KERNEL_ASSERT) {
39   // This function always throws even in NDEBUG mode
40   ASSERT_DEATH_IF_SUPPORTED({ CUDA_KERNEL_ASSERT(false); }, "Assert");
41 }
42 #endif
43 
TEST(WarningTest,JustPrintWarning)44 TEST(WarningTest, JustPrintWarning) {
45   TORCH_WARN("I'm a warning");
46 }
47 
TEST(ExceptionTest,ErrorFormatting)48 TEST(ExceptionTest, ErrorFormatting) {
49   expectThrowsEq(
50       []() { TORCH_CHECK(false, "This is invalid"); }, "This is invalid");
51 
52   expectThrowsEq(
53       []() {
54         try {
55           TORCH_CHECK(false, "This is invalid");
56         } catch (Error& e) {
57           TORCH_RETHROW(e, "While checking X");
58         }
59       },
60       "This is invalid (While checking X)");
61 
62   expectThrowsEq(
63       []() {
64         try {
65           try {
66             TORCH_CHECK(false, "This is invalid");
67           } catch (Error& e) {
68             TORCH_RETHROW(e, "While checking X");
69           }
70         } catch (Error& e) {
71           TORCH_RETHROW(e, "While checking Y");
72         }
73       },
74       R"msg(This is invalid
75   While checking X
76   While checking Y)msg");
77 }
78 
79 static int assertionArgumentCounter = 0;
getAssertionArgument()80 static int getAssertionArgument() {
81   return ++assertionArgumentCounter;
82 }
83 
failCheck()84 static void failCheck() {
85   TORCH_CHECK(false, "message ", getAssertionArgument());
86 }
87 
failInternalAssert()88 static void failInternalAssert() {
89   TORCH_INTERNAL_ASSERT(false, "message ", getAssertionArgument());
90 }
91 
TEST(ExceptionTest,DontCallArgumentFunctionsTwiceOnFailure)92 TEST(ExceptionTest, DontCallArgumentFunctionsTwiceOnFailure) {
93   assertionArgumentCounter = 0;
94   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
95   EXPECT_ANY_THROW(failCheck());
96   EXPECT_EQ(assertionArgumentCounter, 1) << "TORCH_CHECK called argument twice";
97 
98   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
99   EXPECT_ANY_THROW(failInternalAssert());
100   EXPECT_EQ(assertionArgumentCounter, 2)
101       << "TORCH_INTERNAL_ASSERT called argument twice";
102 }
103