xref: /aosp_15_r20/external/pytorch/c10/util/MathConstants.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/BFloat16.h>
5 #include <c10/util/Half.h>
6 
7 C10_CLANG_DIAGNOSTIC_PUSH()
8 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
10 #endif
11 
12 namespace c10 {
13 // TODO: Replace me with inline constexpr variable when C++17 becomes available
14 namespace detail {
15 template <typename T>
e()16 C10_HOST_DEVICE inline constexpr T e() {
17   return static_cast<T>(2.718281828459045235360287471352662);
18 }
19 
20 template <typename T>
euler()21 C10_HOST_DEVICE inline constexpr T euler() {
22   return static_cast<T>(0.577215664901532860606512090082402);
23 }
24 
25 template <typename T>
frac_1_pi()26 C10_HOST_DEVICE inline constexpr T frac_1_pi() {
27   return static_cast<T>(0.318309886183790671537767526745028);
28 }
29 
30 template <typename T>
frac_1_sqrt_pi()31 C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() {
32   return static_cast<T>(0.564189583547756286948079451560772);
33 }
34 
35 template <typename T>
frac_sqrt_2()36 C10_HOST_DEVICE inline constexpr T frac_sqrt_2() {
37   return static_cast<T>(0.707106781186547524400844362104849);
38 }
39 
40 template <typename T>
frac_sqrt_3()41 C10_HOST_DEVICE inline constexpr T frac_sqrt_3() {
42   return static_cast<T>(0.577350269189625764509148780501957);
43 }
44 
45 template <typename T>
golden_ratio()46 C10_HOST_DEVICE inline constexpr T golden_ratio() {
47   return static_cast<T>(1.618033988749894848204586834365638);
48 }
49 
50 template <typename T>
ln_10()51 C10_HOST_DEVICE inline constexpr T ln_10() {
52   return static_cast<T>(2.302585092994045684017991454684364);
53 }
54 
55 template <typename T>
ln_2()56 C10_HOST_DEVICE inline constexpr T ln_2() {
57   return static_cast<T>(0.693147180559945309417232121458176);
58 }
59 
60 template <typename T>
log_10_e()61 C10_HOST_DEVICE inline constexpr T log_10_e() {
62   return static_cast<T>(0.434294481903251827651128918916605);
63 }
64 
65 template <typename T>
log_2_e()66 C10_HOST_DEVICE inline constexpr T log_2_e() {
67   return static_cast<T>(1.442695040888963407359924681001892);
68 }
69 
70 template <typename T>
pi()71 C10_HOST_DEVICE inline constexpr T pi() {
72   return static_cast<T>(3.141592653589793238462643383279502);
73 }
74 
75 template <typename T>
sqrt_2()76 C10_HOST_DEVICE inline constexpr T sqrt_2() {
77   return static_cast<T>(1.414213562373095048801688724209698);
78 }
79 
80 template <typename T>
sqrt_3()81 C10_HOST_DEVICE inline constexpr T sqrt_3() {
82   return static_cast<T>(1.732050807568877293527446341505872);
83 }
84 
85 template <>
86 C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() {
87   // According to
88   // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values
89   // pi is encoded as 4049
90   return BFloat16(0x4049, BFloat16::from_bits());
91 }
92 
93 template <>
94 C10_HOST_DEVICE inline constexpr Half pi<Half>() {
95   return Half(0x4248, Half::from_bits());
96 }
97 } // namespace detail
98 
99 template <typename T>
100 constexpr T e = c10::detail::e<T>();
101 
102 template <typename T>
103 constexpr T euler = c10::detail::euler<T>();
104 
105 template <typename T>
106 constexpr T frac_1_pi = c10::detail::frac_1_pi<T>();
107 
108 template <typename T>
109 constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>();
110 
111 template <typename T>
112 constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>();
113 
114 template <typename T>
115 constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>();
116 
117 template <typename T>
118 constexpr T golden_ratio = c10::detail::golden_ratio<T>();
119 
120 template <typename T>
121 constexpr T ln_10 = c10::detail::ln_10<T>();
122 
123 template <typename T>
124 constexpr T ln_2 = c10::detail::ln_2<T>();
125 
126 template <typename T>
127 constexpr T log_10_e = c10::detail::log_10_e<T>();
128 
129 template <typename T>
130 constexpr T log_2_e = c10::detail::log_2_e<T>();
131 
132 template <typename T>
133 constexpr T pi = c10::detail::pi<T>();
134 
135 template <typename T>
136 constexpr T sqrt_2 = c10::detail::sqrt_2<T>();
137 
138 template <typename T>
139 constexpr T sqrt_3 = c10::detail::sqrt_3<T>();
140 } // namespace c10
141 
142 C10_CLANG_DIAGNOSTIC_POP()
143