xref: /aosp_15_r20/external/pytorch/c10/core/SymBool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymNodeImpl.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/intrusive_ptr.h>
7 #include <cstdint>
8 #include <optional>
9 #include <ostream>
10 #include <utility>
11 
12 namespace c10 {
13 
14 class C10_API SymBool {
15  public:
SymBool(bool b)16   /*implicit*/ SymBool(bool b) : data_(b){};
SymBool(SymNode ptr)17   SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
18     TORCH_CHECK(ptr_->is_bool());
19   };
SymBool()20   SymBool() : data_(false) {}
21 
toSymNodeImplUnowned()22   SymNodeImpl* toSymNodeImplUnowned() const {
23     return ptr_.get();
24   }
25 
release()26   SymNodeImpl* release() && {
27     return std::move(ptr_).release();
28   }
29 
30   // Only valid if is_heap_allocated()
31   SymNode toSymNodeImpl() const;
32 
33   // Guaranteed to return a SymNode, wrapping using base if necessary
34   SymNode wrap_node(const SymNode& base) const;
35 
expect_bool()36   bool expect_bool() const {
37     std::optional<bool> c = maybe_as_bool();
38     TORCH_CHECK(c.has_value());
39     return *c;
40   }
41 
42   SymBool sym_and(const SymBool&) const;
43   SymBool sym_or(const SymBool&) const;
44   SymBool sym_not() const;
45 
46   SymBool operator&(const SymBool& other) const {
47     return sym_and(other);
48   }
49   SymBool operator|(const SymBool& other) const {
50     return sym_or(other);
51   }
52   SymBool operator~() const {
53     return sym_not();
54   }
55 
56   // Insert a guard for the bool to be its concrete value, and then return
57   // that value.  Note that C++ comparison operations default to returning
58   // bool, so it's not so common to have to call this
59   bool guard_bool(const char* file, int64_t line) const;
60   bool expect_true(const char* file, int64_t line) const;
61   bool guard_size_oblivious(const char* file, int64_t line) const;
62 
63   bool has_hint() const;
64 
as_bool_unchecked()65   bool as_bool_unchecked() const {
66     return data_;
67   }
68 
maybe_as_bool()69   std::optional<bool> maybe_as_bool() const {
70     if (!is_heap_allocated()) {
71       return std::make_optional(data_);
72     }
73     return toSymNodeImplUnowned()->constant_bool();
74   }
75 
is_heap_allocated()76   bool is_heap_allocated() const {
77     return ptr_;
78   }
79 
80  private:
81   // TODO: optimize to union
82   bool data_;
83   SymNode ptr_;
84 };
85 
86 C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
87 
88 #define TORCH_SYM_CHECK(cond, ...) \
89   TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
90 #define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
91   TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
92 
guard_size_oblivious(bool b,const char * file,int64_t line)93 inline bool guard_size_oblivious(
94     bool b,
95     const char* file [[maybe_unused]],
96     int64_t line [[maybe_unused]]) {
97   return b;
98 }
99 
guard_size_oblivious(const c10::SymBool & b,const char * file,int64_t line)100 inline bool guard_size_oblivious(
101     const c10::SymBool& b,
102     const char* file,
103     int64_t line) {
104   return b.guard_size_oblivious(file, line);
105 }
106 
107 #define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
108   c10::guard_size_oblivious((cond), __FILE__, __LINE__)
109 
110 } // namespace c10
111