xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Distributions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/TensorOperators.h>
6 #include <c10/util/Exception.h>
7 #include <optional>
8 
9 #include <ATen/CPUGeneratorImpl.h>
10 #include <ATen/core/DistributionsHelper.h>
11 #include <ATen/native/Distributions.h>
12 #include <ATen/native/DispatchStub.h>
13 #include <ATen/native/UnaryOps.h>
14 #include <ATen/native/DistributionTemplates.h>
15 #include <ATen/NamedTensorUtils.h>
16 #include <ATen/native/cpu/Loops.h>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/_dirichlet_grad_native.h>
23 #include <ATen/ops/_sample_dirichlet_native.h>
24 #include <ATen/ops/_standard_gamma_grad_native.h>
25 #include <ATen/ops/_standard_gamma_native.h>
26 #include <ATen/ops/_assert_async.h>
27 #include <ATen/ops/argmax.h>
28 #include <ATen/ops/bernoulli_native.h>
29 #include <ATen/ops/binomial_native.h>
30 #include <ATen/ops/cauchy_native.h>
31 #include <ATen/ops/div.h>
32 #include <ATen/ops/empty.h>
33 #include <ATen/ops/empty_like.h>
34 #include <ATen/ops/exponential_native.h>
35 #include <ATen/ops/geometric_native.h>
36 #include <ATen/ops/log_normal_native.h>
37 #include <ATen/ops/multinomial_native.h>
38 #include <ATen/ops/normal_native.h>
39 #include <ATen/ops/poisson_native.h>
40 #include <ATen/ops/random_native.h>
41 #include <ATen/ops/topk.h>
42 #include <ATen/ops/uniform_native.h>
43 #include <ATen/ops/zeros.h>
44 #endif
45 
46 #include <utility>
47 
48 namespace {
49 /*
50  * This section is a counterpart to Distributions.cu
51  *
52  */
53 
54 // The function `sample_poisson`
55 // is adapted from Numpy's distributions.c implementation.
56 // It is MIT licensed, so here is the copyright:
57 
58 /* Copyright 2005 Robert Kern ([email protected])
59  *
60  * Permission is hereby granted, free of charge, to any person obtaining a
61  * copy of this software and associated documentation files (the
62  * "Software"), to deal in the Software without restriction, including
63  * without limitation the rights to use, copy, modify, merge, publish,
64  * distribute, sublicense, and/or sell copies of the Software, and to
65  * permit persons to whom the Software is furnished to do so, subject to
66  * the following conditions:
67  *
68  * The above copyright notice and this permission notice shall be included
69  * in all copies or substantial portions of the Software.
70  *
71  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
72  * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
73  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
74  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
75  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
76  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
77  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
78  */
79 
80 
sample_poisson(double lambda,at::CPUGeneratorImpl * generator)81 int64_t sample_poisson(double lambda, at::CPUGeneratorImpl* generator) {
82   TORCH_CHECK(lambda >= 0, "invalid Poisson rate, expected rate to be non-negative");
83   at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
84   if (lambda >= 10) {
85     // transformed rejection method, (Hoermann, 1993)
86 
87     double slam = std::sqrt(lambda);
88     double loglam = std::log(lambda);
89     double b = 0.931 + 2.53 * slam;
90     double a = -0.059 + 0.02483 * b;
91     double invalpha = 1.1239 + 1.1328 / (b - 3.4);
92     double vr = 0.9277 - 3.6224 / (b - 2);
93 
94     while (true) {
95       double U = standard_uniform(generator) - 0.5;
96       double V = standard_uniform(generator);
97       double us = 0.5 - std::fabs(U);
98       auto k = std::floor((2 * a / us + b) * U + lambda + 0.43);
99       if ((us >= 0.07) && (V <= vr)) {
100         return static_cast<int64_t>(k);
101       }
102       if ((k < 0) || ((us < 0.013) && (V > us))) {
103         continue;
104       }
105       if ((std::log(V) + std::log(invalpha) - std::log(a / (us * us) + b)) <=
106           (-lambda + k * loglam - std::lgamma(k + 1))) {
107         return static_cast<int64_t>(k);
108       }
109     }
110   } else if (lambda == 0) {
111     return 0;
112   } else {
113     auto enlam = std::exp(-lambda);
114     int64_t X = 0;
115     auto prod = 1.0;
116     while (true) {
117       auto U = standard_uniform(generator);
118       prod *= U;
119       if (prod > enlam) {
120         X += 1;
121       } else {
122         return X;
123       }
124     }
125   }
126 }
127 
128 } // namespace
129 
130 namespace at::native {
131 
132 DEFINE_DISPATCH(bernoulli_tensor_stub);
133 DEFINE_DISPATCH(bernoulli_scalar_stub);
134 DEFINE_DISPATCH(cauchy_stub);
135 DEFINE_DISPATCH(exponential_stub);
136 DEFINE_DISPATCH(multinomial_with_replacement_stub);
137 DEFINE_DISPATCH(geometric_stub);
138 DEFINE_DISPATCH(log_normal_stub);
139 DEFINE_DISPATCH(uniform_stub);
140 DEFINE_DISPATCH(normal_stub);
141 DEFINE_DISPATCH(random_stub);
142 DEFINE_DISPATCH(random_from_to_stub);
143 DEFINE_DISPATCH(random_full_64_bits_range_stub);
144 
145 // ==================================================== Bernoulli =====================================================
146 
147 template<typename RNG>
148 struct BernoulliStub {
operator ()at::native::BernoulliStub149   void operator()(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
150     bernoulli_tensor_stub(self.device().type(), self, p_, gen);
151   }
152 
operator ()at::native::BernoulliStub153   void operator()(Tensor& self, double p, std::optional<Generator> gen) {
154     bernoulli_scalar_stub(self.device().type(), self, p, gen);
155   }
156 };
157 
bernoulli(const Tensor & self,std::optional<Generator> gen)158 Tensor bernoulli(const Tensor& self, std::optional<Generator> gen) {
159   Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
160   result.bernoulli_(self, std::move(gen));
161   return result;
162 }
163 
bernoulli(const Tensor & self,double p,std::optional<Generator> gen)164 Tensor bernoulli(const Tensor& self, double p, std::optional<Generator> gen) {
165   Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
166   result.bernoulli_(p, std::move(gen));
167   return result;
168 }
169 
bernoulli_out(const Tensor & self,std::optional<Generator> gen,Tensor & result)170 Tensor& bernoulli_out(const Tensor& self, std::optional<Generator> gen, Tensor& result) {
171   return at::native::templates::bernoulli_out_impl<BernoulliStub, Generator>(result, self, std::move(gen));
172 }
173 
bernoulli_(Tensor & self,const Tensor & p_,std::optional<Generator> gen)174 Tensor& bernoulli_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
175   return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p_, std::move(gen));
176 }
177 
bernoulli_(Tensor & self,double p,std::optional<Generator> gen)178 Tensor& bernoulli_(Tensor& self, double p, std::optional<Generator> gen) {
179   return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p, std::move(gen));
180 }
181 
182 // ================================================== LogNormal =======================================================
183 
184 template<typename RNG>
185 struct LogNormalStub {
operator ()at::native::LogNormalStub186   void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
187     log_normal_stub(iter.device_type(), iter, mean, std, gen);
188   }
189 };
190 
log_normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)191 Tensor& log_normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
192   return at::native::templates::log_normal_impl_<LogNormalStub, Generator>(self, mean, std, std::move(gen));
193 }
194 
195 // ==================================================== Cauchy ========================================================
196 
197 template<typename RNG>
198 struct CauchyStub {
operator ()at::native::CauchyStub199   void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
200     cauchy_stub(iter.device_type(), iter, median, sigma, gen);
201   }
202 };
203 
cauchy_(Tensor & self,double median,double sigma,std::optional<Generator> gen)204 Tensor& cauchy_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
205   return at::native::templates::cauchy_impl_<CauchyStub, Generator>(self, median, sigma, std::move(gen));
206 }
207 
208 // ================================================== Exponential =====================================================
209 
210 template<typename RNG>
211 struct ExponentialStub {
operator ()at::native::ExponentialStub212   void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
213     exponential_stub(iter.device_type(), iter, lambda, gen);
214   }
215 };
216 
exponential_(Tensor & self,double lambda,std::optional<Generator> gen)217 Tensor& exponential_(Tensor& self, double lambda, std::optional<Generator> gen) {
218   return at::native::templates::exponential_impl_<ExponentialStub, Generator>(self, lambda, std::move(gen));
219 }
220 
221 // =================================================== Geometric ======================================================
222 
223 template<typename RNG>
224 struct GeometricStub {
operator ()at::native::GeometricStub225   void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
226     geometric_stub(iter.device_type(), iter, p, gen);
227   }
228 };
229 
geometric_(Tensor & self,double p,std::optional<Generator> gen)230 Tensor& geometric_(Tensor& self, double p, std::optional<Generator> gen) {
231   return at::native::templates::geometric_impl_<GeometricStub, Generator>(self, p, std::move(gen));
232 }
233 
234 // ==================================================== Uniform =======================================================
235 
236 template<typename RNG>
237 struct UniformStub {
operator ()at::native::UniformStub238   void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
239     uniform_stub(iter.device_type(), iter, from, to, gen);
240   }
241 };
242 
243 template<typename RNG>
244 struct UniformMeta {
245   // No-op!
operator ()at::native::UniformMeta246   void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
247   }
248 };
249 
uniform_(Tensor & self,double from,double to,std::optional<Generator> gen)250 Tensor& uniform_(Tensor& self, double from, double to, std::optional<Generator> gen) {
251   return at::native::templates::uniform_impl_<UniformStub, Generator>(self, from, to, std::move(gen));
252 }
253 
uniform_meta_(Tensor & self,double from,double to,std::optional<Generator> gen)254 Tensor& uniform_meta_(Tensor& self, double from, double to, std::optional<Generator> gen) {
255   return at::native::templates::uniform_impl_<UniformMeta, Generator>(self, from, to, std::move(gen));
256 }
257 
258 // ==================================================== Normal ========================================================
259 
260 template<typename RNG>
261 struct NormalStub {
operator ()at::native::NormalStub262   void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
263     normal_stub(self.device().type(), self, mean, std, gen);
264   }
265 };
266 
267 template<typename RNG>
268 struct NormalMeta {
269   // No-op!
operator ()at::native::NormalMeta270   void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
271   }
272 };
273 
274 // inplace
normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)275 Tensor& normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
276   return at::native::templates::normal_impl_<NormalStub, Generator>(self, mean, std, std::move(gen));
277 }
278 
normal_meta_(Tensor & self,double mean,double std,std::optional<Generator> gen)279 Tensor& normal_meta_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
280   return at::native::templates::normal_impl_<NormalMeta, Generator>(self, mean, std, std::move(gen));
281 }
282 
283 // out tensor float
normal_out(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)284 Tensor& normal_out(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
285   return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
286 }
287 
normal_out_meta(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)288 Tensor& normal_out_meta(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
289   return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
290 }
291 
292 // out float tensor
normal_out(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)293 Tensor& normal_out(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
294   return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
295 }
296 
normal_out_meta(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)297 Tensor& normal_out_meta(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
298   return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
299 
300 }
301 
302 // out tensor tensor
normal_out(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)303 Tensor& normal_out(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
304   return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
305 }
306 
normal_out_meta(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)307 Tensor& normal_out_meta(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
308   return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
309 }
310 
311 // functional tensor float
normal(const Tensor & mean,double std,std::optional<Generator> gen)312 Tensor normal(const Tensor& mean, double std, std::optional<Generator> gen) {
313   return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
314 }
315 
normal_meta(const Tensor & mean,double std,std::optional<Generator> gen)316 Tensor normal_meta(const Tensor& mean, double std, std::optional<Generator> gen) {
317   return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
318 }
319 
320 // functional float tensor
normal(double mean,const Tensor & std,std::optional<Generator> gen)321 Tensor normal(double mean, const Tensor& std, std::optional<Generator> gen) {
322   return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
323 }
324 
normal_meta(double mean,const Tensor & std,std::optional<Generator> gen)325 Tensor normal_meta(double mean, const Tensor& std, std::optional<Generator> gen) {
326   return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
327 }
328 
329 // functional tensor tensor
normal(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)330 Tensor normal(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
331   return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
332 }
333 
normal_meta(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)334 Tensor normal_meta(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
335   return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
336 }
337 
338 // functional variant, only used by the functionalization pass.
normal_functional(const Tensor & self,double mean,double std,std::optional<at::Generator> generator)339 Tensor normal_functional(const Tensor& self, double mean, double std, std::optional<at::Generator> generator) {
340   return self.clone().normal_(mean, std, std::move(generator));
341 }
342 
343 // ==================================================== Random ========================================================
344 
345 template<typename RNG>
346 struct RandomStub {
operator ()at::native::RandomStub347   void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
348     random_stub(iter.device_type(), iter, gen);
349   }
350 };
351 
random_(Tensor & self,std::optional<Generator> gen)352 Tensor& random_(Tensor& self, std::optional<Generator> gen) {
353   return at::native::templates::random_impl<RandomStub, Generator>(self, std::move(gen));
354 }
355 
356 template<typename RNG>
357 struct RandomFromToStub {
operator ()at::native::RandomFromToStub358   void operator()(TensorIteratorBase& iter, uint64_t range, int64_t from, std::optional<Generator> gen) {
359     random_from_to_stub(iter.device_type(), iter, range, from, gen);
360   }
operator ()at::native::RandomFromToStub361   void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
362     random_full_64_bits_range_stub(iter.device_type(), iter, gen);
363   }
364 };
365 
random_(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> gen)366 Tensor& random_(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> gen) {
367   return at::native::templates::random_from_to_impl<RandomFromToStub, Generator>(self, from, to, std::move(gen));
368 }
369 
random_(Tensor & self,int64_t to,std::optional<Generator> gen)370 Tensor& random_(Tensor& self, int64_t to, std::optional<Generator> gen) {
371   return random_(self, 0, to, std::move(gen));
372 }
373 
random_meta_(Tensor & self,std::optional<Generator> gen)374 Tensor& random_meta_(Tensor& self, std::optional<Generator> gen) {
375   // No error checking yay
376   return self;
377 }
378 
random_meta_(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> gen)379 Tensor& random_meta_(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> gen) {
380   // No error checking yay
381   return self;
382 }
383 
random_meta_(Tensor & self,int64_t to,std::optional<Generator> gen)384 Tensor& random_meta_(Tensor& self, int64_t to, std::optional<Generator> gen) {
385   // No error checking yay
386   return self;
387 }
388 
389 // ====================================================================================================================
390 
_standard_gamma_grad_cpu(const Tensor & self,const Tensor & output)391 Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
392   Tensor ret = at::empty(self.sizes(), self.options());
393   auto iter = TensorIteratorConfig()
394     .add_output(ret)
395     .add_input(self)
396     .add_input(output)
397     .build();
398   AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] {
399     cpu_serial_kernel(iter, [](scalar_t self_val, scalar_t output_val) -> scalar_t{
400       return standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
401     });
402   });
403   return ret;
404 }
405 
_dirichlet_grad_cpu(const Tensor & x,const Tensor & alpha,const Tensor & total)406 Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) {
407   Tensor ret = at::empty(x.sizes(), x.options());
408   auto iter = TensorIteratorConfig()
409     .add_output(ret)
410     .add_input(x)
411     .add_input(alpha)
412     .add_input(total)
413     .build();
414   AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] {
415     cpu_serial_kernel(iter, [](scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t{
416       return dirichlet_grad_one<scalar_t, double>(x_val, alpha_val, total_val);
417     });
418   });
419   return ret;
420 }
421 
422 /*
423  * This section is a counterpart to Distributions.cu
424  */
425 
_s_binomial_cpu(const Tensor & count,const Tensor & prob,std::optional<Generator> gen)426 Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, std::optional<Generator> gen) {
427   Tensor ret = at::zeros(count.sizes(), count.options());
428   auto iter = TensorIteratorConfig()
429     .add_output(ret)
430     .add_input(count)
431     .add_input(prob)
432     .build();
433   AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "binomial_cpu", [&] {
434     CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
435     // See Note [Acquire lock when using random generators]
436     std::lock_guard<std::mutex> lock(generator->mutex_);
437     cpu_serial_kernel(iter, [generator](scalar_t count_val, scalar_t prob_val) -> scalar_t{
438       auto uniform_lambda = [generator] () {
439         at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
440         return standard_uniform(generator);
441       };
442       BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
443 
444       auto sample = sample_binomial<scalar_t, double, decltype(uniform_lambda)>(count_val, prob_val, standard_uniform);
445       return static_cast<scalar_t>(sample);
446     });
447   });
448   return ret;
449 }
450 
_s_poisson_cpu(const Tensor & lambda,std::optional<Generator> gen)451 Tensor _s_poisson_cpu(const Tensor& lambda, std::optional<Generator> gen) {
452   Tensor ret = at::zeros(lambda.sizes(), lambda.options());
453   auto iter = TensorIteratorConfig()
454     .add_output(ret)
455     .add_input(lambda)
456     .build();
457   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, ret.scalar_type(), "poisson_cpu", [&] {
458     CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
459     // See Note [Acquire lock when using random generators]
460     std::lock_guard<std::mutex> lock(generator->mutex_);
461     cpu_serial_kernel(iter, [generator](scalar_t lambda_val) -> scalar_t{
462       return static_cast<scalar_t>(sample_poisson(static_cast<double>(lambda_val), generator));
463     });
464   });
465   return ret;
466 }
467 
_s_gamma_cpu(const Tensor & alpha,std::optional<Generator> gen)468 Tensor _s_gamma_cpu(const Tensor& alpha, std::optional<Generator> gen) {
469   Tensor ret = at::zeros(alpha.sizes(), alpha.options());
470   auto iter = TensorIteratorConfig()
471     .add_output(ret)
472     .add_input(alpha)
473     .build();
474   AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] {
475     CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
476     // See Note [Acquire lock when using random generators]
477     std::lock_guard<std::mutex> lock(generator->mutex_);
478     cpu_serial_kernel(iter, [generator](scalar_t alpha_val) -> scalar_t{
479       auto uniform_lambda = [generator] () {
480         at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
481         return standard_uniform(generator);
482       };
483       BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
484 
485       auto normal_lambda = [generator] () {
486         at::normal_distribution<double> normal(0.0, 1.0);
487         return normal(generator);
488       };
489       BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
490       auto sample = sample_gamma<scalar_t, double, decltype(uniform_lambda), decltype(normal_lambda)>(alpha_val, standard_uniform, standard_normal);
491       return std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);
492     });
493   });
494 
495   return ret;
496 }
497 
_s_dirichlet_cpu(const Tensor & alpha,std::optional<Generator> gen)498 Tensor _s_dirichlet_cpu(const Tensor& alpha, std::optional<Generator> gen) {
499   Tensor ret = at::zeros(alpha.sizes(), alpha.options());
500   AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "dirichlet", [&] {
501     Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double));
502     CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
503     // See Note [Acquire lock when using random generators]
504     std::lock_guard<std::mutex> lock(generator->mutex_);
505     /* Generate gamma sample by casting alpha to double to prevent underflow. */
506     auto iter1 = TensorIteratorConfig()
507       .add_output(gamma)
508       .add_input(alpha)
509       .check_all_same_dtype(false)
510       .build();
511     cpu_serial_kernel(iter1, [generator](scalar_t alpha_val) -> double{
512       auto uniform_lambda = [generator] () {
513         at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
514         return standard_uniform(generator);
515       };
516       BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
517 
518       auto normal_lambda = [generator] () {
519         at::normal_distribution<double> normal(0.0, 1.0);
520         return normal(generator);
521       };
522       BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
523       auto sample = sample_gamma<double, double, decltype(uniform_lambda), decltype(normal_lambda)>
524         (alpha_val, standard_uniform, standard_normal);
525       return std::max(std::numeric_limits<double>::min(), sample);
526     });
527     /* Normalize and cast back to scalar_t. */
528     Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes());
529     auto iter2 = TensorIteratorConfig()
530       .add_output(ret)
531       .add_input(gamma)
532       .add_input(gamma_sum)
533       .check_all_same_dtype(false)
534       .build();
535     cpu_serial_kernel(iter2, [](double gamma_val, double gamma_sum_val) -> scalar_t{
536       auto ret_val = gamma_val / gamma_sum_val;
537       auto min_val = std::numeric_limits<scalar_t>::min();
538       auto max_val = std::nexttoward(static_cast<scalar_t>(1.0f), 0.0f);
539       return std::min(max_val, std::max(min_val, static_cast<scalar_t>(ret_val)));
540     });
541   });
542   return ret;
543 }
544 
545 /* The largest consecutive integer representable in float32 (2^24) */
546 constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);
547 
multinomial_out(const Tensor & self,int64_t n_sample,bool with_replacement,std::optional<Generator> gen,Tensor & result)548 Tensor& multinomial_out(const Tensor& self,
549     int64_t n_sample,
550     bool with_replacement,
551     std::optional<Generator> gen,
552     Tensor& result) {
553   TORCH_CHECK(
554       result.device() == self.device(),
555       "multinomial arguments must have the same device");
556   TORCH_CHECK(
557       self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
558   TORCH_CHECK(
559       at::isFloatingType(self.scalar_type()),
560       "multinomial only supports floating-point dtypes for input, got: ",
561       self.scalar_type());
562   TORCH_CHECK(result.scalar_type() == ScalarType::Long,
563       "multinomial expects Long tensor out, got: ", result.scalar_type());
564   TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
565   int64_t n_categories = self.size(-1);
566   TORCH_CHECK(with_replacement || (n_sample <= n_categories),
567       "cannot sample n_sample > prob_dist.size(-1) samples without replacement");
568   // Since the index tensor is float, numCategories cannot exceed max
569   // float integer precision
570   TORCH_CHECK(
571       n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
572       "number of categories cannot exceed 2^24");
573 
574   if (self.dim() == 1) {
575     result.resize_({n_sample});
576   } else {
577     const int64_t n_dist = self.size(0);
578     result.resize_({n_dist, n_sample});
579   }
580   if (result.numel() == 0) {
581     return result;
582   }
583 
584   // Fast-path for no replacement or if only one sample is drawn.
585   // Reference:
586   // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
587   if (!with_replacement || n_sample == 1) {
588     // Sanity checks on `self`.
589     auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0));
590     at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0");
591     at::Tensor zero_prob_condition;
592     if (self.dim() == 1){
593       zero_prob_condition = (self.sum() == 0);
594     } else {
595       zero_prob_condition = (self.sum(1) == 0).any();
596     }
597     at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
598 
599     // The algorithm is from gumbel softmax.
600     // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
601     // Here we can apply exp to the formula which will not affect result of
602     // argmax or topk. Then we have
603     // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
604     // We can also simplify the formula above by
605     // s = argmax( p / q ) where q ~ Exp(1)
606     Tensor q = at::empty_like(self).exponential_(1, std::move(gen));
607     // In theory the probability to generate 0 from exponential distribution is
608     // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
609     // side, there is a very low probability to generate 0 from
610     // exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
611     // ignore it here, but there may be some risk to get invalid output on CPU.
612     at::div_out(q, self, q);
613     if (n_sample == 1) {
614       at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
615     } else {
616       Tensor vals = at::empty(result.sizes(), self.options());
617       at::topk_out(vals, result, q, n_sample);
618     }
619     return result;
620   }
621 
622   multinomial_with_replacement_stub(
623       result.device().type(), result, self, n_sample, gen);
624   return result;
625 }
626 
multinomial(const Tensor & self,int64_t n_sample,bool with_replacement,std::optional<Generator> gen)627 Tensor multinomial(
628     const Tensor& self,
629     int64_t n_sample,
630     bool with_replacement,
631     std::optional<Generator> gen) {
632   Tensor result = at::empty({0}, self.options().dtype(kLong));
633   native::multinomial_out(self, n_sample, with_replacement, std::move(gen), result);
634   return result;
635 }
636 
637 } // namespace at::native
638