1 #pragma once 2 3 #include <c10/macros/Macros.h> 4 #include <c10/util/BFloat16.h> 5 #include <c10/util/Half.h> 6 7 C10_CLANG_DIAGNOSTIC_PUSH() 8 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") 9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") 10 #endif 11 12 namespace c10 { 13 // TODO: Replace me with inline constexpr variable when C++17 becomes available 14 namespace detail { 15 template <typename T> e()16C10_HOST_DEVICE inline constexpr T e() { 17 return static_cast<T>(2.718281828459045235360287471352662); 18 } 19 20 template <typename T> euler()21C10_HOST_DEVICE inline constexpr T euler() { 22 return static_cast<T>(0.577215664901532860606512090082402); 23 } 24 25 template <typename T> frac_1_pi()26C10_HOST_DEVICE inline constexpr T frac_1_pi() { 27 return static_cast<T>(0.318309886183790671537767526745028); 28 } 29 30 template <typename T> frac_1_sqrt_pi()31C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() { 32 return static_cast<T>(0.564189583547756286948079451560772); 33 } 34 35 template <typename T> frac_sqrt_2()36C10_HOST_DEVICE inline constexpr T frac_sqrt_2() { 37 return static_cast<T>(0.707106781186547524400844362104849); 38 } 39 40 template <typename T> frac_sqrt_3()41C10_HOST_DEVICE inline constexpr T frac_sqrt_3() { 42 return static_cast<T>(0.577350269189625764509148780501957); 43 } 44 45 template <typename T> golden_ratio()46C10_HOST_DEVICE inline constexpr T golden_ratio() { 47 return static_cast<T>(1.618033988749894848204586834365638); 48 } 49 50 template <typename T> ln_10()51C10_HOST_DEVICE inline constexpr T ln_10() { 52 return static_cast<T>(2.302585092994045684017991454684364); 53 } 54 55 template <typename T> ln_2()56C10_HOST_DEVICE inline constexpr T ln_2() { 57 return static_cast<T>(0.693147180559945309417232121458176); 58 } 59 60 template <typename T> log_10_e()61C10_HOST_DEVICE inline constexpr T log_10_e() { 62 return static_cast<T>(0.434294481903251827651128918916605); 63 } 64 65 template <typename T> log_2_e()66C10_HOST_DEVICE inline constexpr T log_2_e() { 67 return static_cast<T>(1.442695040888963407359924681001892); 68 } 69 70 template <typename T> pi()71C10_HOST_DEVICE inline constexpr T pi() { 72 return static_cast<T>(3.141592653589793238462643383279502); 73 } 74 75 template <typename T> sqrt_2()76C10_HOST_DEVICE inline constexpr T sqrt_2() { 77 return static_cast<T>(1.414213562373095048801688724209698); 78 } 79 80 template <typename T> sqrt_3()81C10_HOST_DEVICE inline constexpr T sqrt_3() { 82 return static_cast<T>(1.732050807568877293527446341505872); 83 } 84 85 template <> 86 C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() { 87 // According to 88 // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values 89 // pi is encoded as 4049 90 return BFloat16(0x4049, BFloat16::from_bits()); 91 } 92 93 template <> 94 C10_HOST_DEVICE inline constexpr Half pi<Half>() { 95 return Half(0x4248, Half::from_bits()); 96 } 97 } // namespace detail 98 99 template <typename T> 100 constexpr T e = c10::detail::e<T>(); 101 102 template <typename T> 103 constexpr T euler = c10::detail::euler<T>(); 104 105 template <typename T> 106 constexpr T frac_1_pi = c10::detail::frac_1_pi<T>(); 107 108 template <typename T> 109 constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>(); 110 111 template <typename T> 112 constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>(); 113 114 template <typename T> 115 constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>(); 116 117 template <typename T> 118 constexpr T golden_ratio = c10::detail::golden_ratio<T>(); 119 120 template <typename T> 121 constexpr T ln_10 = c10::detail::ln_10<T>(); 122 123 template <typename T> 124 constexpr T ln_2 = c10::detail::ln_2<T>(); 125 126 template <typename T> 127 constexpr T log_10_e = c10::detail::log_10_e<T>(); 128 129 template <typename T> 130 constexpr T log_2_e = c10::detail::log_2_e<T>(); 131 132 template <typename T> 133 constexpr T pi = c10::detail::pi<T>(); 134 135 template <typename T> 136 constexpr T sqrt_2 = c10::detail::sqrt_2<T>(); 137 138 template <typename T> 139 constexpr T sqrt_3 = c10::detail::sqrt_3<T>(); 140 } // namespace c10 141 142 C10_CLANG_DIAGNOSTIC_POP() 143