1 #pragma once
2
3 /* This file defines math functions compatible across different gpu
4 * platforms (currently CUDA and HIP).
5 */
6 #if defined(__CUDACC__) || defined(__HIPCC__)
7
8 #include <c10/macros/Macros.h>
9 #include <c10/util/Exception.h>
10
11 #ifdef __HIPCC__
12 #define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE
13 #else /* __HIPCC__ */
14 #ifdef __CUDACC_RTC__
15 #define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE
16 #else /* __CUDACC_RTC__ */
17 #define __MATH_FUNCTIONS_DECL__ inline C10_HOST_DEVICE
18 #endif /* __CUDACC_RTC__ */
19 #endif /* __HIPCC__ */
20
21 namespace c10::cuda::compat {
22
abs(float x)23 __MATH_FUNCTIONS_DECL__ float abs(float x) {
24 return ::fabsf(x);
25 }
abs(double x)26 __MATH_FUNCTIONS_DECL__ double abs(double x) {
27 return ::fabs(x);
28 }
29
exp(float x)30 __MATH_FUNCTIONS_DECL__ float exp(float x) {
31 return ::expf(x);
32 }
exp(double x)33 __MATH_FUNCTIONS_DECL__ double exp(double x) {
34 return ::exp(x);
35 }
36
ceil(float x)37 __MATH_FUNCTIONS_DECL__ float ceil(float x) {
38 return ::ceilf(x);
39 }
ceil(double x)40 __MATH_FUNCTIONS_DECL__ double ceil(double x) {
41 return ::ceil(x);
42 }
43
copysign(float x,float y)44 __MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
45 #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
46 return ::copysignf(x, y);
47 #else
48 // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64
49 // (e.g. Jetson), see PyTorch PR #51834
50 // This host function needs to be here for the compiler but is never used
51 TORCH_INTERNAL_ASSERT(
52 false, "CUDAMathCompat copysign should not run on the CPU");
53 #endif
54 }
copysign(double x,double y)55 __MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
56 #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
57 return ::copysign(x, y);
58 #else
59 // see above
60 TORCH_INTERNAL_ASSERT(
61 false, "CUDAMathCompat copysign should not run on the CPU");
62 #endif
63 }
64
floor(float x)65 __MATH_FUNCTIONS_DECL__ float floor(float x) {
66 return ::floorf(x);
67 }
floor(double x)68 __MATH_FUNCTIONS_DECL__ double floor(double x) {
69 return ::floor(x);
70 }
71
log(float x)72 __MATH_FUNCTIONS_DECL__ float log(float x) {
73 return ::logf(x);
74 }
log(double x)75 __MATH_FUNCTIONS_DECL__ double log(double x) {
76 return ::log(x);
77 }
78
log1p(float x)79 __MATH_FUNCTIONS_DECL__ float log1p(float x) {
80 return ::log1pf(x);
81 }
82
log1p(double x)83 __MATH_FUNCTIONS_DECL__ double log1p(double x) {
84 return ::log1p(x);
85 }
86
max(float x,float y)87 __MATH_FUNCTIONS_DECL__ float max(float x, float y) {
88 return ::fmaxf(x, y);
89 }
max(double x,double y)90 __MATH_FUNCTIONS_DECL__ double max(double x, double y) {
91 return ::fmax(x, y);
92 }
93
min(float x,float y)94 __MATH_FUNCTIONS_DECL__ float min(float x, float y) {
95 return ::fminf(x, y);
96 }
min(double x,double y)97 __MATH_FUNCTIONS_DECL__ double min(double x, double y) {
98 return ::fmin(x, y);
99 }
100
pow(float x,float y)101 __MATH_FUNCTIONS_DECL__ float pow(float x, float y) {
102 return ::powf(x, y);
103 }
pow(double x,double y)104 __MATH_FUNCTIONS_DECL__ double pow(double x, double y) {
105 return ::pow(x, y);
106 }
107
sincos(float x,float * sptr,float * cptr)108 __MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) {
109 return ::sincosf(x, sptr, cptr);
110 }
sincos(double x,double * sptr,double * cptr)111 __MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) {
112 return ::sincos(x, sptr, cptr);
113 }
114
sqrt(float x)115 __MATH_FUNCTIONS_DECL__ float sqrt(float x) {
116 return ::sqrtf(x);
117 }
sqrt(double x)118 __MATH_FUNCTIONS_DECL__ double sqrt(double x) {
119 return ::sqrt(x);
120 }
121
rsqrt(float x)122 __MATH_FUNCTIONS_DECL__ float rsqrt(float x) {
123 return ::rsqrtf(x);
124 }
rsqrt(double x)125 __MATH_FUNCTIONS_DECL__ double rsqrt(double x) {
126 return ::rsqrt(x);
127 }
128
tan(float x)129 __MATH_FUNCTIONS_DECL__ float tan(float x) {
130 return ::tanf(x);
131 }
tan(double x)132 __MATH_FUNCTIONS_DECL__ double tan(double x) {
133 return ::tan(x);
134 }
135
tanh(float x)136 __MATH_FUNCTIONS_DECL__ float tanh(float x) {
137 return ::tanhf(x);
138 }
tanh(double x)139 __MATH_FUNCTIONS_DECL__ double tanh(double x) {
140 return ::tanh(x);
141 }
142
normcdf(float x)143 __MATH_FUNCTIONS_DECL__ float normcdf(float x) {
144 return ::normcdff(x);
145 }
normcdf(double x)146 __MATH_FUNCTIONS_DECL__ double normcdf(double x) {
147 return ::normcdf(x);
148 }
149
150 } // namespace c10::cuda::compat
151
152 #endif
153