xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/test_assert.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <stdexcept>
3 #include <stdarg.h>
4 
barf(const char * fmt,...)5 static inline void barf(const char *fmt, ...) {
6   char msg[2048];
7   va_list args;
8   va_start(args, fmt);
9   vsnprintf(msg, 2048, fmt, args);
10   va_end(args);
11   throw std::runtime_error(msg);
12 }
13 
14 #if defined(_MSC_VER) && _MSC_VER <= 1900
15 #define __func__ __FUNCTION__
16 #endif
17 
18 #if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
19 #define AT_EXPECT(x, y) (__builtin_expect((x),(y)))
20 #else
21 #define AT_EXPECT(x, y) (x)
22 #endif
23 
24 #define ASSERT(cond) \
25   if (AT_EXPECT(!(cond), 0)) { \
26     barf("%s:%u: %s: Assertion `%s` failed.", __FILE__, __LINE__, __func__, #cond); \
27   }
28 
29 #define TRY_CATCH_ELSE(fn, catc, els)                           \
30   {                                                             \
31     /* avoid mistakenly passing if els code throws exception*/  \
32     bool _passed = false;                                       \
33     try {                                                       \
34       fn;                                                       \
35       _passed = true;                                           \
36       els;                                                      \
37     } catch (const std::exception &e) {                         \
38       ASSERT(!_passed);                                         \
39       catc;                                                     \
40     }                                                           \
41   }
42 
43 #define ASSERT_THROWSM(fn, message)     \
44   TRY_CATCH_ELSE(fn, ASSERT(std::string(e.what()).find(message) != std::string::npos), ASSERT(false))
45 
46 #define ASSERT_THROWS(fn)  \
47   ASSERT_THROWSM(fn, "");
48 
49 #define ASSERT_EQUAL(t1, t2) \
50   ASSERT(t1.equal(t2));
51 
52 // allclose broadcasts, so check same size before allclose.
53 #define ASSERT_ALLCLOSE(t1, t2)   \
54   ASSERT(t1.is_same_size(t2));    \
55   ASSERT(t1.allclose(t2));
56 
57 // allclose broadcasts, so check same size before allclose.
58 #define ASSERT_ALLCLOSE_TOLERANCES(t1, t2, atol, rtol)   \
59   ASSERT(t1.is_same_size(t2));    \
60   ASSERT(t1.allclose(t2, atol, rtol));
61