xref: /aosp_15_r20/external/pytorch/aten/src/ATen/NumericUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef __HIPCC__
4 #include <hip/hip_runtime.h>
5 #endif
6 
7 #include <c10/macros/Macros.h>
8 #include <c10/util/BFloat16.h>
9 #include <c10/util/Float8_e4m3fn.h>
10 #include <c10/util/Float8_e4m3fnuz.h>
11 #include <c10/util/Float8_e5m2.h>
12 #include <c10/util/Float8_e5m2fnuz.h>
13 #include <c10/util/Half.h>
14 #include <c10/util/complex.h>
15 
16 #include <cmath>
17 #include <type_traits>
18 
19 namespace at {
20 
21 // std::isnan isn't performant to use on integral types; it will
22 // (uselessly) convert to floating point and then do the test.
23 // This function is.
24 
25 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isnan(T)26 inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
27   return false;
28 }
29 
30 template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isnan(T val)31 inline C10_HOST_DEVICE bool _isnan(T val) {
32 #if defined(__CUDACC__) || defined(__HIPCC__)
33   return ::isnan(val);
34 #else
35   return std::isnan(val);
36 #endif
37 }
38 
39 template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
_isnan(T val)40 inline C10_HOST_DEVICE bool _isnan(T val) {
41   return std::isnan(val.real()) || std::isnan(val.imag());
42 }
43 
44 template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
_isnan(T val)45 inline C10_HOST_DEVICE bool _isnan(T val) {
46   return at::_isnan(static_cast<float>(val));
47 }
48 
49 template <
50     typename T,
51     std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
_isnan(at::BFloat16 val)52 inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
53   return at::_isnan(static_cast<float>(val));
54 }
55 
_isnan(at::BFloat16 val)56 inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
57   return at::_isnan(static_cast<float>(val));
58 }
59 
60 template <
61     typename T,
62     std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
_isnan(T val)63 inline C10_HOST_DEVICE bool _isnan(T val) {
64   return val.isnan();
65 }
66 
67 template <
68     typename T,
69     std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
_isnan(T val)70 inline C10_HOST_DEVICE bool _isnan(T val) {
71   return val.isnan();
72 }
73 
74 template <
75     typename T,
76     std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
_isnan(T val)77 inline C10_HOST_DEVICE bool _isnan(T val) {
78   return val.isnan();
79 }
80 
81 template <
82     typename T,
83     std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
_isnan(T val)84 inline C10_HOST_DEVICE bool _isnan(T val) {
85   return val.isnan();
86 }
87 
88 // std::isinf isn't performant to use on integral types; it will
89 // (uselessly) convert to floating point and then do the test.
90 // This function is.
91 
92 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
_isinf(T)93 inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
94   return false;
95 }
96 
97 template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
_isinf(T val)98 inline C10_HOST_DEVICE bool _isinf(T val) {
99 #if defined(__CUDACC__) || defined(__HIPCC__)
100   return ::isinf(val);
101 #else
102   return std::isinf(val);
103 #endif
104 }
105 
_isinf(at::Half val)106 inline C10_HOST_DEVICE bool _isinf(at::Half val) {
107   return at::_isinf(static_cast<float>(val));
108 }
109 
_isinf(at::BFloat16 val)110 inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
111   return at::_isinf(static_cast<float>(val));
112 }
113 
_isinf(at::Float8_e5m2 val)114 inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
115   return val.isinf();
116 }
117 
_isinf(at::Float8_e4m3fn val)118 inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
119   return false;
120 }
121 
_isinf(at::Float8_e5m2fnuz val)122 inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
123   return false;
124 }
125 
_isinf(at::Float8_e4m3fnuz val)126 inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
127   return false;
128 }
129 
130 template <typename T>
exp(T x)131 C10_HOST_DEVICE inline T exp(T x) {
132   static_assert(
133       !std::is_same_v<T, double>,
134       "this template must be used with float or less precise type");
135 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
136   // use __expf fast approximation for peak bandwidth
137   return __expf(x);
138 #else
139   return ::exp(x);
140 #endif
141 }
142 
143 template <>
144 C10_HOST_DEVICE inline double exp<double>(double x) {
145   return ::exp(x);
146 }
147 
148 template <typename T>
log(T x)149 C10_HOST_DEVICE inline T log(T x) {
150   static_assert(
151       !std::is_same_v<T, double>,
152       "this template must be used with float or less precise type");
153 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
154   // use __logf fast approximation for peak bandwidth
155   return __logf(x);
156 #else
157   return ::log(x);
158 #endif
159 }
160 
161 template <>
162 C10_HOST_DEVICE inline double log<double>(double x) {
163   return ::log(x);
164 }
165 
166 template <typename T>
log1p(T x)167 C10_HOST_DEVICE inline T log1p(T x) {
168   static_assert(
169       !std::is_same_v<T, double>,
170       "this template must be used with float or less precise type");
171 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
172   // use __logf fast approximation for peak bandwidth
173   // NOTE: There is no __log1pf so unfortunately we lose precision.
174   return __logf(1.0f + x);
175 #else
176   return ::log1p(x);
177 #endif
178 }
179 
180 template <>
181 C10_HOST_DEVICE inline double log1p<double>(double x) {
182   return ::log1p(x);
183 }
184 
185 template <typename T>
tan(T x)186 C10_HOST_DEVICE inline T tan(T x) {
187   static_assert(
188       !std::is_same_v<T, double>,
189       "this template must be used with float or less precise type");
190 #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
191   // use __tanf fast approximation for peak bandwidth
192   return __tanf(x);
193 #else
194   return ::tan(x);
195 #endif
196 }
197 
198 template <>
199 C10_HOST_DEVICE inline double tan<double>(double x) {
200   return ::tan(x);
201 }
202 
203 } // namespace at
204