1 #pragma once 2 3 #include <c10/core/SymBool.h> 4 #include <c10/core/SymNodeImpl.h> 5 #include <c10/macros/Export.h> 6 #include <c10/macros/Macros.h> 7 #include <c10/util/Exception.h> 8 #include <c10/util/intrusive_ptr.h> 9 10 #include <cstdint> 11 #include <limits> 12 #include <ostream> 13 #include <utility> 14 15 namespace c10 { 16 17 // NB: this is actually double precision; we're using the Python naming here 18 class C10_API SymFloat { 19 public: SymFloat(double d)20 /*implicit*/ SymFloat(double d) : data_(d){}; SymFloat(SymNode ptr)21 SymFloat(SymNode ptr) 22 : data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) { 23 TORCH_CHECK(ptr_->is_float()); 24 }; SymFloat()25 SymFloat() : data_(0.0) {} 26 toSymNodeImplUnowned()27 SymNodeImpl* toSymNodeImplUnowned() const { 28 return ptr_.get(); 29 } 30 release()31 SymNodeImpl* release() && { 32 return std::move(ptr_).release(); 33 } 34 35 // Only valid if is_symbolic() 36 SymNode toSymNodeImpl() const; 37 38 // Guaranteed to return a SymNode, wrapping using base if necessary 39 SymNode wrap_node(const SymNode& base) const; 40 expect_float()41 double expect_float() const { 42 TORCH_CHECK(!is_symbolic()); 43 return data_; 44 } 45 46 SymFloat operator+(const SymFloat&) const; 47 SymFloat operator-(const SymFloat&) const; 48 SymFloat operator*(const SymFloat&) const; 49 SymFloat operator/(const SymFloat&) const; 50 51 SymBool sym_eq(const SymFloat&) const; 52 SymBool sym_ne(const SymFloat&) const; 53 SymBool sym_lt(const SymFloat&) const; 54 SymBool sym_le(const SymFloat&) const; 55 SymBool sym_gt(const SymFloat&) const; 56 SymBool sym_ge(const SymFloat&) const; 57 58 bool operator==(const SymFloat& o) const { 59 return sym_eq(o).guard_bool(__FILE__, __LINE__); 60 } 61 bool operator!=(const SymFloat& o) const { 62 return sym_ne(o).guard_bool(__FILE__, __LINE__); 63 } 64 bool operator<(const SymFloat& o) const { 65 return sym_lt(o).guard_bool(__FILE__, __LINE__); 66 } 67 bool operator<=(const SymFloat& o) const { 68 return sym_le(o).guard_bool(__FILE__, __LINE__); 69 } 70 bool operator>(const SymFloat& o) const { 71 return sym_gt(o).guard_bool(__FILE__, __LINE__); 72 } 73 bool operator>=(const SymFloat& o) const { 74 return sym_ge(o).guard_bool(__FILE__, __LINE__); 75 } 76 77 SymFloat min(const SymFloat& sci) const; 78 SymFloat max(const SymFloat& sci) const; 79 80 // Need guidance on where to put this code 81 SymFloat sqrt() const; 82 83 // Insert a guard for the float to be its concrete value, and then return 84 // that value. This operation always works, even if the float is symbolic, 85 // so long as we know what the underlying value is. Don't blindly put this 86 // everywhere; you can cause overspecialization of PyTorch programs with 87 // this method. 88 // 89 // It should be called as guard_float(__FILE__, __LINE__). The file and line 90 // number can be used to diagnose overspecialization. 91 double guard_float(const char* file, int64_t line) const; 92 93 bool has_hint() const; 94 95 // N.B. It's important to keep this definition in the header 96 // as we expect if checks to be folded for mobile builds 97 // where `is_symbolic` is always false is_symbolic()98 C10_ALWAYS_INLINE bool is_symbolic() const { 99 return ptr_; 100 } 101 as_float_unchecked()102 double as_float_unchecked() const { 103 return data_; 104 } 105 106 private: 107 // TODO: optimize to union 108 double data_; 109 SymNode ptr_; 110 }; 111 112 C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s); 113 } // namespace c10 114