xref: /aosp_15_r20/external/pytorch/c10/core/SymFloat.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/SymFloat.h>
2 #include <c10/core/SymNodeImpl.h>
3 #include <array>
4 #include <cmath>
5 #include <utility>
6 
7 namespace c10 {
8 
toSymNodeImpl() const9 SymNode SymFloat::toSymNodeImpl() const {
10   TORCH_CHECK(is_symbolic());
11   return SymNode::reclaim_copy(toSymNodeImplUnowned());
12 }
13 
wrap_node(const SymNode & base) const14 SymNode SymFloat::wrap_node(const SymNode& base) const {
15   if (is_symbolic()) {
16     return toSymNodeImpl();
17   } else {
18     return base->wrap_float(as_float_unchecked());
19   }
20 }
21 
normalize_symfloats(const SymFloat & a_,const SymFloat & b_)22 static std::array<SymNode, 2> normalize_symfloats(
23     const SymFloat& a_,
24     const SymFloat& b_) {
25   SymNode a, b;
26   if (a_.is_symbolic())
27     a = a_.toSymNodeImpl();
28   if (b_.is_symbolic())
29     b = b_.toSymNodeImpl();
30 
31   SymNodeImpl* common = a ? a.get() : b.get();
32   if (!a) {
33     a = common->wrap_float(a_.as_float_unchecked());
34   }
35   if (!b) {
36     b = common->wrap_float(b_.as_float_unchecked());
37   }
38   return {std::move(a), std::move(b)};
39 }
40 
operator +(const SymFloat & sci) const41 SymFloat SymFloat::operator+(const SymFloat& sci) const {
42   if (!is_symbolic() && !sci.is_symbolic()) {
43     return SymFloat(data_ + sci.data_);
44   }
45   auto res = normalize_symfloats(*this, sci);
46   return SymFloat(res[0]->add(res[1]));
47 }
48 
operator -(const SymFloat & sci) const49 SymFloat SymFloat::operator-(const SymFloat& sci) const {
50   if (!is_symbolic() && !sci.is_symbolic()) {
51     return SymFloat(data_ - sci.data_);
52   }
53   auto res = normalize_symfloats(*this, sci);
54   return SymFloat(res[0]->sub(res[1]));
55 }
56 
operator *(const SymFloat & sci) const57 SymFloat SymFloat::operator*(const SymFloat& sci) const {
58   if (!is_symbolic() && !sci.is_symbolic()) {
59     return SymFloat(data_ * sci.data_);
60   }
61   auto res = normalize_symfloats(*this, sci);
62   return SymFloat(res[0]->mul(res[1]));
63 }
64 
operator /(const SymFloat & sci) const65 SymFloat SymFloat::operator/(const SymFloat& sci) const {
66   if (!is_symbolic() && !sci.is_symbolic()) {
67     return SymFloat(data_ / sci.data_);
68   }
69   auto res = normalize_symfloats(*this, sci);
70   return SymFloat(res[0]->truediv(res[1]));
71 }
72 
sym_eq(const SymFloat & sci) const73 SymBool SymFloat::sym_eq(const SymFloat& sci) const {
74   if (!is_symbolic() && !sci.is_symbolic()) {
75     return data_ == sci.data_;
76   }
77   auto res = normalize_symfloats(*this, sci);
78   return res[0]->eq(res[1]);
79 }
80 
sym_ne(const SymFloat & sci) const81 SymBool SymFloat::sym_ne(const SymFloat& sci) const {
82   if (!is_symbolic() && !sci.is_symbolic()) {
83     return data_ != sci.data_;
84   }
85   auto res = normalize_symfloats(*this, sci);
86   return res[0]->ne(res[1]);
87 }
88 
sym_lt(const SymFloat & sci) const89 SymBool SymFloat::sym_lt(const SymFloat& sci) const {
90   if (!is_symbolic() && !sci.is_symbolic()) {
91     return data_ < sci.data_;
92   }
93   auto res = normalize_symfloats(*this, sci);
94   return res[0]->lt(res[1]);
95 }
96 
sym_le(const SymFloat & sci) const97 SymBool SymFloat::sym_le(const SymFloat& sci) const {
98   if (!is_symbolic() && !sci.is_symbolic()) {
99     return data_ <= sci.data_;
100   }
101   auto res = normalize_symfloats(*this, sci);
102   return res[0]->le(res[1]);
103 }
104 
sym_gt(const SymFloat & sci) const105 SymBool SymFloat::sym_gt(const SymFloat& sci) const {
106   if (!is_symbolic() && !sci.is_symbolic()) {
107     return data_ > sci.data_;
108   }
109   auto res = normalize_symfloats(*this, sci);
110   return res[0]->gt(res[1]);
111 }
112 
sym_ge(const SymFloat & sci) const113 SymBool SymFloat::sym_ge(const SymFloat& sci) const {
114   if (!is_symbolic() && !sci.is_symbolic()) {
115     return data_ >= sci.data_;
116   }
117   auto res = normalize_symfloats(*this, sci);
118   return res[0]->ge(res[1]);
119 }
120 
min(const SymFloat & sci) const121 SymFloat SymFloat::min(const SymFloat& sci) const {
122   if (!is_symbolic() && !sci.is_symbolic()) {
123     return std::min(data_, sci.data_);
124   }
125   auto res = normalize_symfloats(*this, sci);
126   return SymFloat(res[0]->sym_min(res[1]));
127 }
max(const SymFloat & sci) const128 SymFloat SymFloat::max(const SymFloat& sci) const {
129   if (!is_symbolic() && !sci.is_symbolic()) {
130     return std::max(data_, sci.data_);
131   }
132   auto res = normalize_symfloats(*this, sci);
133   return SymFloat(res[0]->sym_max(res[1]));
134 }
135 
operator <<(std::ostream & os,const SymFloat & s)136 std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
137   if (s.is_symbolic()) {
138     os << s.toSymNodeImpl()->str();
139   } else {
140     os << s.as_float_unchecked();
141   }
142   return os;
143 }
144 
sqrt() const145 SymFloat SymFloat::sqrt() const {
146   if (!is_symbolic()) {
147     return SymFloat(std::sqrt(data_));
148   }
149   auto other = SymFloat(-0.5);
150   auto res = normalize_symfloats(*this, other);
151   return SymFloat(res[0]->pow(res[1]));
152 }
153 
guard_float(const char * file,int64_t line) const154 double SymFloat::guard_float(const char* file, int64_t line) const {
155   if (!is_symbolic()) {
156     return data_;
157   }
158   SymNode a = toSymNodeImpl();
159   return a->guard_float(file, line);
160 }
161 
has_hint() const162 bool SymFloat::has_hint() const {
163   if (!is_symbolic()) {
164     return true;
165   }
166   return toSymNodeImpl()->has_hint();
167 }
168 
169 } // namespace c10
170