1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #ifdef __HIPCC__
4*da0073e9SAndroid Build Coastguard Worker #include <hip/hip_runtime.h>
5*da0073e9SAndroid Build Coastguard Worker #endif
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/BFloat16.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Float8_e4m3fn.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Float8_e4m3fnuz.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Float8_e5m2.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Float8_e5m2fnuz.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Half.h>
14*da0073e9SAndroid Build Coastguard Worker #include <c10/util/complex.h>
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker #include <cmath>
17*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker namespace at {
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker // std::isnan isn't performant to use on integral types; it will
22*da0073e9SAndroid Build Coastguard Worker // (uselessly) convert to floating point and then do the test.
23*da0073e9SAndroid Build Coastguard Worker // This function is.
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isnan(T)26*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
27*da0073e9SAndroid Build Coastguard Worker return false;
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isnan(T val)31*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
32*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
33*da0073e9SAndroid Build Coastguard Worker return ::isnan(val);
34*da0073e9SAndroid Build Coastguard Worker #else
35*da0073e9SAndroid Build Coastguard Worker return std::isnan(val);
36*da0073e9SAndroid Build Coastguard Worker #endif
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
_isnan(T val)40*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
41*da0073e9SAndroid Build Coastguard Worker return std::isnan(val.real()) || std::isnan(val.imag());
42*da0073e9SAndroid Build Coastguard Worker }
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
_isnan(T val)45*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
46*da0073e9SAndroid Build Coastguard Worker return at::_isnan(static_cast<float>(val));
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker template <
50*da0073e9SAndroid Build Coastguard Worker typename T,
51*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
_isnan(at::BFloat16 val)52*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
53*da0073e9SAndroid Build Coastguard Worker return at::_isnan(static_cast<float>(val));
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker
_isnan(at::BFloat16 val)56*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
57*da0073e9SAndroid Build Coastguard Worker return at::_isnan(static_cast<float>(val));
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker template <
61*da0073e9SAndroid Build Coastguard Worker typename T,
62*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
_isnan(T val)63*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
64*da0073e9SAndroid Build Coastguard Worker return val.isnan();
65*da0073e9SAndroid Build Coastguard Worker }
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker template <
68*da0073e9SAndroid Build Coastguard Worker typename T,
69*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
_isnan(T val)70*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
71*da0073e9SAndroid Build Coastguard Worker return val.isnan();
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker template <
75*da0073e9SAndroid Build Coastguard Worker typename T,
76*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
_isnan(T val)77*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
78*da0073e9SAndroid Build Coastguard Worker return val.isnan();
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker template <
82*da0073e9SAndroid Build Coastguard Worker typename T,
83*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
_isnan(T val)84*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isnan(T val) {
85*da0073e9SAndroid Build Coastguard Worker return val.isnan();
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker // std::isinf isn't performant to use on integral types; it will
89*da0073e9SAndroid Build Coastguard Worker // (uselessly) convert to floating point and then do the test.
90*da0073e9SAndroid Build Coastguard Worker // This function is.
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isinf(T)93*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
94*da0073e9SAndroid Build Coastguard Worker return false;
95*da0073e9SAndroid Build Coastguard Worker }
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isinf(T val)98*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(T val) {
99*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
100*da0073e9SAndroid Build Coastguard Worker return ::isinf(val);
101*da0073e9SAndroid Build Coastguard Worker #else
102*da0073e9SAndroid Build Coastguard Worker return std::isinf(val);
103*da0073e9SAndroid Build Coastguard Worker #endif
104*da0073e9SAndroid Build Coastguard Worker }
105*da0073e9SAndroid Build Coastguard Worker
_isinf(at::Half val)106*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::Half val) {
107*da0073e9SAndroid Build Coastguard Worker return at::_isinf(static_cast<float>(val));
108*da0073e9SAndroid Build Coastguard Worker }
109*da0073e9SAndroid Build Coastguard Worker
_isinf(at::BFloat16 val)110*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
111*da0073e9SAndroid Build Coastguard Worker return at::_isinf(static_cast<float>(val));
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
_isinf(at::Float8_e5m2 val)114*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
115*da0073e9SAndroid Build Coastguard Worker return val.isinf();
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker
_isinf(at::Float8_e4m3fn val)118*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
119*da0073e9SAndroid Build Coastguard Worker return false;
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker
_isinf(at::Float8_e5m2fnuz val)122*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
123*da0073e9SAndroid Build Coastguard Worker return false;
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker
_isinf(at::Float8_e4m3fnuz val)126*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
127*da0073e9SAndroid Build Coastguard Worker return false;
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker template <typename T>
exp(T x)131*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline T exp(T x) {
132*da0073e9SAndroid Build Coastguard Worker static_assert(
133*da0073e9SAndroid Build Coastguard Worker !std::is_same_v<T, double>,
134*da0073e9SAndroid Build Coastguard Worker "this template must be used with float or less precise type");
135*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
136*da0073e9SAndroid Build Coastguard Worker // use __expf fast approximation for peak bandwidth
137*da0073e9SAndroid Build Coastguard Worker return __expf(x);
138*da0073e9SAndroid Build Coastguard Worker #else
139*da0073e9SAndroid Build Coastguard Worker return ::exp(x);
140*da0073e9SAndroid Build Coastguard Worker #endif
141*da0073e9SAndroid Build Coastguard Worker }
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker template <>
144*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline double exp<double>(double x) {
145*da0073e9SAndroid Build Coastguard Worker return ::exp(x);
146*da0073e9SAndroid Build Coastguard Worker }
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker template <typename T>
log(T x)149*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline T log(T x) {
150*da0073e9SAndroid Build Coastguard Worker static_assert(
151*da0073e9SAndroid Build Coastguard Worker !std::is_same_v<T, double>,
152*da0073e9SAndroid Build Coastguard Worker "this template must be used with float or less precise type");
153*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
154*da0073e9SAndroid Build Coastguard Worker // use __logf fast approximation for peak bandwidth
155*da0073e9SAndroid Build Coastguard Worker return __logf(x);
156*da0073e9SAndroid Build Coastguard Worker #else
157*da0073e9SAndroid Build Coastguard Worker return ::log(x);
158*da0073e9SAndroid Build Coastguard Worker #endif
159*da0073e9SAndroid Build Coastguard Worker }
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker template <>
162*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline double log<double>(double x) {
163*da0073e9SAndroid Build Coastguard Worker return ::log(x);
164*da0073e9SAndroid Build Coastguard Worker }
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker template <typename T>
log1p(T x)167*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline T log1p(T x) {
168*da0073e9SAndroid Build Coastguard Worker static_assert(
169*da0073e9SAndroid Build Coastguard Worker !std::is_same_v<T, double>,
170*da0073e9SAndroid Build Coastguard Worker "this template must be used with float or less precise type");
171*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
172*da0073e9SAndroid Build Coastguard Worker // use __logf fast approximation for peak bandwidth
173*da0073e9SAndroid Build Coastguard Worker // NOTE: There is no __log1pf so unfortunately we lose precision.
174*da0073e9SAndroid Build Coastguard Worker return __logf(1.0f + x);
175*da0073e9SAndroid Build Coastguard Worker #else
176*da0073e9SAndroid Build Coastguard Worker return ::log1p(x);
177*da0073e9SAndroid Build Coastguard Worker #endif
178*da0073e9SAndroid Build Coastguard Worker }
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker template <>
181*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline double log1p<double>(double x) {
182*da0073e9SAndroid Build Coastguard Worker return ::log1p(x);
183*da0073e9SAndroid Build Coastguard Worker }
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker template <typename T>
tan(T x)186*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline T tan(T x) {
187*da0073e9SAndroid Build Coastguard Worker static_assert(
188*da0073e9SAndroid Build Coastguard Worker !std::is_same_v<T, double>,
189*da0073e9SAndroid Build Coastguard Worker "this template must be used with float or less precise type");
190*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
191*da0073e9SAndroid Build Coastguard Worker // use __tanf fast approximation for peak bandwidth
192*da0073e9SAndroid Build Coastguard Worker return __tanf(x);
193*da0073e9SAndroid Build Coastguard Worker #else
194*da0073e9SAndroid Build Coastguard Worker return ::tan(x);
195*da0073e9SAndroid Build Coastguard Worker #endif
196*da0073e9SAndroid Build Coastguard Worker }
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker template <>
199*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE inline double tan<double>(double x) {
200*da0073e9SAndroid Build Coastguard Worker return ::tan(x);
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker } // namespace at
204