1 #pragma once
2
3 #ifdef __HIPCC__
4 #include <hip/hip_runtime.h>
5 #endif
6
7 #include <c10/macros/Macros.h>
8 #include <c10/util/BFloat16.h>
9 #include <c10/util/Float8_e4m3fn.h>
10 #include <c10/util/Float8_e4m3fnuz.h>
11 #include <c10/util/Float8_e5m2.h>
12 #include <c10/util/Float8_e5m2fnuz.h>
13 #include <c10/util/Half.h>
14 #include <c10/util/complex.h>
15
16 #include <cmath>
17 #include <type_traits>
18
19 namespace at {
20
21 // std::isnan isn't performant to use on integral types; it will
22 // (uselessly) convert to floating point and then do the test.
23 // This function is.
24
25 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isnan(T)26 inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
27 return false;
28 }
29
30 template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isnan(T val)31 inline C10_HOST_DEVICE bool _isnan(T val) {
32 #if defined(__CUDACC__) || defined(__HIPCC__)
33 return ::isnan(val);
34 #else
35 return std::isnan(val);
36 #endif
37 }
38
39 template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
_isnan(T val)40 inline C10_HOST_DEVICE bool _isnan(T val) {
41 return std::isnan(val.real()) || std::isnan(val.imag());
42 }
43
44 template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
_isnan(T val)45 inline C10_HOST_DEVICE bool _isnan(T val) {
46 return at::_isnan(static_cast<float>(val));
47 }
48
49 template <
50 typename T,
51 std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
_isnan(at::BFloat16 val)52 inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
53 return at::_isnan(static_cast<float>(val));
54 }
55
_isnan(at::BFloat16 val)56 inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
57 return at::_isnan(static_cast<float>(val));
58 }
59
60 template <
61 typename T,
62 std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
_isnan(T val)63 inline C10_HOST_DEVICE bool _isnan(T val) {
64 return val.isnan();
65 }
66
67 template <
68 typename T,
69 std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
_isnan(T val)70 inline C10_HOST_DEVICE bool _isnan(T val) {
71 return val.isnan();
72 }
73
74 template <
75 typename T,
76 std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
_isnan(T val)77 inline C10_HOST_DEVICE bool _isnan(T val) {
78 return val.isnan();
79 }
80
81 template <
82 typename T,
83 std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
_isnan(T val)84 inline C10_HOST_DEVICE bool _isnan(T val) {
85 return val.isnan();
86 }
87
88 // std::isinf isn't performant to use on integral types; it will
89 // (uselessly) convert to floating point and then do the test.
90 // This function is.
91
92 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isinf(T)93 inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
94 return false;
95 }
96
97 template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isinf(T val)98 inline C10_HOST_DEVICE bool _isinf(T val) {
99 #if defined(__CUDACC__) || defined(__HIPCC__)
100 return ::isinf(val);
101 #else
102 return std::isinf(val);
103 #endif
104 }
105
_isinf(at::Half val)106 inline C10_HOST_DEVICE bool _isinf(at::Half val) {
107 return at::_isinf(static_cast<float>(val));
108 }
109
_isinf(at::BFloat16 val)110 inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
111 return at::_isinf(static_cast<float>(val));
112 }
113
_isinf(at::Float8_e5m2 val)114 inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
115 return val.isinf();
116 }
117
_isinf(at::Float8_e4m3fn val)118 inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
119 return false;
120 }
121
_isinf(at::Float8_e5m2fnuz val)122 inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
123 return false;
124 }
125
_isinf(at::Float8_e4m3fnuz val)126 inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
127 return false;
128 }
129
130 template <typename T>
exp(T x)131 C10_HOST_DEVICE inline T exp(T x) {
132 static_assert(
133 !std::is_same_v<T, double>,
134 "this template must be used with float or less precise type");
135 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
136 // use __expf fast approximation for peak bandwidth
137 return __expf(x);
138 #else
139 return ::exp(x);
140 #endif
141 }
142
143 template <>
144 C10_HOST_DEVICE inline double exp<double>(double x) {
145 return ::exp(x);
146 }
147
148 template <typename T>
log(T x)149 C10_HOST_DEVICE inline T log(T x) {
150 static_assert(
151 !std::is_same_v<T, double>,
152 "this template must be used with float or less precise type");
153 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
154 // use __logf fast approximation for peak bandwidth
155 return __logf(x);
156 #else
157 return ::log(x);
158 #endif
159 }
160
161 template <>
162 C10_HOST_DEVICE inline double log<double>(double x) {
163 return ::log(x);
164 }
165
166 template <typename T>
log1p(T x)167 C10_HOST_DEVICE inline T log1p(T x) {
168 static_assert(
169 !std::is_same_v<T, double>,
170 "this template must be used with float or less precise type");
171 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
172 // use __logf fast approximation for peak bandwidth
173 // NOTE: There is no __log1pf so unfortunately we lose precision.
174 return __logf(1.0f + x);
175 #else
176 return ::log1p(x);
177 #endif
178 }
179
180 template <>
181 C10_HOST_DEVICE inline double log1p<double>(double x) {
182 return ::log1p(x);
183 }
184
185 template <typename T>
tan(T x)186 C10_HOST_DEVICE inline T tan(T x) {
187 static_assert(
188 !std::is_same_v<T, double>,
189 "this template must be used with float or less precise type");
190 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
191 // use __tanf fast approximation for peak bandwidth
192 return __tanf(x);
193 #else
194 return ::tan(x);
195 #endif
196 }
197
198 template <>
199 C10_HOST_DEVICE inline double tan<double>(double x) {
200 return ::tan(x);
201 }
202
203 } // namespace at
204