1 #pragma once
2
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5
6 C10_CLANG_DIAGNOSTIC_PUSH()
7 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
8 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
9 #endif
10
11 namespace std {
12
13 template <typename T>
14 struct is_reduced_floating_point
15 : std::integral_constant<
16 bool,
17 std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
18
19 template <typename T>
20 constexpr bool is_reduced_floating_point_v =
21 is_reduced_floating_point<T>::value;
22
23 template <
24 typename T,
25 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
acos(T a)26 inline T acos(T a) {
27 return std::acos(float(a));
28 }
29 template <
30 typename T,
31 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
asin(T a)32 inline T asin(T a) {
33 return std::asin(float(a));
34 }
35 template <
36 typename T,
37 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
atan(T a)38 inline T atan(T a) {
39 return std::atan(float(a));
40 }
41 template <
42 typename T,
43 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
atanh(T a)44 inline T atanh(T a) {
45 return std::atanh(float(a));
46 }
47 template <
48 typename T,
49 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
erf(T a)50 inline T erf(T a) {
51 return std::erf(float(a));
52 }
53 template <
54 typename T,
55 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
erfc(T a)56 inline T erfc(T a) {
57 return std::erfc(float(a));
58 }
59 template <
60 typename T,
61 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
exp(T a)62 inline T exp(T a) {
63 return std::exp(float(a));
64 }
65 template <
66 typename T,
67 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
expm1(T a)68 inline T expm1(T a) {
69 return std::expm1(float(a));
70 }
71 template <
72 typename T,
73 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
isfinite(T a)74 inline bool isfinite(T a) {
75 return std::isfinite(float(a));
76 }
77 template <
78 typename T,
79 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log(T a)80 inline T log(T a) {
81 return std::log(float(a));
82 }
83 template <
84 typename T,
85 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log10(T a)86 inline T log10(T a) {
87 return std::log10(float(a));
88 }
89 template <
90 typename T,
91 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log1p(T a)92 inline T log1p(T a) {
93 return std::log1p(float(a));
94 }
95 template <
96 typename T,
97 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log2(T a)98 inline T log2(T a) {
99 return std::log2(float(a));
100 }
101 template <
102 typename T,
103 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
ceil(T a)104 inline T ceil(T a) {
105 return std::ceil(float(a));
106 }
107 template <
108 typename T,
109 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
cos(T a)110 inline T cos(T a) {
111 return std::cos(float(a));
112 }
113 template <
114 typename T,
115 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
floor(T a)116 inline T floor(T a) {
117 return std::floor(float(a));
118 }
119 template <
120 typename T,
121 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
nearbyint(T a)122 inline T nearbyint(T a) {
123 return std::nearbyint(float(a));
124 }
125 template <
126 typename T,
127 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sin(T a)128 inline T sin(T a) {
129 return std::sin(float(a));
130 }
131 template <
132 typename T,
133 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
tan(T a)134 inline T tan(T a) {
135 return std::tan(float(a));
136 }
137 template <
138 typename T,
139 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sinh(T a)140 inline T sinh(T a) {
141 return std::sinh(float(a));
142 }
143 template <
144 typename T,
145 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
cosh(T a)146 inline T cosh(T a) {
147 return std::cosh(float(a));
148 }
149 template <
150 typename T,
151 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
tanh(T a)152 inline T tanh(T a) {
153 return std::tanh(float(a));
154 }
155 template <
156 typename T,
157 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
trunc(T a)158 inline T trunc(T a) {
159 return std::trunc(float(a));
160 }
161 template <
162 typename T,
163 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
lgamma(T a)164 inline T lgamma(T a) {
165 return std::lgamma(float(a));
166 }
167 template <
168 typename T,
169 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sqrt(T a)170 inline T sqrt(T a) {
171 return std::sqrt(float(a));
172 }
173 template <
174 typename T,
175 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
rsqrt(T a)176 inline T rsqrt(T a) {
177 return 1.0 / std::sqrt(float(a));
178 }
179 template <
180 typename T,
181 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
abs(T a)182 inline T abs(T a) {
183 return std::abs(float(a));
184 }
185 #if defined(_MSC_VER) && defined(__CUDACC__)
186 template <
187 typename T,
188 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,double b)189 inline T pow(T a, double b) {
190 return std::pow(float(a), float(b));
191 }
192 #else
193 template <
194 typename T,
195 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,double b)196 inline T pow(T a, double b) {
197 return std::pow(float(a), b);
198 }
199 #endif
200 template <
201 typename T,
202 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,T b)203 inline T pow(T a, T b) {
204 return std::pow(float(a), float(b));
205 }
206 template <
207 typename T,
208 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
fmod(T a,T b)209 inline T fmod(T a, T b) {
210 return std::fmod(float(a), float(b));
211 }
212
213 /*
214 The following function is inspired from the implementation in `musl`
215 Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
216 ----------------------------------------------------------------------
217 Copyright © 2005-2020 Rich Felker, et al.
218
219 Permission is hereby granted, free of charge, to any person obtaining
220 a copy of this software and associated documentation files (the
221 "Software"), to deal in the Software without restriction, including
222 without limitation the rights to use, copy, modify, merge, publish,
223 distribute, sublicense, and/or sell copies of the Software, and to
224 permit persons to whom the Software is furnished to do so, subject to
225 the following conditions:
226
227 The above copyright notice and this permission notice shall be
228 included in all copies or substantial portions of the Software.
229
230 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
231 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
232 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
233 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
234 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
235 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
236 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
237 ----------------------------------------------------------------------
238 */
239 template <
240 typename T,
241 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
nextafter(T from,T to)242 C10_HOST_DEVICE inline T nextafter(T from, T to) {
243 // Reference:
244 // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
245 using int_repr_t = uint16_t;
246 constexpr uint8_t bits = 16;
247 union {
248 T f;
249 int_repr_t i;
250 } ufrom = {from}, uto = {to};
251
252 // get a mask to get the sign bit i.e. MSB
253 int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
254
255 // short-circuit: if either is NaN, return NaN
256 if (from != from || to != to) {
257 return from + to;
258 }
259
260 // short-circuit: if they are exactly the same.
261 if (ufrom.i == uto.i) {
262 return from;
263 }
264
265 // mask the sign-bit to zero i.e. positive
266 // equivalent to abs(x)
267 int_repr_t abs_from = ufrom.i & ~sign_mask;
268 int_repr_t abs_to = uto.i & ~sign_mask;
269 if (abs_from == 0) {
270 // if both are zero but with different sign,
271 // preserve the sign of `to`.
272 if (abs_to == 0) {
273 return to;
274 }
275 // smallest subnormal with sign of `to`.
276 ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
277 return ufrom.f;
278 }
279
280 // if abs(from) > abs(to) or sign(from) != sign(to)
281 if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
282 ufrom.i--;
283 } else {
284 ufrom.i++;
285 }
286
287 return ufrom.f;
288 }
289
290 } // namespace std
291
292 C10_CLANG_DIAGNOSTIC_POP()
293