xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/DistributionsHelper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Array.h>
4 #include <ATen/core/TransformationHelper.h>
5 #include <c10/util/Half.h>
6 #include <c10/util/BFloat16.h>
7 #include <c10/util/MathConstants.h>
8 #include <c10/macros/Macros.h>
9 
10 #include <cmath>
11 #include <limits>
12 #include <optional>
13 #include <type_traits>
14 
15 /**
16  * Distributions kernel adapted from THRandom.cpp
17  * The kernels try to follow std::random distributions signature
18  * For instance: in ATen
19  *      auto gen = at::detail::createCPUGenerator();
20  *      at::uniform_real_distribution<double> uniform(0, 1);
21  *      auto sample = uniform(gen.get());
22  *
23  *      vs std::random
24  *
25  *      std::mt19937 gen;
26  *      std::uniform_real_distribution uniform(0, 1);
27  *      auto sample = uniform(gen);
28  */
29 
30 
31 namespace at {
32 namespace {
33 
34 /**
35  * Samples a discrete uniform distribution in the range [base, base+range) of type T
36  */
37 template <typename T>
38 struct uniform_int_from_to_distribution {
39 
uniform_int_from_to_distributionuniform_int_from_to_distribution40   C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {}
41 
42   template <typename RNG>
operatoruniform_int_from_to_distribution43   C10_HOST_DEVICE inline T operator()(RNG generator) {
44     if ((
45       std::is_same<T, int64_t>::value ||
46       std::is_same<T, double>::value ||
47       std::is_same<T, float>::value ||
48       std::is_same<T, at::BFloat16>::value) && range_ >= 1ULL << 32)
49     {
50       return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
51     } else {
52       return transformation::uniform_int_from_to<T>(generator->random(), range_, base_);
53     }
54   }
55 
56   private:
57     uint64_t range_;
58     int64_t base_;
59 };
60 
61 /**
62  * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
63  */
64 template <typename T>
65 struct uniform_int_full_range_distribution {
66 
67   template <typename RNG>
operatoruniform_int_full_range_distribution68   C10_HOST_DEVICE inline T operator()(RNG generator) {
69     return transformation::uniform_int_full_range<T>(generator->random64());
70   }
71 
72 };
73 
74 /**
75  * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
76  * and [0, 2^mantissa] for floating-point types.
77  */
78 template <typename T>
79 struct uniform_int_distribution {
80 
81   template <typename RNG>
operatoruniform_int_distribution82   C10_HOST_DEVICE inline T operator()(RNG generator) {
83     if constexpr (std::is_same_v<T, double> || std::is_same_v<T, int64_t>) {
84       return transformation::uniform_int<T>(generator->random64());
85     } else {
86       return transformation::uniform_int<T>(generator->random());
87     }
88   }
89 
90 };
91 
92 /**
93  * Samples a uniform distribution in the range [from, to) of type T
94  */
95 template <typename T>
96 struct uniform_real_distribution {
97 
uniform_real_distributionuniform_real_distribution98   C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) {
99     TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
100     TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max());
101     from_ = from;
102     to_ = to;
103   }
104 
105   template <typename RNG>
operatoruniform_real_distribution106   C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
107     if constexpr (std::is_same_v<T, double>) {
108       return transformation::uniform_real<T>(generator->random64(), from_, to_);
109     } else {
110       return transformation::uniform_real<T>(generator->random(), from_, to_);
111     }
112   }
113 
114   private:
115     T from_;
116     T to_;
117 };
118 
119 // The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
120 // https://github.com/pytorch/pytorch/issues/40052
121 #define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member)              \
122 template <typename T>                                                \
123 struct has_member_##member                                           \
124 {                                                                    \
125     typedef char yes;                                                \
126     typedef long no;                                                 \
127     template <typename U> static yes test(decltype(&U::member));     \
128     template <typename U> static no test(...);                       \
129     static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
130 }
131 
132 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
133 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
134 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
135 DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
136 
137 #define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE)                                      \
138                                                                                                     \
139 template <typename RNG, typename ret_type,                                                          \
140           typename std::enable_if_t<(                                                               \
141             has_member_next_##TYPE##_normal_sample<RNG>::value &&                                   \
142             has_member_set_next_##TYPE##_normal_sample<RNG>::value                                  \
143           ), int> = 0>                                                                              \
144 C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) {  \
145   if (generator->next_##TYPE##_normal_sample()) {                                                   \
146     *ret = *(generator->next_##TYPE##_normal_sample());                                             \
147     generator->set_next_##TYPE##_normal_sample(std::optional<TYPE>());                              \
148     return true;                                                                                    \
149   }                                                                                                 \
150   return false;                                                                                     \
151 }                                                                                                   \
152                                                                                                     \
153 template <typename RNG, typename ret_type,                                                          \
154           typename std::enable_if_t<(                                                               \
155             !has_member_next_##TYPE##_normal_sample<RNG>::value ||                                  \
156             !has_member_set_next_##TYPE##_normal_sample<RNG>::value                                 \
157           ), int> = 0>                                                                              \
158 C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) {  \
159   return false;                                                                                     \
160 }                                                                                                   \
161                                                                                                     \
162 template <typename RNG, typename ret_type,                                                          \
163           typename std::enable_if_t<(                                                               \
164             has_member_set_next_##TYPE##_normal_sample<RNG>::value                                  \
165           ), int> = 0>                                                                              \
166 C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
167   generator->set_next_##TYPE##_normal_sample(cache);                                                \
168 }                                                                                                   \
169                                                                                                     \
170 template <typename RNG, typename ret_type,                                                          \
171           typename std::enable_if_t<(                                                               \
172             !has_member_set_next_##TYPE##_normal_sample<RNG>::value                                 \
173           ), int> = 0>                                                                              \
174 C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \
175 }
176 
177 DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double);
178 DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float);
179 
180 /**
181  * Samples a normal distribution using the Box-Muller method
182  * Takes mean and standard deviation as inputs
183  * Note that Box-muller method returns two samples at a time.
184  * Hence, we cache the "next" sample in the CPUGeneratorImpl class.
185  */
186 template <typename T>
187 struct normal_distribution {
188 
normal_distributionnormal_distribution189   C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) {
190     TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in);
191     mean = mean_in;
192     stdv = stdv_in;
193   }
194 
195   template <typename RNG>
operatornormal_distribution196   C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
197     dist_acctype<T> ret;
198     // return cached values if available
199     if constexpr (std::is_same_v<T, double>) {
200       if (maybe_get_next_double_normal_sample(generator, &ret)) {
201         return transformation::normal(ret, mean, stdv);
202       }
203     } else {
204       if (maybe_get_next_float_normal_sample(generator, &ret)) {
205         return transformation::normal(ret, mean, stdv);
206       }
207     }
208     // otherwise generate new normal values
209     uniform_real_distribution<T> uniform(0.0, 1.0);
210     const dist_acctype<T> u1 = uniform(generator);
211     const dist_acctype<T> u2 = uniform(generator);
212     const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log1p(-u2));
213     const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1;
214     if constexpr (std::is_same_v<T, double>) {
215       maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
216     } else {
217       maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
218     }
219     ret = r * ::cos(theta);
220     return transformation::normal(ret, mean, stdv);
221   }
222 
223   private:
224     T mean;
225     T stdv;
226 };
227 
228 template <typename T>
229 struct DiscreteDistributionType { using type = float; };
230 
231 template <> struct DiscreteDistributionType<double> { using type = double; };
232 
233 /**
234  * Samples a bernoulli distribution given a probability input
235  */
236 template <typename T>
237 struct bernoulli_distribution {
238 
239   C10_HOST_DEVICE inline bernoulli_distribution(T p_in) {
240     TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1);
241     p = p_in;
242   }
243 
244   template <typename RNG>
245   C10_HOST_DEVICE inline T operator()(RNG generator) {
246     uniform_real_distribution<T> uniform(0.0, 1.0);
247     return transformation::bernoulli<T>(uniform(generator), p);
248   }
249 
250   private:
251     T p;
252 };
253 
254 /**
255  * Samples a geometric distribution given a probability input
256  */
257 template <typename T>
258 struct geometric_distribution {
259 
260   C10_HOST_DEVICE inline geometric_distribution(T p_in) {
261     TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1);
262     p = p_in;
263   }
264 
265   template <typename RNG>
266   C10_HOST_DEVICE inline T operator()(RNG generator) {
267     uniform_real_distribution<T> uniform(0.0, 1.0);
268     return transformation::geometric<T>(uniform(generator), p);
269   }
270 
271   private:
272     T p;
273 };
274 
275 /**
276  * Samples an exponential distribution given a lambda input
277  */
278 template <typename T>
279 struct exponential_distribution {
280 
281   C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {}
282 
283   template <typename RNG>
284   C10_HOST_DEVICE inline T operator()(RNG generator) {
285     uniform_real_distribution<T> uniform(0.0, 1.0);
286     return transformation::exponential<T>(uniform(generator), lambda);
287   }
288 
289   private:
290     T lambda;
291 };
292 
293 /**
294  * Samples a cauchy distribution given median and sigma as inputs
295  */
296 template <typename T>
297 struct cauchy_distribution {
298 
299   C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {}
300 
301   template <typename RNG>
302   C10_HOST_DEVICE inline T operator()(RNG generator) {
303     uniform_real_distribution<T> uniform(0.0, 1.0);
304     return transformation::cauchy<T>(uniform(generator), median, sigma);
305   }
306 
307   private:
308     T median;
309     T sigma;
310 };
311 
312 /**
313  * Samples a lognormal distribution
314  * Takes mean and standard deviation as inputs
315  * Outputs two samples at a time
316  */
317 template <typename T>
318 struct lognormal_distribution {
319 
320   C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) {
321     TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
322     mean = mean_in;
323     stdv = stdv_in;
324   }
325 
326   template<typename RNG>
327   C10_HOST_DEVICE inline T operator()(RNG generator){
328     normal_distribution<T> normal(mean, stdv);
329     return transformation::log_normal<T>(normal(generator));
330   }
331 
332   private:
333     T mean;
334     T stdv;
335 };
336 }
337 } // namespace at
338