1 // This file is modified from LLVM, see the following copyright information
2 //
3 // -*- C++ -*-
4 //===----------------------------------------------------------------------===//
5 //
6 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7 // See https://llvm.org/LICENSE.txt for license information.
8 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 //
10 //===----------------------------------------------------------------------===//
11
12 #include <string>
13 #include <ATen/cuda/llvm_jit_strings.h>
14
15 namespace at::cuda {
16
17 // copy-pasted from some llvm files:
18 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/type_traits
19 // - https://github.com/llvm/llvm-project/blob/main/clang/test/Headers/Inputs/include/type_traits
20 const std::string traits = R"ESCAPE(
21
22 namespace std {
23
24 template <class _Tp>
25 _Tp&& __declval(int);
26 template <class _Tp>
27 _Tp __declval(long);
28 template <class _Tp>
29 decltype(__declval<_Tp>(0)) declval() noexcept;
30
31 template <class _Tp, _Tp __v>
32 struct integral_constant {
33 static const _Tp value = __v;
34 typedef _Tp value_type;
35 typedef integral_constant type;
36 };
37
38 typedef integral_constant<bool, true> true_type;
39 typedef integral_constant<bool, false> false_type;
40
41 // is_same, functional
42 template <class _Tp, class _Up> struct is_same : public false_type {};
43 template <class _Tp> struct is_same<_Tp, _Tp> : public true_type {};
44
45 // is_integral, for some types.
46 template <class _Tp> struct is_integral
47 : public integral_constant<bool, false> {};
48 template <> struct is_integral<bool>
49 : public integral_constant<bool, true> {};
50 template <> struct is_integral<char>
51 : public integral_constant<bool, true> {};
52 template <> struct is_integral<short>
53 : public integral_constant<bool, true> {};
54 template <> struct is_integral<int>
55 : public integral_constant<bool, true> {};
56 template <> struct is_integral<long>
57 : public integral_constant<bool, true> {};
58 template <> struct is_integral<long long>
59 : public integral_constant<bool, true> {};
60
61 // enable_if, functional
62 template <bool _C, typename _Tp> struct enable_if{};
63 template <typename _Tp> struct enable_if<true, _Tp>{
64 using type = _Tp;
65 };
66 template <bool b, class T=void>
67 using enable_if_t = typename enable_if<b,T>::type;
68
69 template <class _Tp> struct remove_const {typedef _Tp type;};
70 template <class _Tp> struct remove_const<const _Tp> {typedef _Tp type;};
71 template <class _Tp> using remove_const_t = typename remove_const<_Tp>::type;
72
73 template <class _Tp> struct remove_volatile {typedef _Tp type;};
74 template <class _Tp> struct remove_volatile<volatile _Tp> {typedef _Tp type;};
75 template <class _Tp> using remove_volatile_t = typename remove_volatile<_Tp>::type;
76
77 template <class _Tp> struct remove_cv
78 {typedef typename remove_volatile<typename remove_const<_Tp>::type>::type type;};
79 template <class _Tp> using remove_cv_t = typename remove_cv<_Tp>::type;
80
81 template <class _Tp> struct __libcpp_is_floating_point : public false_type {};
82 template <> struct __libcpp_is_floating_point<float> : public true_type {};
83 template <> struct __libcpp_is_floating_point<double> : public true_type {};
84 template <> struct __libcpp_is_floating_point<long double> : public true_type {};
85
86 template <class _Tp> struct is_floating_point
87 : public __libcpp_is_floating_point<typename remove_cv<_Tp>::type> {};
88
89 template <class _Tp> struct is_arithmetic
90 : public integral_constant<bool, is_integral<_Tp>::value ||
91 is_floating_point<_Tp>::value> {};
92 template <class _Tp>
93 inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value;
94
95 template <class _Tp>
96 struct __numeric_type
97 {
98 static void __test(...);
99 static float __test(float);
100 static double __test(char);
101 static double __test(int);
102 static double __test(unsigned);
103 static double __test(long);
104 static double __test(unsigned long);
105 static double __test(long long);
106 static double __test(unsigned long long);
107 static double __test(double);
108 static long double __test(long double);
109
110 typedef decltype(__test(declval<_Tp>())) type;
111 static const bool value = !is_same<type, void>::value;
112 };
113
114 template <>
115 struct __numeric_type<void>
116 {
117 static const bool value = true;
118 };
119
120 // __promote
121
122 template <class _A1, class _A2 = void, class _A3 = void,
123 bool = __numeric_type<_A1>::value &&
124 __numeric_type<_A2>::value &&
125 __numeric_type<_A3>::value>
126 class __promote_imp
127 {
128 public:
129 static const bool value = false;
130 };
131
132 template <class _A1, class _A2, class _A3>
133 class __promote_imp<_A1, _A2, _A3, true>
134 {
135 private:
136 typedef typename __promote_imp<_A1>::type __type1;
137 typedef typename __promote_imp<_A2>::type __type2;
138 typedef typename __promote_imp<_A3>::type __type3;
139 public:
140 typedef decltype(__type1() + __type2() + __type3()) type;
141 static const bool value = true;
142 };
143
144 template <class _A1, class _A2>
145 class __promote_imp<_A1, _A2, void, true>
146 {
147 private:
148 typedef typename __promote_imp<_A1>::type __type1;
149 typedef typename __promote_imp<_A2>::type __type2;
150 public:
151 typedef decltype(__type1() + __type2()) type;
152 static const bool value = true;
153 };
154
155 template <class _A1>
156 class __promote_imp<_A1, void, void, true>
157 {
158 public:
159 typedef typename __numeric_type<_A1>::type type;
160 static const bool value = true;
161 };
162
163 template <class _A1, class _A2 = void, class _A3 = void>
164 class __promote : public __promote_imp<_A1, _A2, _A3> {};
165
166 } // namespace std
167
168 )ESCAPE";
169
get_traits_string()170 const std::string &get_traits_string() {
171 return traits;
172 }
173
174 // This is copy-pasted from the following llvm file:
175 // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/cmath
176 const std::string cmath = R"ESCAPE(
177
178 namespace std {
179
180 using ::signbit;
181 using ::isfinite;
182 using ::isinf;
183 using ::isnan;
184
185 using ::abs;
186
187 using ::acos;
188 using ::acosf;
189 using ::asin;
190 using ::asinf;
191 using ::atan;
192 using ::atanf;
193 using ::atan2;
194 using ::atan2f;
195 using ::ceil;
196 using ::ceilf;
197 using ::cos;
198 using ::cosf;
199 using ::cosh;
200 using ::coshf;
201
202 using ::exp;
203 using ::expf;
204
205 using ::fabs;
206 using ::fabsf;
207 using ::floor;
208 using ::floorf;
209
210 using ::fmod;
211 using ::fmodf;
212
213 using ::frexp;
214 using ::frexpf;
215 using ::ldexp;
216 using ::ldexpf;
217
218 using ::log;
219 using ::logf;
220
221 using ::log10;
222 using ::log10f;
223 using ::modf;
224 using ::modff;
225
226 using ::pow;
227 using ::powf;
228
229 using ::sin;
230 using ::sinf;
231 using ::sinh;
232 using ::sinhf;
233
234 using ::sqrt;
235 using ::sqrtf;
236 using ::tan;
237 using ::tanf;
238
239 using ::tanh;
240 using ::tanhf;
241
242 using ::acosh;
243 using ::acoshf;
244 using ::asinh;
245 using ::asinhf;
246 using ::atanh;
247 using ::atanhf;
248 using ::cbrt;
249 using ::cbrtf;
250
251 using ::copysign;
252 using ::copysignf;
253
254 using ::erf;
255 using ::erff;
256 using ::erfc;
257 using ::erfcf;
258 using ::exp2;
259 using ::exp2f;
260 using ::expm1;
261 using ::expm1f;
262 using ::fdim;
263 using ::fdimf;
264 using ::fmaf;
265 using ::fma;
266 using ::fmax;
267 using ::fmaxf;
268 using ::fmin;
269 using ::fminf;
270 using ::hypot;
271 using ::hypotf;
272 using ::ilogb;
273 using ::ilogbf;
274 using ::lgamma;
275 using ::lgammaf;
276 using ::llrint;
277 using ::llrintf;
278 using ::llround;
279 using ::llroundf;
280 using ::log1p;
281 using ::log1pf;
282 using ::log2;
283 using ::log2f;
284 using ::logb;
285 using ::logbf;
286 using ::lrint;
287 using ::lrintf;
288 using ::lround;
289 using ::lroundf;
290
291 using ::nan;
292 using ::nanf;
293
294 using ::nearbyint;
295 using ::nearbyintf;
296 using ::nextafter;
297 using ::nextafterf;
298 using ::remainder;
299 using ::remainderf;
300 using ::remquo;
301 using ::remquof;
302 using ::rint;
303 using ::rintf;
304 using ::round;
305 using ::roundf;
306 using ::scalbln;
307 using ::scalblnf;
308 using ::scalbn;
309 using ::scalbnf;
310 using ::tgamma;
311 using ::tgammaf;
312 using ::trunc;
313 using ::truncf;
314
315 } // namespace std
316
317 )ESCAPE";
318
get_cmath_string()319 const std::string &get_cmath_string() {
320 return cmath;
321 }
322
323 } // namespace at::cuda
324