1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorIterator.h>
5 #include <ATen/TensorOperators.h>
6 #include <c10/util/Exception.h>
7 #include <optional>
8
9 #include <ATen/CPUGeneratorImpl.h>
10 #include <ATen/core/DistributionsHelper.h>
11 #include <ATen/native/Distributions.h>
12 #include <ATen/native/DispatchStub.h>
13 #include <ATen/native/UnaryOps.h>
14 #include <ATen/native/DistributionTemplates.h>
15 #include <ATen/NamedTensorUtils.h>
16 #include <ATen/native/cpu/Loops.h>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/_dirichlet_grad_native.h>
23 #include <ATen/ops/_sample_dirichlet_native.h>
24 #include <ATen/ops/_standard_gamma_grad_native.h>
25 #include <ATen/ops/_standard_gamma_native.h>
26 #include <ATen/ops/_assert_async.h>
27 #include <ATen/ops/argmax.h>
28 #include <ATen/ops/bernoulli_native.h>
29 #include <ATen/ops/binomial_native.h>
30 #include <ATen/ops/cauchy_native.h>
31 #include <ATen/ops/div.h>
32 #include <ATen/ops/empty.h>
33 #include <ATen/ops/empty_like.h>
34 #include <ATen/ops/exponential_native.h>
35 #include <ATen/ops/geometric_native.h>
36 #include <ATen/ops/log_normal_native.h>
37 #include <ATen/ops/multinomial_native.h>
38 #include <ATen/ops/normal_native.h>
39 #include <ATen/ops/poisson_native.h>
40 #include <ATen/ops/random_native.h>
41 #include <ATen/ops/topk.h>
42 #include <ATen/ops/uniform_native.h>
43 #include <ATen/ops/zeros.h>
44 #endif
45
46 #include <utility>
47
48 namespace {
49 /*
50 * This section is a counterpart to Distributions.cu
51 *
52 */
53
54 // The function `sample_poisson`
55 // is adapted from Numpy's distributions.c implementation.
56 // It is MIT licensed, so here is the copyright:
57
58 /* Copyright 2005 Robert Kern ([email protected])
59 *
60 * Permission is hereby granted, free of charge, to any person obtaining a
61 * copy of this software and associated documentation files (the
62 * "Software"), to deal in the Software without restriction, including
63 * without limitation the rights to use, copy, modify, merge, publish,
64 * distribute, sublicense, and/or sell copies of the Software, and to
65 * permit persons to whom the Software is furnished to do so, subject to
66 * the following conditions:
67 *
68 * The above copyright notice and this permission notice shall be included
69 * in all copies or substantial portions of the Software.
70 *
71 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
72 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
73 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
74 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
75 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
76 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
77 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
78 */
79
80
sample_poisson(double lambda,at::CPUGeneratorImpl * generator)81 int64_t sample_poisson(double lambda, at::CPUGeneratorImpl* generator) {
82 TORCH_CHECK(lambda >= 0, "invalid Poisson rate, expected rate to be non-negative");
83 at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
84 if (lambda >= 10) {
85 // transformed rejection method, (Hoermann, 1993)
86
87 double slam = std::sqrt(lambda);
88 double loglam = std::log(lambda);
89 double b = 0.931 + 2.53 * slam;
90 double a = -0.059 + 0.02483 * b;
91 double invalpha = 1.1239 + 1.1328 / (b - 3.4);
92 double vr = 0.9277 - 3.6224 / (b - 2);
93
94 while (true) {
95 double U = standard_uniform(generator) - 0.5;
96 double V = standard_uniform(generator);
97 double us = 0.5 - std::fabs(U);
98 auto k = std::floor((2 * a / us + b) * U + lambda + 0.43);
99 if ((us >= 0.07) && (V <= vr)) {
100 return static_cast<int64_t>(k);
101 }
102 if ((k < 0) || ((us < 0.013) && (V > us))) {
103 continue;
104 }
105 if ((std::log(V) + std::log(invalpha) - std::log(a / (us * us) + b)) <=
106 (-lambda + k * loglam - std::lgamma(k + 1))) {
107 return static_cast<int64_t>(k);
108 }
109 }
110 } else if (lambda == 0) {
111 return 0;
112 } else {
113 auto enlam = std::exp(-lambda);
114 int64_t X = 0;
115 auto prod = 1.0;
116 while (true) {
117 auto U = standard_uniform(generator);
118 prod *= U;
119 if (prod > enlam) {
120 X += 1;
121 } else {
122 return X;
123 }
124 }
125 }
126 }
127
128 } // namespace
129
130 namespace at::native {
131
132 DEFINE_DISPATCH(bernoulli_tensor_stub);
133 DEFINE_DISPATCH(bernoulli_scalar_stub);
134 DEFINE_DISPATCH(cauchy_stub);
135 DEFINE_DISPATCH(exponential_stub);
136 DEFINE_DISPATCH(multinomial_with_replacement_stub);
137 DEFINE_DISPATCH(geometric_stub);
138 DEFINE_DISPATCH(log_normal_stub);
139 DEFINE_DISPATCH(uniform_stub);
140 DEFINE_DISPATCH(normal_stub);
141 DEFINE_DISPATCH(random_stub);
142 DEFINE_DISPATCH(random_from_to_stub);
143 DEFINE_DISPATCH(random_full_64_bits_range_stub);
144
145 // ==================================================== Bernoulli =====================================================
146
147 template<typename RNG>
148 struct BernoulliStub {
operator ()at::native::BernoulliStub149 void operator()(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
150 bernoulli_tensor_stub(self.device().type(), self, p_, gen);
151 }
152
operator ()at::native::BernoulliStub153 void operator()(Tensor& self, double p, std::optional<Generator> gen) {
154 bernoulli_scalar_stub(self.device().type(), self, p, gen);
155 }
156 };
157
bernoulli(const Tensor & self,std::optional<Generator> gen)158 Tensor bernoulli(const Tensor& self, std::optional<Generator> gen) {
159 Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
160 result.bernoulli_(self, std::move(gen));
161 return result;
162 }
163
bernoulli(const Tensor & self,double p,std::optional<Generator> gen)164 Tensor bernoulli(const Tensor& self, double p, std::optional<Generator> gen) {
165 Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
166 result.bernoulli_(p, std::move(gen));
167 return result;
168 }
169
bernoulli_out(const Tensor & self,std::optional<Generator> gen,Tensor & result)170 Tensor& bernoulli_out(const Tensor& self, std::optional<Generator> gen, Tensor& result) {
171 return at::native::templates::bernoulli_out_impl<BernoulliStub, Generator>(result, self, std::move(gen));
172 }
173
bernoulli_(Tensor & self,const Tensor & p_,std::optional<Generator> gen)174 Tensor& bernoulli_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
175 return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p_, std::move(gen));
176 }
177
bernoulli_(Tensor & self,double p,std::optional<Generator> gen)178 Tensor& bernoulli_(Tensor& self, double p, std::optional<Generator> gen) {
179 return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p, std::move(gen));
180 }
181
182 // ================================================== LogNormal =======================================================
183
184 template<typename RNG>
185 struct LogNormalStub {
operator ()at::native::LogNormalStub186 void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
187 log_normal_stub(iter.device_type(), iter, mean, std, gen);
188 }
189 };
190
log_normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)191 Tensor& log_normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
192 return at::native::templates::log_normal_impl_<LogNormalStub, Generator>(self, mean, std, std::move(gen));
193 }
194
195 // ==================================================== Cauchy ========================================================
196
197 template<typename RNG>
198 struct CauchyStub {
operator ()at::native::CauchyStub199 void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
200 cauchy_stub(iter.device_type(), iter, median, sigma, gen);
201 }
202 };
203
cauchy_(Tensor & self,double median,double sigma,std::optional<Generator> gen)204 Tensor& cauchy_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
205 return at::native::templates::cauchy_impl_<CauchyStub, Generator>(self, median, sigma, std::move(gen));
206 }
207
208 // ================================================== Exponential =====================================================
209
210 template<typename RNG>
211 struct ExponentialStub {
operator ()at::native::ExponentialStub212 void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
213 exponential_stub(iter.device_type(), iter, lambda, gen);
214 }
215 };
216
exponential_(Tensor & self,double lambda,std::optional<Generator> gen)217 Tensor& exponential_(Tensor& self, double lambda, std::optional<Generator> gen) {
218 return at::native::templates::exponential_impl_<ExponentialStub, Generator>(self, lambda, std::move(gen));
219 }
220
221 // =================================================== Geometric ======================================================
222
223 template<typename RNG>
224 struct GeometricStub {
operator ()at::native::GeometricStub225 void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
226 geometric_stub(iter.device_type(), iter, p, gen);
227 }
228 };
229
geometric_(Tensor & self,double p,std::optional<Generator> gen)230 Tensor& geometric_(Tensor& self, double p, std::optional<Generator> gen) {
231 return at::native::templates::geometric_impl_<GeometricStub, Generator>(self, p, std::move(gen));
232 }
233
234 // ==================================================== Uniform =======================================================
235
236 template<typename RNG>
237 struct UniformStub {
operator ()at::native::UniformStub238 void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
239 uniform_stub(iter.device_type(), iter, from, to, gen);
240 }
241 };
242
243 template<typename RNG>
244 struct UniformMeta {
245 // No-op!
operator ()at::native::UniformMeta246 void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
247 }
248 };
249
uniform_(Tensor & self,double from,double to,std::optional<Generator> gen)250 Tensor& uniform_(Tensor& self, double from, double to, std::optional<Generator> gen) {
251 return at::native::templates::uniform_impl_<UniformStub, Generator>(self, from, to, std::move(gen));
252 }
253
uniform_meta_(Tensor & self,double from,double to,std::optional<Generator> gen)254 Tensor& uniform_meta_(Tensor& self, double from, double to, std::optional<Generator> gen) {
255 return at::native::templates::uniform_impl_<UniformMeta, Generator>(self, from, to, std::move(gen));
256 }
257
258 // ==================================================== Normal ========================================================
259
260 template<typename RNG>
261 struct NormalStub {
operator ()at::native::NormalStub262 void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
263 normal_stub(self.device().type(), self, mean, std, gen);
264 }
265 };
266
267 template<typename RNG>
268 struct NormalMeta {
269 // No-op!
operator ()at::native::NormalMeta270 void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
271 }
272 };
273
274 // inplace
normal_(Tensor & self,double mean,double std,std::optional<Generator> gen)275 Tensor& normal_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
276 return at::native::templates::normal_impl_<NormalStub, Generator>(self, mean, std, std::move(gen));
277 }
278
normal_meta_(Tensor & self,double mean,double std,std::optional<Generator> gen)279 Tensor& normal_meta_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
280 return at::native::templates::normal_impl_<NormalMeta, Generator>(self, mean, std, std::move(gen));
281 }
282
283 // out tensor float
normal_out(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)284 Tensor& normal_out(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
285 return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
286 }
287
normal_out_meta(const Tensor & mean,double std,std::optional<Generator> gen,Tensor & output)288 Tensor& normal_out_meta(const Tensor& mean, double std, std::optional<Generator> gen, Tensor& output) {
289 return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
290 }
291
292 // out float tensor
normal_out(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)293 Tensor& normal_out(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
294 return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
295 }
296
normal_out_meta(double mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)297 Tensor& normal_out_meta(double mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
298 return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
299
300 }
301
302 // out tensor tensor
normal_out(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)303 Tensor& normal_out(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
304 return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, std::move(gen));
305 }
306
normal_out_meta(const Tensor & mean,const Tensor & std,std::optional<Generator> gen,Tensor & output)307 Tensor& normal_out_meta(const Tensor& mean, const Tensor& std, std::optional<Generator> gen, Tensor& output) {
308 return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, std::move(gen));
309 }
310
311 // functional tensor float
normal(const Tensor & mean,double std,std::optional<Generator> gen)312 Tensor normal(const Tensor& mean, double std, std::optional<Generator> gen) {
313 return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
314 }
315
normal_meta(const Tensor & mean,double std,std::optional<Generator> gen)316 Tensor normal_meta(const Tensor& mean, double std, std::optional<Generator> gen) {
317 return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
318 }
319
320 // functional float tensor
normal(double mean,const Tensor & std,std::optional<Generator> gen)321 Tensor normal(double mean, const Tensor& std, std::optional<Generator> gen) {
322 return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
323 }
324
normal_meta(double mean,const Tensor & std,std::optional<Generator> gen)325 Tensor normal_meta(double mean, const Tensor& std, std::optional<Generator> gen) {
326 return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
327 }
328
329 // functional tensor tensor
normal(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)330 Tensor normal(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
331 return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, std::move(gen));
332 }
333
normal_meta(const Tensor & mean,const Tensor & std,std::optional<Generator> gen)334 Tensor normal_meta(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
335 return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, std::move(gen));
336 }
337
338 // functional variant, only used by the functionalization pass.
normal_functional(const Tensor & self,double mean,double std,std::optional<at::Generator> generator)339 Tensor normal_functional(const Tensor& self, double mean, double std, std::optional<at::Generator> generator) {
340 return self.clone().normal_(mean, std, std::move(generator));
341 }
342
343 // ==================================================== Random ========================================================
344
345 template<typename RNG>
346 struct RandomStub {
operator ()at::native::RandomStub347 void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
348 random_stub(iter.device_type(), iter, gen);
349 }
350 };
351
random_(Tensor & self,std::optional<Generator> gen)352 Tensor& random_(Tensor& self, std::optional<Generator> gen) {
353 return at::native::templates::random_impl<RandomStub, Generator>(self, std::move(gen));
354 }
355
356 template<typename RNG>
357 struct RandomFromToStub {
operator ()at::native::RandomFromToStub358 void operator()(TensorIteratorBase& iter, uint64_t range, int64_t from, std::optional<Generator> gen) {
359 random_from_to_stub(iter.device_type(), iter, range, from, gen);
360 }
operator ()at::native::RandomFromToStub361 void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
362 random_full_64_bits_range_stub(iter.device_type(), iter, gen);
363 }
364 };
365
random_(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> gen)366 Tensor& random_(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> gen) {
367 return at::native::templates::random_from_to_impl<RandomFromToStub, Generator>(self, from, to, std::move(gen));
368 }
369
random_(Tensor & self,int64_t to,std::optional<Generator> gen)370 Tensor& random_(Tensor& self, int64_t to, std::optional<Generator> gen) {
371 return random_(self, 0, to, std::move(gen));
372 }
373
random_meta_(Tensor & self,std::optional<Generator> gen)374 Tensor& random_meta_(Tensor& self, std::optional<Generator> gen) {
375 // No error checking yay
376 return self;
377 }
378
random_meta_(Tensor & self,int64_t from,std::optional<int64_t> to,std::optional<Generator> gen)379 Tensor& random_meta_(Tensor& self, int64_t from, std::optional<int64_t> to, std::optional<Generator> gen) {
380 // No error checking yay
381 return self;
382 }
383
random_meta_(Tensor & self,int64_t to,std::optional<Generator> gen)384 Tensor& random_meta_(Tensor& self, int64_t to, std::optional<Generator> gen) {
385 // No error checking yay
386 return self;
387 }
388
389 // ====================================================================================================================
390
_standard_gamma_grad_cpu(const Tensor & self,const Tensor & output)391 Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
392 Tensor ret = at::empty(self.sizes(), self.options());
393 auto iter = TensorIteratorConfig()
394 .add_output(ret)
395 .add_input(self)
396 .add_input(output)
397 .build();
398 AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] {
399 cpu_serial_kernel(iter, [](scalar_t self_val, scalar_t output_val) -> scalar_t{
400 return standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
401 });
402 });
403 return ret;
404 }
405
_dirichlet_grad_cpu(const Tensor & x,const Tensor & alpha,const Tensor & total)406 Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) {
407 Tensor ret = at::empty(x.sizes(), x.options());
408 auto iter = TensorIteratorConfig()
409 .add_output(ret)
410 .add_input(x)
411 .add_input(alpha)
412 .add_input(total)
413 .build();
414 AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] {
415 cpu_serial_kernel(iter, [](scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t{
416 return dirichlet_grad_one<scalar_t, double>(x_val, alpha_val, total_val);
417 });
418 });
419 return ret;
420 }
421
422 /*
423 * This section is a counterpart to Distributions.cu
424 */
425
_s_binomial_cpu(const Tensor & count,const Tensor & prob,std::optional<Generator> gen)426 Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, std::optional<Generator> gen) {
427 Tensor ret = at::zeros(count.sizes(), count.options());
428 auto iter = TensorIteratorConfig()
429 .add_output(ret)
430 .add_input(count)
431 .add_input(prob)
432 .build();
433 AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "binomial_cpu", [&] {
434 CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
435 // See Note [Acquire lock when using random generators]
436 std::lock_guard<std::mutex> lock(generator->mutex_);
437 cpu_serial_kernel(iter, [generator](scalar_t count_val, scalar_t prob_val) -> scalar_t{
438 auto uniform_lambda = [generator] () {
439 at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
440 return standard_uniform(generator);
441 };
442 BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
443
444 auto sample = sample_binomial<scalar_t, double, decltype(uniform_lambda)>(count_val, prob_val, standard_uniform);
445 return static_cast<scalar_t>(sample);
446 });
447 });
448 return ret;
449 }
450
_s_poisson_cpu(const Tensor & lambda,std::optional<Generator> gen)451 Tensor _s_poisson_cpu(const Tensor& lambda, std::optional<Generator> gen) {
452 Tensor ret = at::zeros(lambda.sizes(), lambda.options());
453 auto iter = TensorIteratorConfig()
454 .add_output(ret)
455 .add_input(lambda)
456 .build();
457 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, ret.scalar_type(), "poisson_cpu", [&] {
458 CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
459 // See Note [Acquire lock when using random generators]
460 std::lock_guard<std::mutex> lock(generator->mutex_);
461 cpu_serial_kernel(iter, [generator](scalar_t lambda_val) -> scalar_t{
462 return static_cast<scalar_t>(sample_poisson(static_cast<double>(lambda_val), generator));
463 });
464 });
465 return ret;
466 }
467
_s_gamma_cpu(const Tensor & alpha,std::optional<Generator> gen)468 Tensor _s_gamma_cpu(const Tensor& alpha, std::optional<Generator> gen) {
469 Tensor ret = at::zeros(alpha.sizes(), alpha.options());
470 auto iter = TensorIteratorConfig()
471 .add_output(ret)
472 .add_input(alpha)
473 .build();
474 AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] {
475 CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
476 // See Note [Acquire lock when using random generators]
477 std::lock_guard<std::mutex> lock(generator->mutex_);
478 cpu_serial_kernel(iter, [generator](scalar_t alpha_val) -> scalar_t{
479 auto uniform_lambda = [generator] () {
480 at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
481 return standard_uniform(generator);
482 };
483 BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
484
485 auto normal_lambda = [generator] () {
486 at::normal_distribution<double> normal(0.0, 1.0);
487 return normal(generator);
488 };
489 BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
490 auto sample = sample_gamma<scalar_t, double, decltype(uniform_lambda), decltype(normal_lambda)>(alpha_val, standard_uniform, standard_normal);
491 return std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);
492 });
493 });
494
495 return ret;
496 }
497
_s_dirichlet_cpu(const Tensor & alpha,std::optional<Generator> gen)498 Tensor _s_dirichlet_cpu(const Tensor& alpha, std::optional<Generator> gen) {
499 Tensor ret = at::zeros(alpha.sizes(), alpha.options());
500 AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "dirichlet", [&] {
501 Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double));
502 CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
503 // See Note [Acquire lock when using random generators]
504 std::lock_guard<std::mutex> lock(generator->mutex_);
505 /* Generate gamma sample by casting alpha to double to prevent underflow. */
506 auto iter1 = TensorIteratorConfig()
507 .add_output(gamma)
508 .add_input(alpha)
509 .check_all_same_dtype(false)
510 .build();
511 cpu_serial_kernel(iter1, [generator](scalar_t alpha_val) -> double{
512 auto uniform_lambda = [generator] () {
513 at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
514 return standard_uniform(generator);
515 };
516 BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
517
518 auto normal_lambda = [generator] () {
519 at::normal_distribution<double> normal(0.0, 1.0);
520 return normal(generator);
521 };
522 BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
523 auto sample = sample_gamma<double, double, decltype(uniform_lambda), decltype(normal_lambda)>
524 (alpha_val, standard_uniform, standard_normal);
525 return std::max(std::numeric_limits<double>::min(), sample);
526 });
527 /* Normalize and cast back to scalar_t. */
528 Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes());
529 auto iter2 = TensorIteratorConfig()
530 .add_output(ret)
531 .add_input(gamma)
532 .add_input(gamma_sum)
533 .check_all_same_dtype(false)
534 .build();
535 cpu_serial_kernel(iter2, [](double gamma_val, double gamma_sum_val) -> scalar_t{
536 auto ret_val = gamma_val / gamma_sum_val;
537 auto min_val = std::numeric_limits<scalar_t>::min();
538 auto max_val = std::nexttoward(static_cast<scalar_t>(1.0f), 0.0f);
539 return std::min(max_val, std::max(min_val, static_cast<scalar_t>(ret_val)));
540 });
541 });
542 return ret;
543 }
544
545 /* The largest consecutive integer representable in float32 (2^24) */
546 constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);
547
multinomial_out(const Tensor & self,int64_t n_sample,bool with_replacement,std::optional<Generator> gen,Tensor & result)548 Tensor& multinomial_out(const Tensor& self,
549 int64_t n_sample,
550 bool with_replacement,
551 std::optional<Generator> gen,
552 Tensor& result) {
553 TORCH_CHECK(
554 result.device() == self.device(),
555 "multinomial arguments must have the same device");
556 TORCH_CHECK(
557 self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
558 TORCH_CHECK(
559 at::isFloatingType(self.scalar_type()),
560 "multinomial only supports floating-point dtypes for input, got: ",
561 self.scalar_type());
562 TORCH_CHECK(result.scalar_type() == ScalarType::Long,
563 "multinomial expects Long tensor out, got: ", result.scalar_type());
564 TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
565 int64_t n_categories = self.size(-1);
566 TORCH_CHECK(with_replacement || (n_sample <= n_categories),
567 "cannot sample n_sample > prob_dist.size(-1) samples without replacement");
568 // Since the index tensor is float, numCategories cannot exceed max
569 // float integer precision
570 TORCH_CHECK(
571 n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
572 "number of categories cannot exceed 2^24");
573
574 if (self.dim() == 1) {
575 result.resize_({n_sample});
576 } else {
577 const int64_t n_dist = self.size(0);
578 result.resize_({n_dist, n_sample});
579 }
580 if (result.numel() == 0) {
581 return result;
582 }
583
584 // Fast-path for no replacement or if only one sample is drawn.
585 // Reference:
586 // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
587 if (!with_replacement || n_sample == 1) {
588 // Sanity checks on `self`.
589 auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0));
590 at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0");
591 at::Tensor zero_prob_condition;
592 if (self.dim() == 1){
593 zero_prob_condition = (self.sum() == 0);
594 } else {
595 zero_prob_condition = (self.sum(1) == 0).any();
596 }
597 at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
598
599 // The algorithm is from gumbel softmax.
600 // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
601 // Here we can apply exp to the formula which will not affect result of
602 // argmax or topk. Then we have
603 // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
604 // We can also simplify the formula above by
605 // s = argmax( p / q ) where q ~ Exp(1)
606 Tensor q = at::empty_like(self).exponential_(1, std::move(gen));
607 // In theory the probability to generate 0 from exponential distribution is
608 // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
609 // side, there is a very low probability to generate 0 from
610 // exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
611 // ignore it here, but there may be some risk to get invalid output on CPU.
612 at::div_out(q, self, q);
613 if (n_sample == 1) {
614 at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
615 } else {
616 Tensor vals = at::empty(result.sizes(), self.options());
617 at::topk_out(vals, result, q, n_sample);
618 }
619 return result;
620 }
621
622 multinomial_with_replacement_stub(
623 result.device().type(), result, self, n_sample, gen);
624 return result;
625 }
626
multinomial(const Tensor & self,int64_t n_sample,bool with_replacement,std::optional<Generator> gen)627 Tensor multinomial(
628 const Tensor& self,
629 int64_t n_sample,
630 bool with_replacement,
631 std::optional<Generator> gen) {
632 Tensor result = at::empty({0}, self.options().dtype(kLong));
633 native::multinomial_out(self, n_sample, with_replacement, std::move(gen), result);
634 return result;
635 }
636
637 } // namespace at::native
638