xref: /aosp_15_r20/external/pytorch/c10/core/SymFloat.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymBool.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/macros/Export.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/intrusive_ptr.h>
9 
10 #include <cstdint>
11 #include <limits>
12 #include <ostream>
13 #include <utility>
14 
15 namespace c10 {
16 
17 // NB: this is actually double precision; we're using the Python naming here
18 class C10_API SymFloat {
19  public:
SymFloat(double d)20   /*implicit*/ SymFloat(double d) : data_(d){};
SymFloat(SymNode ptr)21   SymFloat(SymNode ptr)
22       : data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
23     TORCH_CHECK(ptr_->is_float());
24   };
SymFloat()25   SymFloat() : data_(0.0) {}
26 
toSymNodeImplUnowned()27   SymNodeImpl* toSymNodeImplUnowned() const {
28     return ptr_.get();
29   }
30 
release()31   SymNodeImpl* release() && {
32     return std::move(ptr_).release();
33   }
34 
35   // Only valid if is_symbolic()
36   SymNode toSymNodeImpl() const;
37 
38   // Guaranteed to return a SymNode, wrapping using base if necessary
39   SymNode wrap_node(const SymNode& base) const;
40 
expect_float()41   double expect_float() const {
42     TORCH_CHECK(!is_symbolic());
43     return data_;
44   }
45 
46   SymFloat operator+(const SymFloat&) const;
47   SymFloat operator-(const SymFloat&) const;
48   SymFloat operator*(const SymFloat&) const;
49   SymFloat operator/(const SymFloat&) const;
50 
51   SymBool sym_eq(const SymFloat&) const;
52   SymBool sym_ne(const SymFloat&) const;
53   SymBool sym_lt(const SymFloat&) const;
54   SymBool sym_le(const SymFloat&) const;
55   SymBool sym_gt(const SymFloat&) const;
56   SymBool sym_ge(const SymFloat&) const;
57 
58   bool operator==(const SymFloat& o) const {
59     return sym_eq(o).guard_bool(__FILE__, __LINE__);
60   }
61   bool operator!=(const SymFloat& o) const {
62     return sym_ne(o).guard_bool(__FILE__, __LINE__);
63   }
64   bool operator<(const SymFloat& o) const {
65     return sym_lt(o).guard_bool(__FILE__, __LINE__);
66   }
67   bool operator<=(const SymFloat& o) const {
68     return sym_le(o).guard_bool(__FILE__, __LINE__);
69   }
70   bool operator>(const SymFloat& o) const {
71     return sym_gt(o).guard_bool(__FILE__, __LINE__);
72   }
73   bool operator>=(const SymFloat& o) const {
74     return sym_ge(o).guard_bool(__FILE__, __LINE__);
75   }
76 
77   SymFloat min(const SymFloat& sci) const;
78   SymFloat max(const SymFloat& sci) const;
79 
80   // Need guidance on where to put this code
81   SymFloat sqrt() const;
82 
83   // Insert a guard for the float to be its concrete value, and then return
84   // that value.  This operation always works, even if the float is symbolic,
85   // so long as we know what the underlying value is. Don't blindly put this
86   // everywhere; you can cause overspecialization of PyTorch programs with
87   // this method.
88   //
89   // It should be called as guard_float(__FILE__, __LINE__).  The file and line
90   // number can be used to diagnose overspecialization.
91   double guard_float(const char* file, int64_t line) const;
92 
93   bool has_hint() const;
94 
95   // N.B. It's important to keep this definition in the header
96   // as we expect if checks to be folded for mobile builds
97   // where `is_symbolic` is always false
is_symbolic()98   C10_ALWAYS_INLINE bool is_symbolic() const {
99     return ptr_;
100   }
101 
as_float_unchecked()102   double as_float_unchecked() const {
103     return data_;
104   }
105 
106  private:
107   // TODO: optimize to union
108   double data_;
109   SymNode ptr_;
110 };
111 
112 C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s);
113 } // namespace c10
114