1 #include <c10/core/SymFloat.h>
2 #include <c10/core/SymNodeImpl.h>
3 #include <array>
4 #include <cmath>
5 #include <utility>
6
7 namespace c10 {
8
toSymNodeImpl() const9 SymNode SymFloat::toSymNodeImpl() const {
10 TORCH_CHECK(is_symbolic());
11 return SymNode::reclaim_copy(toSymNodeImplUnowned());
12 }
13
wrap_node(const SymNode & base) const14 SymNode SymFloat::wrap_node(const SymNode& base) const {
15 if (is_symbolic()) {
16 return toSymNodeImpl();
17 } else {
18 return base->wrap_float(as_float_unchecked());
19 }
20 }
21
normalize_symfloats(const SymFloat & a_,const SymFloat & b_)22 static std::array<SymNode, 2> normalize_symfloats(
23 const SymFloat& a_,
24 const SymFloat& b_) {
25 SymNode a, b;
26 if (a_.is_symbolic())
27 a = a_.toSymNodeImpl();
28 if (b_.is_symbolic())
29 b = b_.toSymNodeImpl();
30
31 SymNodeImpl* common = a ? a.get() : b.get();
32 if (!a) {
33 a = common->wrap_float(a_.as_float_unchecked());
34 }
35 if (!b) {
36 b = common->wrap_float(b_.as_float_unchecked());
37 }
38 return {std::move(a), std::move(b)};
39 }
40
operator +(const SymFloat & sci) const41 SymFloat SymFloat::operator+(const SymFloat& sci) const {
42 if (!is_symbolic() && !sci.is_symbolic()) {
43 return SymFloat(data_ + sci.data_);
44 }
45 auto res = normalize_symfloats(*this, sci);
46 return SymFloat(res[0]->add(res[1]));
47 }
48
operator -(const SymFloat & sci) const49 SymFloat SymFloat::operator-(const SymFloat& sci) const {
50 if (!is_symbolic() && !sci.is_symbolic()) {
51 return SymFloat(data_ - sci.data_);
52 }
53 auto res = normalize_symfloats(*this, sci);
54 return SymFloat(res[0]->sub(res[1]));
55 }
56
operator *(const SymFloat & sci) const57 SymFloat SymFloat::operator*(const SymFloat& sci) const {
58 if (!is_symbolic() && !sci.is_symbolic()) {
59 return SymFloat(data_ * sci.data_);
60 }
61 auto res = normalize_symfloats(*this, sci);
62 return SymFloat(res[0]->mul(res[1]));
63 }
64
operator /(const SymFloat & sci) const65 SymFloat SymFloat::operator/(const SymFloat& sci) const {
66 if (!is_symbolic() && !sci.is_symbolic()) {
67 return SymFloat(data_ / sci.data_);
68 }
69 auto res = normalize_symfloats(*this, sci);
70 return SymFloat(res[0]->truediv(res[1]));
71 }
72
sym_eq(const SymFloat & sci) const73 SymBool SymFloat::sym_eq(const SymFloat& sci) const {
74 if (!is_symbolic() && !sci.is_symbolic()) {
75 return data_ == sci.data_;
76 }
77 auto res = normalize_symfloats(*this, sci);
78 return res[0]->eq(res[1]);
79 }
80
sym_ne(const SymFloat & sci) const81 SymBool SymFloat::sym_ne(const SymFloat& sci) const {
82 if (!is_symbolic() && !sci.is_symbolic()) {
83 return data_ != sci.data_;
84 }
85 auto res = normalize_symfloats(*this, sci);
86 return res[0]->ne(res[1]);
87 }
88
sym_lt(const SymFloat & sci) const89 SymBool SymFloat::sym_lt(const SymFloat& sci) const {
90 if (!is_symbolic() && !sci.is_symbolic()) {
91 return data_ < sci.data_;
92 }
93 auto res = normalize_symfloats(*this, sci);
94 return res[0]->lt(res[1]);
95 }
96
sym_le(const SymFloat & sci) const97 SymBool SymFloat::sym_le(const SymFloat& sci) const {
98 if (!is_symbolic() && !sci.is_symbolic()) {
99 return data_ <= sci.data_;
100 }
101 auto res = normalize_symfloats(*this, sci);
102 return res[0]->le(res[1]);
103 }
104
sym_gt(const SymFloat & sci) const105 SymBool SymFloat::sym_gt(const SymFloat& sci) const {
106 if (!is_symbolic() && !sci.is_symbolic()) {
107 return data_ > sci.data_;
108 }
109 auto res = normalize_symfloats(*this, sci);
110 return res[0]->gt(res[1]);
111 }
112
sym_ge(const SymFloat & sci) const113 SymBool SymFloat::sym_ge(const SymFloat& sci) const {
114 if (!is_symbolic() && !sci.is_symbolic()) {
115 return data_ >= sci.data_;
116 }
117 auto res = normalize_symfloats(*this, sci);
118 return res[0]->ge(res[1]);
119 }
120
min(const SymFloat & sci) const121 SymFloat SymFloat::min(const SymFloat& sci) const {
122 if (!is_symbolic() && !sci.is_symbolic()) {
123 return std::min(data_, sci.data_);
124 }
125 auto res = normalize_symfloats(*this, sci);
126 return SymFloat(res[0]->sym_min(res[1]));
127 }
max(const SymFloat & sci) const128 SymFloat SymFloat::max(const SymFloat& sci) const {
129 if (!is_symbolic() && !sci.is_symbolic()) {
130 return std::max(data_, sci.data_);
131 }
132 auto res = normalize_symfloats(*this, sci);
133 return SymFloat(res[0]->sym_max(res[1]));
134 }
135
operator <<(std::ostream & os,const SymFloat & s)136 std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
137 if (s.is_symbolic()) {
138 os << s.toSymNodeImpl()->str();
139 } else {
140 os << s.as_float_unchecked();
141 }
142 return os;
143 }
144
sqrt() const145 SymFloat SymFloat::sqrt() const {
146 if (!is_symbolic()) {
147 return SymFloat(std::sqrt(data_));
148 }
149 auto other = SymFloat(-0.5);
150 auto res = normalize_symfloats(*this, other);
151 return SymFloat(res[0]->pow(res[1]));
152 }
153
guard_float(const char * file,int64_t line) const154 double SymFloat::guard_float(const char* file, int64_t line) const {
155 if (!is_symbolic()) {
156 return data_;
157 }
158 SymNode a = toSymNodeImpl();
159 return a->guard_float(file, line);
160 }
161
has_hint() const162 bool SymFloat::has_hint() const {
163 if (!is_symbolic()) {
164 return true;
165 }
166 return toSymNodeImpl()->has_hint();
167 }
168
169 } // namespace c10
170