1 #pragma once 2 3 #include <ATen/core/Array.h> 4 #include <ATen/core/TransformationHelper.h> 5 #include <c10/util/Half.h> 6 #include <c10/util/BFloat16.h> 7 #include <c10/util/MathConstants.h> 8 #include <c10/macros/Macros.h> 9 10 #include <cmath> 11 #include <limits> 12 #include <optional> 13 #include <type_traits> 14 15 /** 16 * Distributions kernel adapted from THRandom.cpp 17 * The kernels try to follow std::random distributions signature 18 * For instance: in ATen 19 * auto gen = at::detail::createCPUGenerator(); 20 * at::uniform_real_distribution<double> uniform(0, 1); 21 * auto sample = uniform(gen.get()); 22 * 23 * vs std::random 24 * 25 * std::mt19937 gen; 26 * std::uniform_real_distribution uniform(0, 1); 27 * auto sample = uniform(gen); 28 */ 29 30 31 namespace at { 32 namespace { 33 34 /** 35 * Samples a discrete uniform distribution in the range [base, base+range) of type T 36 */ 37 template <typename T> 38 struct uniform_int_from_to_distribution { 39 uniform_int_from_to_distributionuniform_int_from_to_distribution40 C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {} 41 42 template <typename RNG> operatoruniform_int_from_to_distribution43 C10_HOST_DEVICE inline T operator()(RNG generator) { 44 if (( 45 std::is_same<T, int64_t>::value || 46 std::is_same<T, double>::value || 47 std::is_same<T, float>::value || 48 std::is_same<T, at::BFloat16>::value) && range_ >= 1ULL << 32) 49 { 50 return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_); 51 } else { 52 return transformation::uniform_int_from_to<T>(generator->random(), range_, base_); 53 } 54 } 55 56 private: 57 uint64_t range_; 58 int64_t base_; 59 }; 60 61 /** 62 * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)] 63 */ 64 template <typename T> 65 struct uniform_int_full_range_distribution { 66 67 template <typename RNG> operatoruniform_int_full_range_distribution68 C10_HOST_DEVICE inline T operator()(RNG generator) { 69 return transformation::uniform_int_full_range<T>(generator->random64()); 70 } 71 72 }; 73 74 /** 75 * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types 76 * and [0, 2^mantissa] for floating-point types. 77 */ 78 template <typename T> 79 struct uniform_int_distribution { 80 81 template <typename RNG> operatoruniform_int_distribution82 C10_HOST_DEVICE inline T operator()(RNG generator) { 83 if constexpr (std::is_same_v<T, double> || std::is_same_v<T, int64_t>) { 84 return transformation::uniform_int<T>(generator->random64()); 85 } else { 86 return transformation::uniform_int<T>(generator->random()); 87 } 88 } 89 90 }; 91 92 /** 93 * Samples a uniform distribution in the range [from, to) of type T 94 */ 95 template <typename T> 96 struct uniform_real_distribution { 97 uniform_real_distributionuniform_real_distribution98 C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) { 99 TORCH_CHECK_IF_NOT_ON_CUDA(from <= to); 100 TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max()); 101 from_ = from; 102 to_ = to; 103 } 104 105 template <typename RNG> operatoruniform_real_distribution106 C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){ 107 if constexpr (std::is_same_v<T, double>) { 108 return transformation::uniform_real<T>(generator->random64(), from_, to_); 109 } else { 110 return transformation::uniform_real<T>(generator->random(), from_, to_); 111 } 112 } 113 114 private: 115 T from_; 116 T to_; 117 }; 118 119 // The SFINAE checks introduced in #39816 looks overcomplicated and must revisited 120 // https://github.com/pytorch/pytorch/issues/40052 121 #define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \ 122 template <typename T> \ 123 struct has_member_##member \ 124 { \ 125 typedef char yes; \ 126 typedef long no; \ 127 template <typename U> static yes test(decltype(&U::member)); \ 128 template <typename U> static no test(...); \ 129 static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \ 130 } 131 132 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample); 133 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample); 134 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample); 135 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample); 136 137 #define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \ 138 \ 139 template <typename RNG, typename ret_type, \ 140 typename std::enable_if_t<( \ 141 has_member_next_##TYPE##_normal_sample<RNG>::value && \ 142 has_member_set_next_##TYPE##_normal_sample<RNG>::value \ 143 ), int> = 0> \ 144 C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \ 145 if (generator->next_##TYPE##_normal_sample()) { \ 146 *ret = *(generator->next_##TYPE##_normal_sample()); \ 147 generator->set_next_##TYPE##_normal_sample(std::optional<TYPE>()); \ 148 return true; \ 149 } \ 150 return false; \ 151 } \ 152 \ 153 template <typename RNG, typename ret_type, \ 154 typename std::enable_if_t<( \ 155 !has_member_next_##TYPE##_normal_sample<RNG>::value || \ 156 !has_member_set_next_##TYPE##_normal_sample<RNG>::value \ 157 ), int> = 0> \ 158 C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) { \ 159 return false; \ 160 } \ 161 \ 162 template <typename RNG, typename ret_type, \ 163 typename std::enable_if_t<( \ 164 has_member_set_next_##TYPE##_normal_sample<RNG>::value \ 165 ), int> = 0> \ 166 C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \ 167 generator->set_next_##TYPE##_normal_sample(cache); \ 168 } \ 169 \ 170 template <typename RNG, typename ret_type, \ 171 typename std::enable_if_t<( \ 172 !has_member_set_next_##TYPE##_normal_sample<RNG>::value \ 173 ), int> = 0> \ 174 C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \ 175 } 176 177 DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double); 178 DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float); 179 180 /** 181 * Samples a normal distribution using the Box-Muller method 182 * Takes mean and standard deviation as inputs 183 * Note that Box-muller method returns two samples at a time. 184 * Hence, we cache the "next" sample in the CPUGeneratorImpl class. 185 */ 186 template <typename T> 187 struct normal_distribution { 188 normal_distributionnormal_distribution189 C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) { 190 TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); 191 mean = mean_in; 192 stdv = stdv_in; 193 } 194 195 template <typename RNG> operatornormal_distribution196 C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){ 197 dist_acctype<T> ret; 198 // return cached values if available 199 if constexpr (std::is_same_v<T, double>) { 200 if (maybe_get_next_double_normal_sample(generator, &ret)) { 201 return transformation::normal(ret, mean, stdv); 202 } 203 } else { 204 if (maybe_get_next_float_normal_sample(generator, &ret)) { 205 return transformation::normal(ret, mean, stdv); 206 } 207 } 208 // otherwise generate new normal values 209 uniform_real_distribution<T> uniform(0.0, 1.0); 210 const dist_acctype<T> u1 = uniform(generator); 211 const dist_acctype<T> u2 = uniform(generator); 212 const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log1p(-u2)); 213 const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1; 214 if constexpr (std::is_same_v<T, double>) { 215 maybe_set_next_double_normal_sample(generator, r * ::sin(theta)); 216 } else { 217 maybe_set_next_float_normal_sample(generator, r * ::sin(theta)); 218 } 219 ret = r * ::cos(theta); 220 return transformation::normal(ret, mean, stdv); 221 } 222 223 private: 224 T mean; 225 T stdv; 226 }; 227 228 template <typename T> 229 struct DiscreteDistributionType { using type = float; }; 230 231 template <> struct DiscreteDistributionType<double> { using type = double; }; 232 233 /** 234 * Samples a bernoulli distribution given a probability input 235 */ 236 template <typename T> 237 struct bernoulli_distribution { 238 239 C10_HOST_DEVICE inline bernoulli_distribution(T p_in) { 240 TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1); 241 p = p_in; 242 } 243 244 template <typename RNG> 245 C10_HOST_DEVICE inline T operator()(RNG generator) { 246 uniform_real_distribution<T> uniform(0.0, 1.0); 247 return transformation::bernoulli<T>(uniform(generator), p); 248 } 249 250 private: 251 T p; 252 }; 253 254 /** 255 * Samples a geometric distribution given a probability input 256 */ 257 template <typename T> 258 struct geometric_distribution { 259 260 C10_HOST_DEVICE inline geometric_distribution(T p_in) { 261 TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1); 262 p = p_in; 263 } 264 265 template <typename RNG> 266 C10_HOST_DEVICE inline T operator()(RNG generator) { 267 uniform_real_distribution<T> uniform(0.0, 1.0); 268 return transformation::geometric<T>(uniform(generator), p); 269 } 270 271 private: 272 T p; 273 }; 274 275 /** 276 * Samples an exponential distribution given a lambda input 277 */ 278 template <typename T> 279 struct exponential_distribution { 280 281 C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {} 282 283 template <typename RNG> 284 C10_HOST_DEVICE inline T operator()(RNG generator) { 285 uniform_real_distribution<T> uniform(0.0, 1.0); 286 return transformation::exponential<T>(uniform(generator), lambda); 287 } 288 289 private: 290 T lambda; 291 }; 292 293 /** 294 * Samples a cauchy distribution given median and sigma as inputs 295 */ 296 template <typename T> 297 struct cauchy_distribution { 298 299 C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {} 300 301 template <typename RNG> 302 C10_HOST_DEVICE inline T operator()(RNG generator) { 303 uniform_real_distribution<T> uniform(0.0, 1.0); 304 return transformation::cauchy<T>(uniform(generator), median, sigma); 305 } 306 307 private: 308 T median; 309 T sigma; 310 }; 311 312 /** 313 * Samples a lognormal distribution 314 * Takes mean and standard deviation as inputs 315 * Outputs two samples at a time 316 */ 317 template <typename T> 318 struct lognormal_distribution { 319 320 C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) { 321 TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); 322 mean = mean_in; 323 stdv = stdv_in; 324 } 325 326 template<typename RNG> 327 C10_HOST_DEVICE inline T operator()(RNG generator){ 328 normal_distribution<T> normal(mean, stdv); 329 return transformation::log_normal<T>(normal(generator)); 330 } 331 332 private: 333 T mean; 334 T stdv; 335 }; 336 } 337 } // namespace at 338