xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/llvm_basic.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This file is modified from LLVM, see the following copyright information
2 //
3 // -*- C++ -*-
4 //===----------------------------------------------------------------------===//
5 //
6 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7 // See https://llvm.org/LICENSE.txt for license information.
8 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include <string>
13 #include <ATen/cuda/llvm_jit_strings.h>
14 
15 namespace at::cuda {
16 
17 // copy-pasted from some llvm files:
18 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/type_traits
19 // - https://github.com/llvm/llvm-project/blob/main/clang/test/Headers/Inputs/include/type_traits
20 const std::string traits = R"ESCAPE(
21 
22 namespace std {
23 
24 template <class _Tp>
25 _Tp&& __declval(int);
26 template <class _Tp>
27 _Tp __declval(long);
28 template <class _Tp>
29 decltype(__declval<_Tp>(0)) declval() noexcept;
30 
31 template <class _Tp, _Tp __v>
32 struct integral_constant {
33   static const _Tp value = __v;
34   typedef _Tp value_type;
35   typedef integral_constant type;
36 };
37 
38 typedef integral_constant<bool, true> true_type;
39 typedef integral_constant<bool, false> false_type;
40 
41 // is_same, functional
42 template <class _Tp, class _Up> struct is_same : public false_type {};
43 template <class _Tp> struct is_same<_Tp, _Tp> : public true_type {};
44 
45 // is_integral, for some types.
46 template <class _Tp> struct is_integral
47     : public integral_constant<bool, false> {};
48 template <> struct is_integral<bool>
49     : public integral_constant<bool, true> {};
50 template <> struct is_integral<char>
51     : public integral_constant<bool, true> {};
52 template <> struct is_integral<short>
53     : public integral_constant<bool, true> {};
54 template <> struct is_integral<int>
55     : public integral_constant<bool, true> {};
56 template <> struct is_integral<long>
57     : public integral_constant<bool, true> {};
58 template <> struct is_integral<long long>
59     : public integral_constant<bool, true> {};
60 
61 // enable_if, functional
62 template <bool _C, typename _Tp> struct enable_if{};
63 template <typename _Tp> struct enable_if<true, _Tp>{
64   using type = _Tp;
65 };
66 template <bool b, class T=void>
67 using enable_if_t = typename enable_if<b,T>::type;
68 
69 template <class _Tp> struct remove_const            {typedef _Tp type;};
70 template <class _Tp> struct remove_const<const _Tp> {typedef _Tp type;};
71 template <class _Tp> using remove_const_t = typename remove_const<_Tp>::type;
72 
73 template <class _Tp> struct remove_volatile               {typedef _Tp type;};
74 template <class _Tp> struct remove_volatile<volatile _Tp> {typedef _Tp type;};
75 template <class _Tp> using remove_volatile_t = typename remove_volatile<_Tp>::type;
76 
77 template <class _Tp> struct remove_cv
78 {typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
79 template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;
80 
81 template <class _Tp> struct __libcpp_is_floating_point              : public false_type {};
82 template <>          struct __libcpp_is_floating_point<float>       : public true_type {};
83 template <>          struct __libcpp_is_floating_point<double>      : public true_type {};
84 template <>          struct __libcpp_is_floating_point<long double> : public true_type {};
85 
86 template <class _Tp> struct is_floating_point
87     : public __libcpp_is_floating_point<typename remove_cv<_Tp>::type> {};
88 
89 template <class _Tp> struct is_arithmetic
90     : public integral_constant<bool, is_integral<_Tp>::value      ||
91                                      is_floating_point<_Tp>::value> {};
92 template <class _Tp>
93 inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
94 
95 template <class _Tp>
96 struct __numeric_type
97 {
98    static void __test(...);
99    static float __test(float);
100    static double __test(char);
101    static double __test(int);
102    static double __test(unsigned);
103    static double __test(long);
104    static double __test(unsigned long);
105    static double __test(long long);
106    static double __test(unsigned long long);
107    static double __test(double);
108    static long double __test(long double);
109 
110    typedef decltype(__test(declval<_Tp>())) type;
111    static const bool value = !is_same<type, void>::value;
112 };
113 
114 template <>
115 struct __numeric_type<void>
116 {
117    static const bool value = true;
118 };
119 
120 // __promote
121 
122 template <class _A1, class _A2 = void, class _A3 = void,
123           bool = __numeric_type<_A1>::value &&
124                  __numeric_type<_A2>::value &&
125                  __numeric_type<_A3>::value>
126 class __promote_imp
127 {
128 public:
129     static const bool value = false;
130 };
131 
132 template <class _A1, class _A2, class _A3>
133 class __promote_imp<_A1, _A2, _A3, true>
134 {
135 private:
136     typedef typename __promote_imp<_A1>::type __type1;
137     typedef typename __promote_imp<_A2>::type __type2;
138     typedef typename __promote_imp<_A3>::type __type3;
139 public:
140     typedef decltype(__type1() + __type2() + __type3()) type;
141     static const bool value = true;
142 };
143 
144 template <class _A1, class _A2>
145 class __promote_imp<_A1, _A2, void, true>
146 {
147 private:
148     typedef typename __promote_imp<_A1>::type __type1;
149     typedef typename __promote_imp<_A2>::type __type2;
150 public:
151     typedef decltype(__type1() + __type2()) type;
152     static const bool value = true;
153 };
154 
155 template <class _A1>
156 class __promote_imp<_A1, void, void, true>
157 {
158 public:
159     typedef typename __numeric_type<_A1>::type type;
160     static const bool value = true;
161 };
162 
163 template <class _A1, class _A2 = void, class _A3 = void>
164 class __promote : public __promote_imp<_A1, _A2, _A3> {};
165 
166 } // namespace std
167 
168 )ESCAPE";
169 
get_traits_string()170 const std::string &get_traits_string() {
171     return traits;
172 }
173 
174 // This is copy-pasted from the following llvm file:
175 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/cmath
176 const std::string cmath = R"ESCAPE(
177 
178 namespace std {
179 
180 using ::signbit;
181 using ::isfinite;
182 using ::isinf;
183 using ::isnan;
184 
185 using ::abs;
186 
187 using ::acos;
188 using ::acosf;
189 using ::asin;
190 using ::asinf;
191 using ::atan;
192 using ::atanf;
193 using ::atan2;
194 using ::atan2f;
195 using ::ceil;
196 using ::ceilf;
197 using ::cos;
198 using ::cosf;
199 using ::cosh;
200 using ::coshf;
201 
202 using ::exp;
203 using ::expf;
204 
205 using ::fabs;
206 using ::fabsf;
207 using ::floor;
208 using ::floorf;
209 
210 using ::fmod;
211 using ::fmodf;
212 
213 using ::frexp;
214 using ::frexpf;
215 using ::ldexp;
216 using ::ldexpf;
217 
218 using ::log;
219 using ::logf;
220 
221 using ::log10;
222 using ::log10f;
223 using ::modf;
224 using ::modff;
225 
226 using ::pow;
227 using ::powf;
228 
229 using ::sin;
230 using ::sinf;
231 using ::sinh;
232 using ::sinhf;
233 
234 using ::sqrt;
235 using ::sqrtf;
236 using ::tan;
237 using ::tanf;
238 
239 using ::tanh;
240 using ::tanhf;
241 
242 using ::acosh;
243 using ::acoshf;
244 using ::asinh;
245 using ::asinhf;
246 using ::atanh;
247 using ::atanhf;
248 using ::cbrt;
249 using ::cbrtf;
250 
251 using ::copysign;
252 using ::copysignf;
253 
254 using ::erf;
255 using ::erff;
256 using ::erfc;
257 using ::erfcf;
258 using ::exp2;
259 using ::exp2f;
260 using ::expm1;
261 using ::expm1f;
262 using ::fdim;
263 using ::fdimf;
264 using ::fmaf;
265 using ::fma;
266 using ::fmax;
267 using ::fmaxf;
268 using ::fmin;
269 using ::fminf;
270 using ::hypot;
271 using ::hypotf;
272 using ::ilogb;
273 using ::ilogbf;
274 using ::lgamma;
275 using ::lgammaf;
276 using ::llrint;
277 using ::llrintf;
278 using ::llround;
279 using ::llroundf;
280 using ::log1p;
281 using ::log1pf;
282 using ::log2;
283 using ::log2f;
284 using ::logb;
285 using ::logbf;
286 using ::lrint;
287 using ::lrintf;
288 using ::lround;
289 using ::lroundf;
290 
291 using ::nan;
292 using ::nanf;
293 
294 using ::nearbyint;
295 using ::nearbyintf;
296 using ::nextafter;
297 using ::nextafterf;
298 using ::remainder;
299 using ::remainderf;
300 using ::remquo;
301 using ::remquof;
302 using ::rint;
303 using ::rintf;
304 using ::round;
305 using ::roundf;
306 using ::scalbln;
307 using ::scalblnf;
308 using ::scalbn;
309 using ::scalbnf;
310 using ::tgamma;
311 using ::tgammaf;
312 using ::trunc;
313 using ::truncf;
314 
315 } // namespace std
316 
317 )ESCAPE";
318 
get_cmath_string()319 const std::string &get_cmath_string() {
320     return cmath;
321 }
322 
323 } // namespace at::cuda
324