1 #pragma once
2
3 #include <c10/core/SymNodeImpl.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/intrusive_ptr.h>
7 #include <cstdint>
8 #include <optional>
9 #include <ostream>
10 #include <utility>
11
12 namespace c10 {
13
14 class C10_API SymBool {
15 public:
SymBool(bool b)16 /*implicit*/ SymBool(bool b) : data_(b){};
SymBool(SymNode ptr)17 SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
18 TORCH_CHECK(ptr_->is_bool());
19 };
SymBool()20 SymBool() : data_(false) {}
21
toSymNodeImplUnowned()22 SymNodeImpl* toSymNodeImplUnowned() const {
23 return ptr_.get();
24 }
25
release()26 SymNodeImpl* release() && {
27 return std::move(ptr_).release();
28 }
29
30 // Only valid if is_heap_allocated()
31 SymNode toSymNodeImpl() const;
32
33 // Guaranteed to return a SymNode, wrapping using base if necessary
34 SymNode wrap_node(const SymNode& base) const;
35
expect_bool()36 bool expect_bool() const {
37 std::optional<bool> c = maybe_as_bool();
38 TORCH_CHECK(c.has_value());
39 return *c;
40 }
41
42 SymBool sym_and(const SymBool&) const;
43 SymBool sym_or(const SymBool&) const;
44 SymBool sym_not() const;
45
46 SymBool operator&(const SymBool& other) const {
47 return sym_and(other);
48 }
49 SymBool operator|(const SymBool& other) const {
50 return sym_or(other);
51 }
52 SymBool operator~() const {
53 return sym_not();
54 }
55
56 // Insert a guard for the bool to be its concrete value, and then return
57 // that value. Note that C++ comparison operations default to returning
58 // bool, so it's not so common to have to call this
59 bool guard_bool(const char* file, int64_t line) const;
60 bool expect_true(const char* file, int64_t line) const;
61 bool guard_size_oblivious(const char* file, int64_t line) const;
62
63 bool has_hint() const;
64
as_bool_unchecked()65 bool as_bool_unchecked() const {
66 return data_;
67 }
68
maybe_as_bool()69 std::optional<bool> maybe_as_bool() const {
70 if (!is_heap_allocated()) {
71 return std::make_optional(data_);
72 }
73 return toSymNodeImplUnowned()->constant_bool();
74 }
75
is_heap_allocated()76 bool is_heap_allocated() const {
77 return ptr_;
78 }
79
80 private:
81 // TODO: optimize to union
82 bool data_;
83 SymNode ptr_;
84 };
85
86 C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
87
88 #define TORCH_SYM_CHECK(cond, ...) \
89 TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
90 #define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
91 TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
92
guard_size_oblivious(bool b,const char * file,int64_t line)93 inline bool guard_size_oblivious(
94 bool b,
95 const char* file [[maybe_unused]],
96 int64_t line [[maybe_unused]]) {
97 return b;
98 }
99
guard_size_oblivious(const c10::SymBool & b,const char * file,int64_t line)100 inline bool guard_size_oblivious(
101 const c10::SymBool& b,
102 const char* file,
103 int64_t line) {
104 return b.guard_size_oblivious(file, line);
105 }
106
107 #define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
108 c10::guard_size_oblivious((cond), __FILE__, __LINE__)
109
110 } // namespace c10
111