xref: /aosp_15_r20/external/pytorch/test/distributions/test_distributions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: distributions"]
2
3"""
4Note [Randomized statistical tests]
5-----------------------------------
6
7This note describes how to maintain tests in this file as random sources
8change. This file contains two types of randomized tests:
9
101. The easier type of randomized test are tests that should always pass but are
11   initialized with random data. If these fail something is wrong, but it's
12   fine to use a fixed seed by inheriting from common.TestCase.
13
142. The trickier tests are statistical tests. These tests explicitly call
15   set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
16   These statistical tests have a known positive failure rate
17   (we set failure_rate=1e-3 by default). We need to balance strength of these
18   tests with annoyance of false alarms. One way that works is to specifically
19   set seeds in each of the randomized tests. When a random generator
20   occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
21   of these statistical tests may (rarely) fail. If one fails in this case,
22   it's fine to increment the seed of the failing test (but you shouldn't need
23   to increment it more than once; otherwise something is probably actually
24   wrong).
25
263. `test_geometric_sample`, `test_binomial_sample` and `test_poisson_sample`
27   are validated against `scipy.stats.` which are not guaranteed to be identical
28   across different versions of scipy (namely, they yield invalid results in 1.7+)
29"""
30
31import math
32import numbers
33import unittest
34from collections import namedtuple
35from itertools import product
36from random import shuffle
37
38from packaging import version
39
40import torch
41import torch.autograd.forward_ad as fwAD
42from torch import inf, nan
43from torch.autograd import grad
44from torch.autograd.functional import jacobian
45from torch.distributions import (
46    Bernoulli,
47    Beta,
48    Binomial,
49    Categorical,
50    Cauchy,
51    Chi2,
52    constraints,
53    ContinuousBernoulli,
54    Dirichlet,
55    Distribution,
56    Exponential,
57    ExponentialFamily,
58    FisherSnedecor,
59    Gamma,
60    Geometric,
61    Gumbel,
62    HalfCauchy,
63    HalfNormal,
64    Independent,
65    InverseGamma,
66    kl_divergence,
67    Kumaraswamy,
68    Laplace,
69    LKJCholesky,
70    LogisticNormal,
71    LogNormal,
72    LowRankMultivariateNormal,
73    MixtureSameFamily,
74    Multinomial,
75    MultivariateNormal,
76    NegativeBinomial,
77    Normal,
78    OneHotCategorical,
79    OneHotCategoricalStraightThrough,
80    Pareto,
81    Poisson,
82    RelaxedBernoulli,
83    RelaxedOneHotCategorical,
84    StudentT,
85    TransformedDistribution,
86    Uniform,
87    VonMises,
88    Weibull,
89    Wishart,
90)
91from torch.distributions.constraint_registry import transform_to
92from torch.distributions.constraints import Constraint, is_dependent
93from torch.distributions.dirichlet import _Dirichlet_backward
94from torch.distributions.kl import _kl_expfamily_expfamily
95from torch.distributions.transforms import (
96    AffineTransform,
97    CatTransform,
98    ExpTransform,
99    identity_transform,
100    StackTransform,
101)
102from torch.distributions.utils import (
103    lazy_property,
104    probs_to_logits,
105    tril_matrix_to_vec,
106    vec_to_tril_matrix,
107)
108from torch.nn.functional import softmax
109from torch.testing._internal.common_cuda import TEST_CUDA
110from torch.testing._internal.common_utils import (
111    gradcheck,
112    load_tests,
113    run_tests,
114    set_default_dtype,
115    set_rng_seed,
116    skipIfTorchDynamo,
117    TestCase,
118)
119
120
121# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
122# sharding on sandcastle. This line silences flake warnings
123load_tests = load_tests
124
125TEST_NUMPY = True
126try:
127    import numpy as np
128    import scipy.special
129    import scipy.stats
130except ImportError:
131    TEST_NUMPY = False
132
133
134def pairwise(Dist, *params):
135    """
136    Creates a pair of distributions `Dist` initialized to test each element of
137    param with each other.
138    """
139    params1 = [torch.tensor([p] * len(p)) for p in params]
140    params2 = [p.transpose(0, 1) for p in params1]
141    return Dist(*params1), Dist(*params2)
142
143
144def is_all_nan(tensor):
145    """
146    Checks if all entries of a tensor is nan.
147    """
148    return (tensor != tensor).all()
149
150
151Example = namedtuple("Example", ["Dist", "params"])
152
153
154# Register all distributions for generic tests.
155def _get_examples():
156    return [
157        Example(
158            Bernoulli,
159            [
160                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
161                {"probs": torch.tensor([0.3], requires_grad=True)},
162                {"probs": 0.3},
163                {"logits": torch.tensor([0.0], requires_grad=True)},
164            ],
165        ),
166        Example(
167            Geometric,
168            [
169                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
170                {"probs": torch.tensor([0.3], requires_grad=True)},
171                {"probs": 0.3},
172            ],
173        ),
174        Example(
175            Beta,
176            [
177                {
178                    "concentration1": torch.randn(2, 3).exp().requires_grad_(),
179                    "concentration0": torch.randn(2, 3).exp().requires_grad_(),
180                },
181                {
182                    "concentration1": torch.randn(4).exp().requires_grad_(),
183                    "concentration0": torch.randn(4).exp().requires_grad_(),
184                },
185            ],
186        ),
187        Example(
188            Categorical,
189            [
190                {
191                    "probs": torch.tensor(
192                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
193                    )
194                },
195                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
196                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
197            ],
198        ),
199        Example(
200            Binomial,
201            [
202                {
203                    "probs": torch.tensor(
204                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
205                    ),
206                    "total_count": 10,
207                },
208                {
209                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
210                    "total_count": 10,
211                },
212                {
213                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
214                    "total_count": torch.tensor([10]),
215                },
216                {
217                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
218                    "total_count": torch.tensor([10, 8]),
219                },
220                {
221                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
222                    "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]),
223                },
224                {
225                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
226                    "total_count": torch.tensor(0.0),
227                },
228            ],
229        ),
230        Example(
231            NegativeBinomial,
232            [
233                {
234                    "probs": torch.tensor(
235                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
236                    ),
237                    "total_count": 10,
238                },
239                {
240                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
241                    "total_count": 10,
242                },
243                {
244                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
245                    "total_count": torch.tensor([10]),
246                },
247                {
248                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
249                    "total_count": torch.tensor([10, 8]),
250                },
251                {
252                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
253                    "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]),
254                },
255                {
256                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
257                    "total_count": torch.tensor(0.0),
258                },
259            ],
260        ),
261        Example(
262            Multinomial,
263            [
264                {
265                    "probs": torch.tensor(
266                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
267                    ),
268                    "total_count": 10,
269                },
270                {
271                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
272                    "total_count": 10,
273                },
274            ],
275        ),
276        Example(
277            Cauchy,
278            [
279                {"loc": 0.0, "scale": 1.0},
280                {"loc": torch.tensor([0.0]), "scale": 1.0},
281                {
282                    "loc": torch.tensor([[0.0], [0.0]]),
283                    "scale": torch.tensor([[1.0], [1.0]]),
284                },
285            ],
286        ),
287        Example(
288            Chi2,
289            [
290                {"df": torch.randn(2, 3).exp().requires_grad_()},
291                {"df": torch.randn(1).exp().requires_grad_()},
292            ],
293        ),
294        Example(
295            StudentT,
296            [
297                {"df": torch.randn(2, 3).exp().requires_grad_()},
298                {"df": torch.randn(1).exp().requires_grad_()},
299            ],
300        ),
301        Example(
302            Dirichlet,
303            [
304                {"concentration": torch.randn(2, 3).exp().requires_grad_()},
305                {"concentration": torch.randn(4).exp().requires_grad_()},
306            ],
307        ),
308        Example(
309            Exponential,
310            [
311                {"rate": torch.randn(5, 5).abs().requires_grad_()},
312                {"rate": torch.randn(1).abs().requires_grad_()},
313            ],
314        ),
315        Example(
316            FisherSnedecor,
317            [
318                {
319                    "df1": torch.randn(5, 5).abs().requires_grad_(),
320                    "df2": torch.randn(5, 5).abs().requires_grad_(),
321                },
322                {
323                    "df1": torch.randn(1).abs().requires_grad_(),
324                    "df2": torch.randn(1).abs().requires_grad_(),
325                },
326                {
327                    "df1": torch.tensor([1.0]),
328                    "df2": 1.0,
329                },
330            ],
331        ),
332        Example(
333            Gamma,
334            [
335                {
336                    "concentration": torch.randn(2, 3).exp().requires_grad_(),
337                    "rate": torch.randn(2, 3).exp().requires_grad_(),
338                },
339                {
340                    "concentration": torch.randn(1).exp().requires_grad_(),
341                    "rate": torch.randn(1).exp().requires_grad_(),
342                },
343            ],
344        ),
345        Example(
346            Gumbel,
347            [
348                {
349                    "loc": torch.randn(5, 5, requires_grad=True),
350                    "scale": torch.randn(5, 5).abs().requires_grad_(),
351                },
352                {
353                    "loc": torch.randn(1, requires_grad=True),
354                    "scale": torch.randn(1).abs().requires_grad_(),
355                },
356            ],
357        ),
358        Example(HalfCauchy, [{"scale": 1.0}, {"scale": torch.tensor([[1.0], [1.0]])}]),
359        Example(
360            HalfNormal,
361            [
362                {"scale": torch.randn(5, 5).abs().requires_grad_()},
363                {"scale": torch.randn(1).abs().requires_grad_()},
364                {"scale": torch.tensor([1e-5, 1e-5], requires_grad=True)},
365            ],
366        ),
367        Example(
368            Independent,
369            [
370                {
371                    "base_distribution": Normal(
372                        torch.randn(2, 3, requires_grad=True),
373                        torch.randn(2, 3).abs().requires_grad_(),
374                    ),
375                    "reinterpreted_batch_ndims": 0,
376                },
377                {
378                    "base_distribution": Normal(
379                        torch.randn(2, 3, requires_grad=True),
380                        torch.randn(2, 3).abs().requires_grad_(),
381                    ),
382                    "reinterpreted_batch_ndims": 1,
383                },
384                {
385                    "base_distribution": Normal(
386                        torch.randn(2, 3, requires_grad=True),
387                        torch.randn(2, 3).abs().requires_grad_(),
388                    ),
389                    "reinterpreted_batch_ndims": 2,
390                },
391                {
392                    "base_distribution": Normal(
393                        torch.randn(2, 3, 5, requires_grad=True),
394                        torch.randn(2, 3, 5).abs().requires_grad_(),
395                    ),
396                    "reinterpreted_batch_ndims": 2,
397                },
398                {
399                    "base_distribution": Normal(
400                        torch.randn(2, 3, 5, requires_grad=True),
401                        torch.randn(2, 3, 5).abs().requires_grad_(),
402                    ),
403                    "reinterpreted_batch_ndims": 3,
404                },
405            ],
406        ),
407        Example(
408            Kumaraswamy,
409            [
410                {
411                    "concentration1": torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
412                    "concentration0": torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
413                },
414                {
415                    "concentration1": torch.rand(4).uniform_(1, 2).requires_grad_(),
416                    "concentration0": torch.rand(4).uniform_(1, 2).requires_grad_(),
417                },
418            ],
419        ),
420        Example(
421            LKJCholesky,
422            [
423                {"dim": 2, "concentration": 0.5},
424                {
425                    "dim": 3,
426                    "concentration": torch.tensor([0.5, 1.0, 2.0]),
427                },
428                {"dim": 100, "concentration": 4.0},
429            ],
430        ),
431        Example(
432            Laplace,
433            [
434                {
435                    "loc": torch.randn(5, 5, requires_grad=True),
436                    "scale": torch.randn(5, 5).abs().requires_grad_(),
437                },
438                {
439                    "loc": torch.randn(1, requires_grad=True),
440                    "scale": torch.randn(1).abs().requires_grad_(),
441                },
442                {
443                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
444                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
445                },
446            ],
447        ),
448        Example(
449            LogNormal,
450            [
451                {
452                    "loc": torch.randn(5, 5, requires_grad=True),
453                    "scale": torch.randn(5, 5).abs().requires_grad_(),
454                },
455                {
456                    "loc": torch.randn(1, requires_grad=True),
457                    "scale": torch.randn(1).abs().requires_grad_(),
458                },
459                {
460                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
461                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
462                },
463            ],
464        ),
465        Example(
466            LogisticNormal,
467            [
468                {
469                    "loc": torch.randn(5, 5).requires_grad_(),
470                    "scale": torch.randn(5, 5).abs().requires_grad_(),
471                },
472                {
473                    "loc": torch.randn(1).requires_grad_(),
474                    "scale": torch.randn(1).abs().requires_grad_(),
475                },
476                {
477                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
478                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
479                },
480            ],
481        ),
482        Example(
483            LowRankMultivariateNormal,
484            [
485                {
486                    "loc": torch.randn(5, 2, requires_grad=True),
487                    "cov_factor": torch.randn(5, 2, 1, requires_grad=True),
488                    "cov_diag": torch.tensor([2.0, 0.25], requires_grad=True),
489                },
490                {
491                    "loc": torch.randn(4, 3, requires_grad=True),
492                    "cov_factor": torch.randn(3, 2, requires_grad=True),
493                    "cov_diag": torch.tensor([5.0, 1.5, 3.0], requires_grad=True),
494                },
495            ],
496        ),
497        Example(
498            MultivariateNormal,
499            [
500                {
501                    "loc": torch.randn(5, 2, requires_grad=True),
502                    "covariance_matrix": torch.tensor(
503                        [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
504                    ),
505                },
506                {
507                    "loc": torch.randn(2, 3, requires_grad=True),
508                    "precision_matrix": torch.tensor(
509                        [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]],
510                        requires_grad=True,
511                    ),
512                },
513                {
514                    "loc": torch.randn(5, 3, 2, requires_grad=True),
515                    "scale_tril": torch.tensor(
516                        [
517                            [[2.0, 0.0], [-0.5, 0.25]],
518                            [[2.0, 0.0], [0.3, 0.25]],
519                            [[5.0, 0.0], [-0.5, 1.5]],
520                        ],
521                        requires_grad=True,
522                    ),
523                },
524                {
525                    "loc": torch.tensor([1.0, -1.0]),
526                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
527                },
528            ],
529        ),
530        Example(
531            Normal,
532            [
533                {
534                    "loc": torch.randn(5, 5, requires_grad=True),
535                    "scale": torch.randn(5, 5).abs().requires_grad_(),
536                },
537                {
538                    "loc": torch.randn(1, requires_grad=True),
539                    "scale": torch.randn(1).abs().requires_grad_(),
540                },
541                {
542                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
543                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
544                },
545            ],
546        ),
547        Example(
548            OneHotCategorical,
549            [
550                {
551                    "probs": torch.tensor(
552                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
553                    )
554                },
555                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
556                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
557            ],
558        ),
559        Example(
560            OneHotCategoricalStraightThrough,
561            [
562                {
563                    "probs": torch.tensor(
564                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
565                    )
566                },
567                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
568                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
569            ],
570        ),
571        Example(
572            Pareto,
573            [
574                {"scale": 1.0, "alpha": 1.0},
575                {
576                    "scale": (torch.randn(5, 5).abs() + 0.1).requires_grad_(),
577                    "alpha": (torch.randn(5, 5).abs() + 0.1).requires_grad_(),
578                },
579                {"scale": torch.tensor([1.0]), "alpha": 1.0},
580            ],
581        ),
582        Example(
583            Poisson,
584            [
585                {
586                    "rate": torch.randn(5, 5).abs().requires_grad_(),
587                },
588                {
589                    "rate": torch.randn(3).abs().requires_grad_(),
590                },
591                {
592                    "rate": 0.2,
593                },
594                {
595                    "rate": torch.tensor([0.0], requires_grad=True),
596                },
597                {
598                    "rate": 0.0,
599                },
600            ],
601        ),
602        Example(
603            RelaxedBernoulli,
604            [
605                {
606                    "temperature": torch.tensor([0.5], requires_grad=True),
607                    "probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True),
608                },
609                {
610                    "temperature": torch.tensor([2.0]),
611                    "probs": torch.tensor([0.3]),
612                },
613                {
614                    "temperature": torch.tensor([7.2]),
615                    "logits": torch.tensor([-2.0, 2.0, 1.0, 5.0]),
616                },
617            ],
618        ),
619        Example(
620            RelaxedOneHotCategorical,
621            [
622                {
623                    "temperature": torch.tensor([0.5], requires_grad=True),
624                    "probs": torch.tensor(
625                        [[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True
626                    ),
627                },
628                {
629                    "temperature": torch.tensor([2.0]),
630                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
631                },
632                {
633                    "temperature": torch.tensor([7.2]),
634                    "logits": torch.tensor([[-2.0, 2.0], [1.0, 5.0]]),
635                },
636            ],
637        ),
638        Example(
639            TransformedDistribution,
640            [
641                {
642                    "base_distribution": Normal(
643                        torch.randn(2, 3, requires_grad=True),
644                        torch.randn(2, 3).abs().requires_grad_(),
645                    ),
646                    "transforms": [],
647                },
648                {
649                    "base_distribution": Normal(
650                        torch.randn(2, 3, requires_grad=True),
651                        torch.randn(2, 3).abs().requires_grad_(),
652                    ),
653                    "transforms": ExpTransform(),
654                },
655                {
656                    "base_distribution": Normal(
657                        torch.randn(2, 3, 5, requires_grad=True),
658                        torch.randn(2, 3, 5).abs().requires_grad_(),
659                    ),
660                    "transforms": [
661                        AffineTransform(torch.randn(3, 5), torch.randn(3, 5)),
662                        ExpTransform(),
663                    ],
664                },
665                {
666                    "base_distribution": Normal(
667                        torch.randn(2, 3, 5, requires_grad=True),
668                        torch.randn(2, 3, 5).abs().requires_grad_(),
669                    ),
670                    "transforms": AffineTransform(1, 2),
671                },
672                {
673                    "base_distribution": Uniform(
674                        torch.tensor(1e8).log(), torch.tensor(1e10).log()
675                    ),
676                    "transforms": ExpTransform(),
677                },
678            ],
679        ),
680        Example(
681            Uniform,
682            [
683                {
684                    "low": torch.zeros(5, 5, requires_grad=True),
685                    "high": torch.ones(5, 5, requires_grad=True),
686                },
687                {
688                    "low": torch.zeros(1, requires_grad=True),
689                    "high": torch.ones(1, requires_grad=True),
690                },
691                {
692                    "low": torch.tensor([1.0, 1.0], requires_grad=True),
693                    "high": torch.tensor([2.0, 3.0], requires_grad=True),
694                },
695            ],
696        ),
697        Example(
698            Weibull,
699            [
700                {
701                    "scale": torch.randn(5, 5).abs().requires_grad_(),
702                    "concentration": torch.randn(1).abs().requires_grad_(),
703                }
704            ],
705        ),
706        Example(
707            Wishart,
708            [
709                {
710                    "covariance_matrix": torch.tensor(
711                        [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
712                    ),
713                    "df": torch.tensor([3.0], requires_grad=True),
714                },
715                {
716                    "precision_matrix": torch.tensor(
717                        [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]],
718                        requires_grad=True,
719                    ),
720                    "df": torch.tensor([5.0, 4], requires_grad=True),
721                },
722                {
723                    "scale_tril": torch.tensor(
724                        [
725                            [[2.0, 0.0], [-0.5, 0.25]],
726                            [[2.0, 0.0], [0.3, 0.25]],
727                            [[5.0, 0.0], [-0.5, 1.5]],
728                        ],
729                        requires_grad=True,
730                    ),
731                    "df": torch.tensor([5.0, 3.5, 3], requires_grad=True),
732                },
733                {
734                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
735                    "df": torch.tensor([3.0]),
736                },
737                {
738                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
739                    "df": 3.0,
740                },
741            ],
742        ),
743        Example(
744            MixtureSameFamily,
745            [
746                {
747                    "mixture_distribution": Categorical(
748                        torch.rand(5, requires_grad=True)
749                    ),
750                    "component_distribution": Normal(
751                        torch.randn(5, requires_grad=True),
752                        torch.rand(5, requires_grad=True),
753                    ),
754                },
755                {
756                    "mixture_distribution": Categorical(
757                        torch.rand(5, requires_grad=True)
758                    ),
759                    "component_distribution": MultivariateNormal(
760                        loc=torch.randn(5, 2, requires_grad=True),
761                        covariance_matrix=torch.tensor(
762                            [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
763                        ),
764                    ),
765                },
766            ],
767        ),
768        Example(
769            VonMises,
770            [
771                {
772                    "loc": torch.tensor(1.0, requires_grad=True),
773                    "concentration": torch.tensor(10.0, requires_grad=True),
774                },
775                {
776                    "loc": torch.tensor([0.0, math.pi / 2], requires_grad=True),
777                    "concentration": torch.tensor([1.0, 10.0], requires_grad=True),
778                },
779            ],
780        ),
781        Example(
782            ContinuousBernoulli,
783            [
784                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
785                {"probs": torch.tensor([0.3], requires_grad=True)},
786                {"probs": 0.3},
787                {"logits": torch.tensor([0.0], requires_grad=True)},
788            ],
789        ),
790        Example(
791            InverseGamma,
792            [
793                {
794                    "concentration": torch.randn(2, 3).exp().requires_grad_(),
795                    "rate": torch.randn(2, 3).exp().requires_grad_(),
796                },
797                {
798                    "concentration": torch.randn(1).exp().requires_grad_(),
799                    "rate": torch.randn(1).exp().requires_grad_(),
800                },
801            ],
802        ),
803    ]
804
805
806def _get_bad_examples():
807    return [
808        Example(
809            Bernoulli,
810            [
811                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
812                {"probs": torch.tensor([-0.5], requires_grad=True)},
813                {"probs": 1.00001},
814            ],
815        ),
816        Example(
817            Beta,
818            [
819                {
820                    "concentration1": torch.tensor([0.0], requires_grad=True),
821                    "concentration0": torch.tensor([0.0], requires_grad=True),
822                },
823                {
824                    "concentration1": torch.tensor([-1.0], requires_grad=True),
825                    "concentration0": torch.tensor([-2.0], requires_grad=True),
826                },
827            ],
828        ),
829        Example(
830            Geometric,
831            [
832                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
833                {"probs": torch.tensor([-0.3], requires_grad=True)},
834                {"probs": 1.00000001},
835            ],
836        ),
837        Example(
838            Categorical,
839            [
840                {
841                    "probs": torch.tensor(
842                        [[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
843                    )
844                },
845                {
846                    "probs": torch.tensor(
847                        [[-1.0, 10.0], [0.0, -1.0]], requires_grad=True
848                    )
849                },
850            ],
851        ),
852        Example(
853            Binomial,
854            [
855                {
856                    "probs": torch.tensor(
857                        [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
858                    ),
859                    "total_count": 10,
860                },
861                {
862                    "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
863                    "total_count": 10,
864                },
865            ],
866        ),
867        Example(
868            NegativeBinomial,
869            [
870                {
871                    "probs": torch.tensor(
872                        [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
873                    ),
874                    "total_count": 10,
875                },
876                {
877                    "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
878                    "total_count": 10,
879                },
880            ],
881        ),
882        Example(
883            Cauchy,
884            [
885                {"loc": 0.0, "scale": -1.0},
886                {"loc": torch.tensor([0.0]), "scale": 0.0},
887                {
888                    "loc": torch.tensor([[0.0], [-2.0]]),
889                    "scale": torch.tensor([[-0.000001], [1.0]]),
890                },
891            ],
892        ),
893        Example(
894            Chi2,
895            [
896                {"df": torch.tensor([0.0], requires_grad=True)},
897                {"df": torch.tensor([-2.0], requires_grad=True)},
898            ],
899        ),
900        Example(
901            StudentT,
902            [
903                {"df": torch.tensor([0.0], requires_grad=True)},
904                {"df": torch.tensor([-2.0], requires_grad=True)},
905            ],
906        ),
907        Example(
908            Dirichlet,
909            [
910                {"concentration": torch.tensor([0.0], requires_grad=True)},
911                {"concentration": torch.tensor([-2.0], requires_grad=True)},
912            ],
913        ),
914        Example(
915            Exponential,
916            [
917                {"rate": torch.tensor([0.0, 0.0], requires_grad=True)},
918                {"rate": torch.tensor([-2.0], requires_grad=True)},
919            ],
920        ),
921        Example(
922            FisherSnedecor,
923            [
924                {
925                    "df1": torch.tensor([0.0, 0.0], requires_grad=True),
926                    "df2": torch.tensor([-1.0, -100.0], requires_grad=True),
927                },
928                {
929                    "df1": torch.tensor([1.0, 1.0], requires_grad=True),
930                    "df2": torch.tensor([0.0, 0.0], requires_grad=True),
931                },
932            ],
933        ),
934        Example(
935            Gamma,
936            [
937                {
938                    "concentration": torch.tensor([0.0, 0.0], requires_grad=True),
939                    "rate": torch.tensor([-1.0, -100.0], requires_grad=True),
940                },
941                {
942                    "concentration": torch.tensor([1.0, 1.0], requires_grad=True),
943                    "rate": torch.tensor([0.0, 0.0], requires_grad=True),
944                },
945            ],
946        ),
947        Example(
948            Gumbel,
949            [
950                {
951                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
952                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
953                },
954                {
955                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
956                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
957                },
958            ],
959        ),
960        Example(
961            HalfCauchy,
962            [
963                {"scale": -1.0},
964                {"scale": 0.0},
965                {"scale": torch.tensor([[-0.000001], [1.0]])},
966            ],
967        ),
968        Example(
969            HalfNormal,
970            [
971                {"scale": torch.tensor([0.0, 1.0], requires_grad=True)},
972                {"scale": torch.tensor([1.0, -1.0], requires_grad=True)},
973            ],
974        ),
975        Example(
976            LKJCholesky,
977            [
978                {"dim": -2, "concentration": 0.1},
979                {
980                    "dim": 1,
981                    "concentration": 2.0,
982                },
983                {
984                    "dim": 2,
985                    "concentration": 0.0,
986                },
987            ],
988        ),
989        Example(
990            Laplace,
991            [
992                {
993                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
994                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
995                },
996                {
997                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
998                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
999                },
1000            ],
1001        ),
1002        Example(
1003            LogNormal,
1004            [
1005                {
1006                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1007                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
1008                },
1009                {
1010                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1011                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
1012                },
1013            ],
1014        ),
1015        Example(
1016            MultivariateNormal,
1017            [
1018                {
1019                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1020                    "covariance_matrix": torch.tensor(
1021                        [[1.0, 0.0], [0.0, -2.0]], requires_grad=True
1022                    ),
1023                },
1024            ],
1025        ),
1026        Example(
1027            Normal,
1028            [
1029                {
1030                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1031                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
1032                },
1033                {
1034                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1035                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
1036                },
1037                {
1038                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
1039                    "scale": torch.tensor([1e-5, -1e-5], requires_grad=True),
1040                },
1041            ],
1042        ),
1043        Example(
1044            OneHotCategorical,
1045            [
1046                {
1047                    "probs": torch.tensor(
1048                        [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True
1049                    )
1050                },
1051                {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
1052            ],
1053        ),
1054        Example(
1055            OneHotCategoricalStraightThrough,
1056            [
1057                {
1058                    "probs": torch.tensor(
1059                        [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True
1060                    )
1061                },
1062                {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
1063            ],
1064        ),
1065        Example(
1066            Pareto,
1067            [
1068                {"scale": 0.0, "alpha": 0.0},
1069                {
1070                    "scale": torch.tensor([0.0, 0.0], requires_grad=True),
1071                    "alpha": torch.tensor([-1e-5, 0.0], requires_grad=True),
1072                },
1073                {"scale": torch.tensor([1.0]), "alpha": -1.0},
1074            ],
1075        ),
1076        Example(
1077            Poisson,
1078            [
1079                {
1080                    "rate": torch.tensor([-0.1], requires_grad=True),
1081                },
1082                {
1083                    "rate": -1.0,
1084                },
1085            ],
1086        ),
1087        Example(
1088            RelaxedBernoulli,
1089            [
1090                {
1091                    "temperature": torch.tensor([1.5], requires_grad=True),
1092                    "probs": torch.tensor([1.7, 0.2, 0.4], requires_grad=True),
1093                },
1094                {
1095                    "temperature": torch.tensor([2.0]),
1096                    "probs": torch.tensor([-1.0]),
1097                },
1098            ],
1099        ),
1100        Example(
1101            RelaxedOneHotCategorical,
1102            [
1103                {
1104                    "temperature": torch.tensor([0.5], requires_grad=True),
1105                    "probs": torch.tensor(
1106                        [[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True
1107                    ),
1108                },
1109                {
1110                    "temperature": torch.tensor([2.0]),
1111                    "probs": torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]),
1112                },
1113            ],
1114        ),
1115        Example(
1116            TransformedDistribution,
1117            [
1118                {
1119                    "base_distribution": Normal(0, 1),
1120                    "transforms": lambda x: x,
1121                },
1122                {
1123                    "base_distribution": Normal(0, 1),
1124                    "transforms": [lambda x: x],
1125                },
1126            ],
1127        ),
1128        Example(
1129            Uniform,
1130            [
1131                {
1132                    "low": torch.tensor([2.0], requires_grad=True),
1133                    "high": torch.tensor([2.0], requires_grad=True),
1134                },
1135                {
1136                    "low": torch.tensor([0.0], requires_grad=True),
1137                    "high": torch.tensor([0.0], requires_grad=True),
1138                },
1139                {
1140                    "low": torch.tensor([1.0], requires_grad=True),
1141                    "high": torch.tensor([0.0], requires_grad=True),
1142                },
1143            ],
1144        ),
1145        Example(
1146            Weibull,
1147            [
1148                {
1149                    "scale": torch.tensor([0.0], requires_grad=True),
1150                    "concentration": torch.tensor([0.0], requires_grad=True),
1151                },
1152                {
1153                    "scale": torch.tensor([1.0], requires_grad=True),
1154                    "concentration": torch.tensor([-1.0], requires_grad=True),
1155                },
1156            ],
1157        ),
1158        Example(
1159            Wishart,
1160            [
1161                {
1162                    "covariance_matrix": torch.tensor(
1163                        [[1.0, 0.0], [0.0, -2.0]], requires_grad=True
1164                    ),
1165                    "df": torch.tensor([1.5], requires_grad=True),
1166                },
1167                {
1168                    "covariance_matrix": torch.tensor(
1169                        [[1.0, 1.0], [1.0, -2.0]], requires_grad=True
1170                    ),
1171                    "df": torch.tensor([3.0], requires_grad=True),
1172                },
1173                {
1174                    "covariance_matrix": torch.tensor(
1175                        [[1.0, 1.0], [1.0, -2.0]], requires_grad=True
1176                    ),
1177                    "df": 3.0,
1178                },
1179            ],
1180        ),
1181        Example(
1182            ContinuousBernoulli,
1183            [
1184                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
1185                {"probs": torch.tensor([-0.5], requires_grad=True)},
1186                {"probs": 1.00001},
1187            ],
1188        ),
1189        Example(
1190            InverseGamma,
1191            [
1192                {
1193                    "concentration": torch.tensor([0.0, 0.0], requires_grad=True),
1194                    "rate": torch.tensor([-1.0, -100.0], requires_grad=True),
1195                },
1196                {
1197                    "concentration": torch.tensor([1.0, 1.0], requires_grad=True),
1198                    "rate": torch.tensor([0.0, 0.0], requires_grad=True),
1199                },
1200            ],
1201        ),
1202    ]
1203
1204
1205class DistributionsTestCase(TestCase):
1206    def setUp(self):
1207        """The tests assume that the validation flag is set."""
1208        torch.distributions.Distribution.set_default_validate_args(True)
1209        super().setUp()
1210
1211
1212@skipIfTorchDynamo("Not a TorchDynamo suitable test")
1213class TestDistributions(DistributionsTestCase):
1214    _do_cuda_memory_leak_check = True
1215    _do_cuda_non_default_stream = True
1216
1217    def _gradcheck_log_prob(self, dist_ctor, ctor_params):
1218        # performs gradient checks on log_prob
1219        distribution = dist_ctor(*ctor_params)
1220        s = distribution.sample()
1221        if not distribution.support.is_discrete:
1222            s = s.detach().requires_grad_()
1223
1224        expected_shape = distribution.batch_shape + distribution.event_shape
1225        self.assertEqual(s.size(), expected_shape)
1226
1227        def apply_fn(s, *params):
1228            return dist_ctor(*params).log_prob(s)
1229
1230        gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True)
1231
1232    def _check_forward_ad(self, fn):
1233        with fwAD.dual_level():
1234            x = torch.tensor(1.0)
1235            t = torch.tensor(1.0)
1236            dual = fwAD.make_dual(x, t)
1237            dual_out = fn(dual)
1238            self.assertEqual(
1239                torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0
1240            )
1241
1242    def _check_log_prob(self, dist, asset_fn):
1243        # checks that the log_prob matches a reference function
1244        s = dist.sample()
1245        log_probs = dist.log_prob(s)
1246        log_probs_data_flat = log_probs.view(-1)
1247        s_data_flat = s.view(len(log_probs_data_flat), -1)
1248        for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)):
1249            asset_fn(i, val.squeeze(), log_prob)
1250
1251    def _check_sampler_sampler(
1252        self,
1253        torch_dist,
1254        ref_dist,
1255        message,
1256        multivariate=False,
1257        circular=False,
1258        num_samples=10000,
1259        failure_rate=1e-3,
1260    ):
1261        # Checks that the .sample() method matches a reference function.
1262        torch_samples = torch_dist.sample((num_samples,)).squeeze()
1263        torch_samples = torch_samples.cpu().numpy()
1264        ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
1265        if multivariate:
1266            # Project onto a random axis.
1267            axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
1268            axis /= np.linalg.norm(axis)
1269            torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
1270            ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
1271        samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
1272        if circular:
1273            samples = [(np.cos(x), v) for (x, v) in samples]
1274        shuffle(
1275            samples
1276        )  # necessary to prevent stable sort from making uneven bins for discrete
1277        samples.sort(key=lambda x: x[0])
1278        samples = np.array(samples)[:, 1]
1279
1280        # Aggregate into bins filled with roughly zero-mean unit-variance RVs.
1281        num_bins = 10
1282        samples_per_bin = len(samples) // num_bins
1283        bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
1284        stddev = samples_per_bin**-0.5
1285        threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
1286        message = f"{message}.sample() is biased:\n{bins}"
1287        for bias in bins:
1288            self.assertLess(-threshold, bias, message)
1289            self.assertLess(bias, threshold, message)
1290
1291    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1292    def _check_sampler_discrete(
1293        self, torch_dist, ref_dist, message, num_samples=10000, failure_rate=1e-3
1294    ):
1295        """Runs a Chi2-test for the support, but ignores tail instead of combining"""
1296        torch_samples = torch_dist.sample((num_samples,)).squeeze()
1297        torch_samples = (
1298            torch_samples.float()
1299            if torch_samples.dtype == torch.bfloat16
1300            else torch_samples
1301        )
1302        torch_samples = torch_samples.cpu().numpy()
1303        unique, counts = np.unique(torch_samples, return_counts=True)
1304        pmf = ref_dist.pmf(unique)
1305        pmf = pmf / pmf.sum()  # renormalize to 1.0 for chisq test
1306        msk = (counts > 5) & ((pmf * num_samples) > 5)
1307        self.assertGreater(
1308            pmf[msk].sum(),
1309            0.9,
1310            "Distribution is too sparse for test; try increasing num_samples",
1311        )
1312        # Add a remainder bucket that combines counts for all values
1313        # below threshold, if such values exist (i.e. mask has False entries).
1314        if not msk.all():
1315            counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
1316            pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
1317        chisq, p = scipy.stats.chisquare(counts, pmf * num_samples)
1318        self.assertGreater(p, failure_rate, message)
1319
1320    def _check_enumerate_support(self, dist, examples):
1321        for params, expected in examples:
1322            params = {k: torch.tensor(v) for k, v in params.items()}
1323            d = dist(**params)
1324            actual = d.enumerate_support(expand=False)
1325            expected = torch.tensor(expected, dtype=actual.dtype)
1326            self.assertEqual(actual, expected)
1327            actual = d.enumerate_support(expand=True)
1328            expected_with_expand = expected.expand(
1329                (-1,) + d.batch_shape + d.event_shape
1330            )
1331            self.assertEqual(actual, expected_with_expand)
1332
1333    def test_repr(self):
1334        for Dist, params in _get_examples():
1335            for param in params:
1336                dist = Dist(**param)
1337                self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
1338
1339    def test_sample_detached(self):
1340        for Dist, params in _get_examples():
1341            for i, param in enumerate(params):
1342                variable_params = [
1343                    p for p in param.values() if getattr(p, "requires_grad", False)
1344                ]
1345                if not variable_params:
1346                    continue
1347                dist = Dist(**param)
1348                sample = dist.sample()
1349                self.assertFalse(
1350                    sample.requires_grad,
1351                    msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .sample() is not detached",
1352                )
1353
1354    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
1355    def test_rsample_requires_grad(self):
1356        for Dist, params in _get_examples():
1357            for i, param in enumerate(params):
1358                if not any(getattr(p, "requires_grad", False) for p in param.values()):
1359                    continue
1360                dist = Dist(**param)
1361                if not dist.has_rsample:
1362                    continue
1363                sample = dist.rsample()
1364                self.assertTrue(
1365                    sample.requires_grad,
1366                    msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .rsample() does not require grad",
1367                )
1368
1369    def test_enumerate_support_type(self):
1370        for Dist, params in _get_examples():
1371            for i, param in enumerate(params):
1372                dist = Dist(**param)
1373                try:
1374                    self.assertTrue(
1375                        type(dist.sample()) is type(dist.enumerate_support()),
1376                        msg=(
1377                            "{} example {}/{}, return type mismatch between "
1378                            + "sample and enumerate_support."
1379                        ).format(Dist.__name__, i + 1, len(params)),
1380                    )
1381                except NotImplementedError:
1382                    pass
1383
1384    def test_lazy_property_grad(self):
1385        x = torch.randn(1, requires_grad=True)
1386
1387        class Dummy:
1388            @lazy_property
1389            def y(self):
1390                return x + 1
1391
1392        def test():
1393            x.grad = None
1394            Dummy().y.backward()
1395            self.assertEqual(x.grad, torch.ones(1))
1396
1397        test()
1398        with torch.no_grad():
1399            test()
1400
1401        mean = torch.randn(2)
1402        cov = torch.eye(2, requires_grad=True)
1403        distn = MultivariateNormal(mean, cov)
1404        with torch.no_grad():
1405            distn.scale_tril
1406        distn.scale_tril.sum().backward()
1407        self.assertIsNotNone(cov.grad)
1408
1409    def test_has_examples(self):
1410        distributions_with_examples = {e.Dist for e in _get_examples()}
1411        for Dist in globals().values():
1412            if (
1413                isinstance(Dist, type)
1414                and issubclass(Dist, Distribution)
1415                and Dist is not Distribution
1416                and Dist is not ExponentialFamily
1417            ):
1418                self.assertIn(
1419                    Dist,
1420                    distributions_with_examples,
1421                    f"Please add {Dist.__name__} to the _get_examples list in test_distributions.py",
1422                )
1423
1424    def test_support_attributes(self):
1425        for Dist, params in _get_examples():
1426            for param in params:
1427                d = Dist(**param)
1428                event_dim = len(d.event_shape)
1429                self.assertEqual(d.support.event_dim, event_dim)
1430                try:
1431                    self.assertEqual(Dist.support.event_dim, event_dim)
1432                except NotImplementedError:
1433                    pass
1434                is_discrete = d.support.is_discrete
1435                try:
1436                    self.assertEqual(Dist.support.is_discrete, is_discrete)
1437                except NotImplementedError:
1438                    pass
1439
1440    def test_distribution_expand(self):
1441        shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
1442        for Dist, params in _get_examples():
1443            for param in params:
1444                for shape in shapes:
1445                    d = Dist(**param)
1446                    expanded_shape = shape + d.batch_shape
1447                    original_shape = d.batch_shape + d.event_shape
1448                    expected_shape = shape + original_shape
1449                    expanded = d.expand(batch_shape=list(expanded_shape))
1450                    sample = expanded.sample()
1451                    actual_shape = expanded.sample().shape
1452                    self.assertEqual(expanded.__class__, d.__class__)
1453                    self.assertEqual(d.sample().shape, original_shape)
1454                    self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
1455                    self.assertEqual(actual_shape, expected_shape)
1456                    self.assertEqual(expanded.batch_shape, expanded_shape)
1457                    try:
1458                        self.assertEqual(
1459                            expanded.mean, d.mean.expand(expanded_shape + d.event_shape)
1460                        )
1461                        self.assertEqual(
1462                            expanded.variance,
1463                            d.variance.expand(expanded_shape + d.event_shape),
1464                        )
1465                    except NotImplementedError:
1466                        pass
1467
1468    def test_distribution_subclass_expand(self):
1469        expand_by = torch.Size((2,))
1470        for Dist, params in _get_examples():
1471
1472            class SubClass(Dist):
1473                pass
1474
1475            for param in params:
1476                d = SubClass(**param)
1477                expanded_shape = expand_by + d.batch_shape
1478                original_shape = d.batch_shape + d.event_shape
1479                expected_shape = expand_by + original_shape
1480                expanded = d.expand(batch_shape=expanded_shape)
1481                sample = expanded.sample()
1482                actual_shape = expanded.sample().shape
1483                self.assertEqual(expanded.__class__, d.__class__)
1484                self.assertEqual(d.sample().shape, original_shape)
1485                self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
1486                self.assertEqual(actual_shape, expected_shape)
1487
1488    @set_default_dtype(torch.double)
1489    def test_bernoulli(self):
1490        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1491        r = torch.tensor(0.3, requires_grad=True)
1492        s = 0.3
1493        self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3))
1494        self.assertFalse(Bernoulli(p).sample().requires_grad)
1495        self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,))
1496        self.assertEqual(Bernoulli(r).sample().size(), ())
1497        self.assertEqual(
1498            Bernoulli(r).sample((3, 2)).size(),
1499            (
1500                3,
1501                2,
1502            ),
1503        )
1504        self.assertEqual(Bernoulli(s).sample().size(), ())
1505        self._gradcheck_log_prob(Bernoulli, (p,))
1506
1507        def ref_log_prob(idx, val, log_prob):
1508            prob = p[idx]
1509            self.assertEqual(log_prob, math.log(prob if val else 1 - prob))
1510
1511        self._check_log_prob(Bernoulli(p), ref_log_prob)
1512        self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
1513        self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
1514
1515        # check entropy computation
1516        self.assertEqual(
1517            Bernoulli(p).entropy(),
1518            torch.tensor([0.6108, 0.5004, 0.6730]),
1519            atol=1e-4,
1520            rtol=0,
1521        )
1522        self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
1523        self.assertEqual(
1524            Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4, rtol=0
1525        )
1526
1527        self._check_forward_ad(torch.bernoulli)
1528        self._check_forward_ad(lambda x: x.bernoulli_())
1529        self._check_forward_ad(lambda x: x.bernoulli_(x.clone().detach()))
1530        self._check_forward_ad(lambda x: x.bernoulli_(x))
1531
1532    def test_bernoulli_enumerate_support(self):
1533        examples = [
1534            ({"probs": [0.1]}, [[0], [1]]),
1535            ({"probs": [0.1, 0.9]}, [[0], [1]]),
1536            ({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]),
1537        ]
1538        self._check_enumerate_support(Bernoulli, examples)
1539
1540    def test_bernoulli_3d(self):
1541        p = torch.full((2, 3, 5), 0.5).requires_grad_()
1542        self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
1543        self.assertEqual(
1544            Bernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5)
1545        )
1546        self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
1547
1548    @set_default_dtype(torch.double)
1549    def test_geometric(self):
1550        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1551        r = torch.tensor(0.3, requires_grad=True)
1552        s = 0.3
1553        self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
1554        self.assertEqual(Geometric(1).sample(), 0)
1555        self.assertEqual(Geometric(1).log_prob(torch.tensor(1.0)), -inf)
1556        self.assertEqual(Geometric(1).log_prob(torch.tensor(0.0)), 0)
1557        self.assertFalse(Geometric(p).sample().requires_grad)
1558        self.assertEqual(Geometric(r).sample((8,)).size(), (8,))
1559        self.assertEqual(Geometric(r).sample().size(), ())
1560        self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2))
1561        self.assertEqual(Geometric(s).sample().size(), ())
1562        self._gradcheck_log_prob(Geometric, (p,))
1563        self.assertRaises(ValueError, lambda: Geometric(0))
1564        self.assertRaises(NotImplementedError, Geometric(r).rsample)
1565
1566        self._check_forward_ad(lambda x: x.geometric_(0.2))
1567
1568    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1569    @set_default_dtype(torch.double)
1570    def test_geometric_log_prob_and_entropy(self):
1571        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1572        s = 0.3
1573
1574        def ref_log_prob(idx, val, log_prob):
1575            prob = p[idx].detach()
1576            self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val))
1577
1578        self._check_log_prob(Geometric(p), ref_log_prob)
1579        self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
1580
1581        # check entropy computation
1582        self.assertEqual(
1583            Geometric(p).entropy(),
1584            scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(),
1585            atol=1e-3,
1586            rtol=0,
1587        )
1588        self.assertEqual(
1589            float(Geometric(s).entropy()),
1590            scipy.stats.geom(s, loc=-1).entropy().item(),
1591            atol=1e-3,
1592            rtol=0,
1593        )
1594
1595    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1596    def test_geometric_sample(self):
1597        set_rng_seed(0)  # see Note [Randomized statistical tests]
1598        for prob in [0.01, 0.18, 0.8]:
1599            self._check_sampler_discrete(
1600                Geometric(prob),
1601                scipy.stats.geom(p=prob, loc=-1),
1602                f"Geometric(prob={prob})",
1603            )
1604
1605    @set_default_dtype(torch.double)
1606    def test_binomial(self):
1607        p = torch.arange(0.05, 1, 0.1).requires_grad_()
1608        for total_count in [1, 2, 10]:
1609            self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p])
1610            self._gradcheck_log_prob(
1611                lambda p: Binomial(total_count, None, p.log()), [p]
1612            )
1613        self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
1614
1615    test_binomial_half = set_default_dtype(torch.float16)(test_binomial)
1616    test_binomial_bfloat16 = set_default_dtype(torch.bfloat16)(test_binomial)
1617
1618    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1619    def test_binomial_sample(self):
1620        set_rng_seed(0)  # see Note [Randomized statistical tests]
1621        for prob in [0.01, 0.1, 0.5, 0.8, 0.9]:
1622            for count in [2, 10, 100, 500]:
1623                self._check_sampler_discrete(
1624                    Binomial(total_count=count, probs=prob),
1625                    scipy.stats.binom(count, prob),
1626                    f"Binomial(total_count={count}, probs={prob})",
1627                )
1628
1629    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1630    @set_default_dtype(torch.double)
1631    def test_binomial_log_prob_and_entropy(self):
1632        probs = torch.arange(0.05, 1, 0.1)
1633        for total_count in [1, 2, 10]:
1634
1635            def ref_log_prob(idx, x, log_prob):
1636                p = probs.view(-1)[idx].item()
1637                expected = scipy.stats.binom(total_count, p).logpmf(x)
1638                self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1639
1640            self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
1641            logits = probs_to_logits(probs, is_binary=True)
1642            self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
1643
1644            bin = Binomial(total_count, logits=logits)
1645            self.assertEqual(
1646                bin.entropy(),
1647                scipy.stats.binom(
1648                    total_count, bin.probs.detach().numpy(), loc=-1
1649                ).entropy(),
1650                atol=1e-3,
1651                rtol=0,
1652            )
1653
1654    def test_binomial_stable(self):
1655        logits = torch.tensor([-100.0, 100.0], dtype=torch.float)
1656        total_count = 1.0
1657        x = torch.tensor([0.0, 0.0], dtype=torch.float)
1658        log_prob = Binomial(total_count, logits=logits).log_prob(x)
1659        self.assertTrue(torch.isfinite(log_prob).all())
1660
1661        # make sure that the grad at logits=0, value=0 is 0.5
1662        x = torch.tensor(0.0, requires_grad=True)
1663        y = Binomial(total_count, logits=x).log_prob(torch.tensor(0.0))
1664        self.assertEqual(grad(y, x)[0], torch.tensor(-0.5))
1665
1666    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1667    @set_default_dtype(torch.double)
1668    def test_binomial_log_prob_vectorized_count(self):
1669        probs = torch.tensor([0.2, 0.7, 0.9])
1670        for total_count, sample in [
1671            (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])),
1672            (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])),
1673        ]:
1674            log_prob = Binomial(total_count, probs).log_prob(sample)
1675            expected = scipy.stats.binom(
1676                total_count.cpu().numpy(), probs.cpu().numpy()
1677            ).logpmf(sample)
1678            self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
1679
1680    def test_binomial_enumerate_support(self):
1681        examples = [
1682            ({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]),
1683            ({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]),
1684            (
1685                {"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3},
1686                [[[0]], [[1]], [[2]], [[3]]],
1687            ),
1688        ]
1689        self._check_enumerate_support(Binomial, examples)
1690
1691    @set_default_dtype(torch.double)
1692    def test_binomial_extreme_vals(self):
1693        total_count = 100
1694        bin0 = Binomial(total_count, 0)
1695        self.assertEqual(bin0.sample(), 0)
1696        self.assertEqual(bin0.log_prob(torch.tensor([0.0]))[0], 0, atol=1e-3, rtol=0)
1697        self.assertEqual(float(bin0.log_prob(torch.tensor([1.0])).exp()), 0)
1698        bin1 = Binomial(total_count, 1)
1699        self.assertEqual(bin1.sample(), total_count)
1700        self.assertEqual(
1701            bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0
1702        )
1703        self.assertEqual(
1704            float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0
1705        )
1706        zero_counts = torch.zeros(torch.Size((2, 2)))
1707        bin2 = Binomial(zero_counts, 1)
1708        self.assertEqual(bin2.sample(), zero_counts)
1709        self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
1710
1711    @set_default_dtype(torch.double)
1712    def test_binomial_vectorized_count(self):
1713        set_rng_seed(1)  # see Note [Randomized statistical tests]
1714        total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64)
1715        bin0 = Binomial(total_count, torch.tensor(1.0))
1716        self.assertEqual(bin0.sample(), total_count)
1717        bin1 = Binomial(total_count, torch.tensor(0.5))
1718        samples = bin1.sample(torch.Size((100000,)))
1719        self.assertTrue((samples <= total_count.type_as(samples)).all())
1720        self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0)
1721        self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0)
1722
1723    @set_default_dtype(torch.double)
1724    def test_negative_binomial(self):
1725        p = torch.arange(0.05, 1, 0.1).requires_grad_()
1726        for total_count in [1, 2, 10]:
1727            self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p])
1728            self._gradcheck_log_prob(
1729                lambda p: NegativeBinomial(total_count, None, p.log()), [p]
1730            )
1731        self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample)
1732        self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy)
1733
1734    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1735    def test_negative_binomial_log_prob(self):
1736        probs = torch.arange(0.05, 1, 0.1)
1737        for total_count in [1, 2, 10]:
1738
1739            def ref_log_prob(idx, x, log_prob):
1740                p = probs.view(-1)[idx].item()
1741                expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
1742                self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1743
1744            self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
1745            logits = probs_to_logits(probs, is_binary=True)
1746            self._check_log_prob(
1747                NegativeBinomial(total_count, logits=logits), ref_log_prob
1748            )
1749
1750    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1751    @set_default_dtype(torch.double)
1752    def test_negative_binomial_log_prob_vectorized_count(self):
1753        probs = torch.tensor([0.2, 0.7, 0.9])
1754        for total_count, sample in [
1755            (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])),
1756            (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])),
1757        ]:
1758            log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
1759            expected = scipy.stats.nbinom(
1760                total_count.cpu().numpy(), 1 - probs.cpu().numpy()
1761            ).logpmf(sample)
1762            self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
1763
1764    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1765    def test_zero_excluded_binomial(self):
1766        vals = Binomial(
1767            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda()
1768        ).sample(torch.Size((100000000,)))
1769        self.assertTrue((vals >= 0).all())
1770        vals = Binomial(
1771            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda()
1772        ).sample(torch.Size((100000000,)))
1773        self.assertTrue((vals < 2).all())
1774        vals = Binomial(
1775            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda()
1776        ).sample(torch.Size((10000,)))
1777        # vals should be roughly half zeroes, half ones
1778        assert (vals == 0.0).sum() > 4000
1779        assert (vals == 1.0).sum() > 4000
1780
1781    @set_default_dtype(torch.double)
1782    def test_multinomial_1d(self):
1783        total_count = 10
1784        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1785        self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
1786        self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
1787        self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
1788        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1789        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1790        self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
1791
1792    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1793    @set_default_dtype(torch.double)
1794    def test_multinomial_1d_log_prob_and_entropy(self):
1795        total_count = 10
1796        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1797        dist = Multinomial(total_count, probs=p)
1798        x = dist.sample()
1799        log_prob = dist.log_prob(x)
1800        expected = torch.tensor(
1801            scipy.stats.multinomial.logpmf(
1802                x.numpy(), n=total_count, p=dist.probs.detach().numpy()
1803            )
1804        )
1805        self.assertEqual(log_prob, expected)
1806
1807        dist = Multinomial(total_count, logits=p.log())
1808        x = dist.sample()
1809        log_prob = dist.log_prob(x)
1810        expected = torch.tensor(
1811            scipy.stats.multinomial.logpmf(
1812                x.numpy(), n=total_count, p=dist.probs.detach().numpy()
1813            )
1814        )
1815        self.assertEqual(log_prob, expected)
1816
1817        expected = scipy.stats.multinomial.entropy(
1818            total_count, dist.probs.detach().numpy()
1819        )
1820        self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0)
1821
1822    @set_default_dtype(torch.double)
1823    def test_multinomial_2d(self):
1824        total_count = 10
1825        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1826        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1827        p = torch.tensor(probabilities, requires_grad=True)
1828        s = torch.tensor(probabilities_1, requires_grad=True)
1829        self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
1830        self.assertEqual(
1831            Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
1832        )
1833        self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
1834        set_rng_seed(0)
1835        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1836        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1837
1838        # sample check for extreme value of probs
1839        self.assertEqual(
1840            Multinomial(total_count, s).sample(),
1841            torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64),
1842        )
1843
1844    @set_default_dtype(torch.double)
1845    def test_categorical_1d(self):
1846        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1847        self.assertTrue(is_all_nan(Categorical(p).mean))
1848        self.assertTrue(is_all_nan(Categorical(p).variance))
1849        self.assertEqual(Categorical(p).sample().size(), ())
1850        self.assertFalse(Categorical(p).sample().requires_grad)
1851        self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
1852        self.assertEqual(Categorical(p).sample((1,)).size(), (1,))
1853        self._gradcheck_log_prob(Categorical, (p,))
1854        self.assertRaises(NotImplementedError, Categorical(p).rsample)
1855
1856    @set_default_dtype(torch.double)
1857    def test_categorical_2d(self):
1858        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1859        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1860        p = torch.tensor(probabilities, requires_grad=True)
1861        s = torch.tensor(probabilities_1, requires_grad=True)
1862        self.assertEqual(Categorical(p).mean.size(), (2,))
1863        self.assertEqual(Categorical(p).variance.size(), (2,))
1864        self.assertTrue(is_all_nan(Categorical(p).mean))
1865        self.assertTrue(is_all_nan(Categorical(p).variance))
1866        self.assertEqual(Categorical(p).sample().size(), (2,))
1867        self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
1868        self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2))
1869        self._gradcheck_log_prob(Categorical, (p,))
1870
1871        # sample check for extreme value of probs
1872        set_rng_seed(0)
1873        self.assertEqual(
1874            Categorical(s).sample(sample_shape=(2,)), torch.tensor([[0, 1], [0, 1]])
1875        )
1876
1877        def ref_log_prob(idx, val, log_prob):
1878            sample_prob = p[idx][val] / p[idx].sum()
1879            self.assertEqual(log_prob, math.log(sample_prob))
1880
1881        self._check_log_prob(Categorical(p), ref_log_prob)
1882        self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
1883
1884        # check entropy computation
1885        self.assertEqual(
1886            Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0
1887        )
1888        self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
1889        # issue gh-40553
1890        logits = p.log()
1891        logits[1, 1] = logits[0, 2] = float("-inf")
1892        e = Categorical(logits=logits).entropy()
1893        self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0)
1894
1895    def test_categorical_enumerate_support(self):
1896        examples = [
1897            ({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]),
1898            ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]),
1899        ]
1900        self._check_enumerate_support(Categorical, examples)
1901
1902    @set_default_dtype(torch.double)
1903    def test_one_hot_categorical_1d(self):
1904        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1905        self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
1906        self.assertFalse(OneHotCategorical(p).sample().requires_grad)
1907        self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
1908        self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3))
1909        self._gradcheck_log_prob(OneHotCategorical, (p,))
1910        self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
1911
1912    @set_default_dtype(torch.double)
1913    def test_one_hot_categorical_2d(self):
1914        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1915        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1916        p = torch.tensor(probabilities, requires_grad=True)
1917        s = torch.tensor(probabilities_1, requires_grad=True)
1918        self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
1919        self.assertEqual(
1920            OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
1921        )
1922        self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
1923        self._gradcheck_log_prob(OneHotCategorical, (p,))
1924
1925        dist = OneHotCategorical(p)
1926        x = dist.sample()
1927        self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
1928
1929    def test_one_hot_categorical_enumerate_support(self):
1930        examples = [
1931            ({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
1932            ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]),
1933        ]
1934        self._check_enumerate_support(OneHotCategorical, examples)
1935
1936    def test_poisson_forward_ad(self):
1937        self._check_forward_ad(torch.poisson)
1938
1939    def test_poisson_shape(self):
1940        rate = torch.randn(2, 3).abs().requires_grad_()
1941        rate_1d = torch.randn(1).abs().requires_grad_()
1942        self.assertEqual(Poisson(rate).sample().size(), (2, 3))
1943        self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
1944        self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
1945        self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1))
1946        self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,))
1947
1948    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1949    @set_default_dtype(torch.double)
1950    def test_poisson_log_prob(self):
1951        rate = torch.randn(2, 3).abs().requires_grad_()
1952        rate_1d = torch.randn(1).abs().requires_grad_()
1953        rate_zero = torch.zeros([], requires_grad=True)
1954
1955        def ref_log_prob(ref_rate, idx, x, log_prob):
1956            l = ref_rate.view(-1)[idx].detach()
1957            expected = scipy.stats.poisson.logpmf(x, l)
1958            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1959
1960        set_rng_seed(0)
1961        self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args))
1962        self._check_log_prob(
1963            Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args)
1964        )
1965        self._gradcheck_log_prob(Poisson, (rate,))
1966        self._gradcheck_log_prob(Poisson, (rate_1d,))
1967
1968        # We cannot check gradients automatically for zero rates because the finite difference
1969        # approximation enters the forbidden parameter space. We instead compare with the
1970        # theoretical results.
1971        dist = Poisson(rate_zero)
1972        dist.log_prob(torch.ones_like(rate_zero)).backward()
1973        self.assertEqual(rate_zero.grad, torch.inf)
1974
1975    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1976    def test_poisson_sample(self):
1977        set_rng_seed(1)  # see Note [Randomized statistical tests]
1978        saved_dtype = torch.get_default_dtype()
1979        for dtype in [torch.float, torch.double, torch.bfloat16, torch.half]:
1980            torch.set_default_dtype(dtype)
1981            for rate in [0.1, 1.0, 5.0]:
1982                self._check_sampler_discrete(
1983                    Poisson(rate),
1984                    scipy.stats.poisson(rate),
1985                    f"Poisson(lambda={rate})",
1986                    failure_rate=1e-3,
1987                )
1988        torch.set_default_dtype(saved_dtype)
1989
1990    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1991    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1992    def test_poisson_gpu_sample(self):
1993        set_rng_seed(1)
1994        for rate in [0.12, 0.9, 4.0]:
1995            self._check_sampler_discrete(
1996                Poisson(torch.tensor([rate]).cuda()),
1997                scipy.stats.poisson(rate),
1998                f"Poisson(lambda={rate}, cuda)",
1999                failure_rate=1e-3,
2000            )
2001
2002    @set_default_dtype(torch.double)
2003    def test_relaxed_bernoulli(self):
2004        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
2005        r = torch.tensor(0.3, requires_grad=True)
2006        s = 0.3
2007        temp = torch.tensor(0.67, requires_grad=True)
2008        self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3))
2009        self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad)
2010        self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,))
2011        self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ())
2012        self.assertEqual(
2013            RelaxedBernoulli(temp, r).sample((3, 2)).size(),
2014            (
2015                3,
2016                2,
2017            ),
2018        )
2019        self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ())
2020        self._gradcheck_log_prob(RelaxedBernoulli, (temp, p))
2021        self._gradcheck_log_prob(RelaxedBernoulli, (temp, r))
2022
2023        # test that rsample doesn't fail
2024        s = RelaxedBernoulli(temp, p).rsample()
2025        s.backward(torch.ones_like(s))
2026
2027    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2028    def test_rounded_relaxed_bernoulli(self):
2029        set_rng_seed(0)  # see Note [Randomized statistical tests]
2030
2031        class Rounded:
2032            def __init__(self, dist):
2033                self.dist = dist
2034
2035            def sample(self, *args, **kwargs):
2036                return torch.round(self.dist.sample(*args, **kwargs))
2037
2038        for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
2039            self._check_sampler_discrete(
2040                Rounded(RelaxedBernoulli(temp, probs)),
2041                scipy.stats.bernoulli(probs),
2042                f"Rounded(RelaxedBernoulli(temp={temp}, probs={probs}))",
2043                failure_rate=1e-3,
2044            )
2045
2046        for probs in [0.001, 0.2, 0.999]:
2047            equal_probs = torch.tensor(0.5)
2048            dist = RelaxedBernoulli(1e10, probs)
2049            s = dist.rsample()
2050            self.assertEqual(equal_probs, s)
2051
2052    @set_default_dtype(torch.double)
2053    def test_relaxed_one_hot_categorical_1d(self):
2054        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
2055        temp = torch.tensor(0.67, requires_grad=True)
2056        self.assertEqual(
2057            RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,)
2058        )
2059        self.assertFalse(
2060            RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad
2061        )
2062        self.assertEqual(
2063            RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(),
2064            (2, 2, 3),
2065        )
2066        self.assertEqual(
2067            RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(),
2068            (1, 3),
2069        )
2070        self._gradcheck_log_prob(
2071            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p)
2072        )
2073
2074    @set_default_dtype(torch.double)
2075    def test_relaxed_one_hot_categorical_2d(self):
2076        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
2077        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
2078        temp = torch.tensor([3.0], requires_grad=True)
2079        # The lower the temperature, the more unstable the log_prob gradcheck is
2080        # w.r.t. the sample. Values below 0.25 empirically fail the default tol.
2081        temp_2 = torch.tensor([0.25], requires_grad=True)
2082        p = torch.tensor(probabilities, requires_grad=True)
2083        s = torch.tensor(probabilities_1, requires_grad=True)
2084        self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
2085        self.assertEqual(
2086            RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(),
2087            (3, 4, 2, 3),
2088        )
2089        self.assertEqual(
2090            RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3)
2091        )
2092        self._gradcheck_log_prob(
2093            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p)
2094        )
2095        self._gradcheck_log_prob(
2096            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False),
2097            (temp_2, p),
2098        )
2099
2100    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2101    def test_argmax_relaxed_categorical(self):
2102        set_rng_seed(0)  # see Note [Randomized statistical tests]
2103
2104        class ArgMax:
2105            def __init__(self, dist):
2106                self.dist = dist
2107
2108            def sample(self, *args, **kwargs):
2109                s = self.dist.sample(*args, **kwargs)
2110                _, idx = torch.max(s, -1)
2111                return idx
2112
2113        class ScipyCategorical:
2114            def __init__(self, dist):
2115                self.dist = dist
2116
2117            def pmf(self, samples):
2118                new_samples = np.zeros(samples.shape + self.dist.p.shape)
2119                new_samples[np.arange(samples.shape[0]), samples] = 1
2120                return self.dist.pmf(new_samples)
2121
2122        for probs, temp in product(
2123            [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0]
2124        ):
2125            self._check_sampler_discrete(
2126                ArgMax(RelaxedOneHotCategorical(temp, probs)),
2127                ScipyCategorical(scipy.stats.multinomial(1, probs)),
2128                f"Rounded(RelaxedOneHotCategorical(temp={temp}, probs={probs}))",
2129                failure_rate=1e-3,
2130            )
2131
2132        for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]:
2133            equal_probs = torch.ones(probs.size()) / probs.size()[0]
2134            dist = RelaxedOneHotCategorical(1e10, probs)
2135            s = dist.rsample()
2136            self.assertEqual(equal_probs, s)
2137
2138    @set_default_dtype(torch.double)
2139    def test_uniform(self):
2140        low = torch.zeros(5, 5, requires_grad=True)
2141        high = (torch.ones(5, 5) * 3).requires_grad_()
2142        low_1d = torch.zeros(1, requires_grad=True)
2143        high_1d = (torch.ones(1) * 3).requires_grad_()
2144        self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2145        self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
2146        self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2147        self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2148        self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
2149
2150        # Check log_prob computation when value outside range
2151        uniform = Uniform(low_1d, high_1d, validate_args=False)
2152        above_high = torch.tensor([4.0])
2153        below_low = torch.tensor([-1.0])
2154        self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2155        self.assertEqual(uniform.log_prob(below_low).item(), -inf)
2156
2157        # check cdf computation when value outside range
2158        self.assertEqual(uniform.cdf(below_low).item(), 0)
2159        self.assertEqual(uniform.cdf(above_high).item(), 1)
2160
2161        set_rng_seed(1)
2162        self._gradcheck_log_prob(Uniform, (low, high))
2163        self._gradcheck_log_prob(Uniform, (low, 1.0))
2164        self._gradcheck_log_prob(Uniform, (0.0, high))
2165
2166        state = torch.get_rng_state()
2167        rand = low.new(low.size()).uniform_()
2168        torch.set_rng_state(state)
2169        u = Uniform(low, high).rsample()
2170        u.backward(torch.ones_like(u))
2171        self.assertEqual(low.grad, 1 - rand)
2172        self.assertEqual(high.grad, rand)
2173        low.grad.zero_()
2174        high.grad.zero_()
2175
2176        self._check_forward_ad(lambda x: x.uniform_())
2177
2178    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2179    def test_vonmises_sample(self):
2180        for loc in [0.0, math.pi / 2.0]:
2181            for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
2182                self._check_sampler_sampler(
2183                    VonMises(loc, concentration),
2184                    scipy.stats.vonmises(loc=loc, kappa=concentration),
2185                    f"VonMises(loc={loc}, concentration={concentration})",
2186                    num_samples=int(1e5),
2187                    circular=True,
2188                )
2189
2190    def test_vonmises_logprob(self):
2191        concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
2192        for concentration in concentrations:
2193            grid = torch.arange(0.0, 2 * math.pi, 1e-4)
2194            prob = VonMises(0.0, concentration).log_prob(grid).exp()
2195            norm = prob.mean().item() * 2 * math.pi
2196            self.assertLess(abs(norm - 1), 1e-3)
2197
2198    @set_default_dtype(torch.double)
2199    def test_cauchy(self):
2200        loc = torch.zeros(5, 5, requires_grad=True)
2201        scale = torch.ones(5, 5, requires_grad=True)
2202        loc_1d = torch.zeros(1, requires_grad=True)
2203        scale_1d = torch.ones(1, requires_grad=True)
2204        self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
2205        self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf)
2206        self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
2207        self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5))
2208        self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
2209        self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2210        self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))
2211
2212        set_rng_seed(1)
2213        self._gradcheck_log_prob(Cauchy, (loc, scale))
2214        self._gradcheck_log_prob(Cauchy, (loc, 1.0))
2215        self._gradcheck_log_prob(Cauchy, (0.0, scale))
2216
2217        state = torch.get_rng_state()
2218        eps = loc.new(loc.size()).cauchy_()
2219        torch.set_rng_state(state)
2220        c = Cauchy(loc, scale).rsample()
2221        c.backward(torch.ones_like(c))
2222        self.assertEqual(loc.grad, torch.ones_like(scale))
2223        self.assertEqual(scale.grad, eps)
2224        loc.grad.zero_()
2225        scale.grad.zero_()
2226
2227        self._check_forward_ad(lambda x: x.cauchy_())
2228
2229    @set_default_dtype(torch.double)
2230    def test_halfcauchy(self):
2231        scale = torch.ones(5, 5, requires_grad=True)
2232        scale_1d = torch.ones(1, requires_grad=True)
2233        self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
2234        self.assertEqual(HalfCauchy(scale_1d).variance, inf)
2235        self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
2236        self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
2237        self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
2238        self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
2239        self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))
2240
2241        set_rng_seed(1)
2242        self._gradcheck_log_prob(HalfCauchy, (scale,))
2243        self._gradcheck_log_prob(HalfCauchy, (1.0,))
2244
2245        state = torch.get_rng_state()
2246        eps = scale.new(scale.size()).cauchy_().abs_()
2247        torch.set_rng_state(state)
2248        c = HalfCauchy(scale).rsample()
2249        c.backward(torch.ones_like(c))
2250        self.assertEqual(scale.grad, eps)
2251        scale.grad.zero_()
2252
2253    @set_default_dtype(torch.double)
2254    def test_halfnormal(self):
2255        std = torch.randn(5, 5).abs().requires_grad_()
2256        std_1d = torch.randn(1).abs().requires_grad_()
2257        std_delta = torch.tensor([1e-5, 1e-5])
2258        self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
2259        self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
2260        self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
2261        self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
2262        self.assertEqual(HalfNormal(0.6).sample((1,)).size(), (1,))
2263        self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))
2264
2265        # sample check for extreme value of std
2266        set_rng_seed(1)
2267        self.assertEqual(
2268            HalfNormal(std_delta).sample(sample_shape=(1, 2)),
2269            torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
2270            atol=1e-4,
2271            rtol=0,
2272        )
2273
2274        self._gradcheck_log_prob(HalfNormal, (std,))
2275        self._gradcheck_log_prob(HalfNormal, (1.0,))
2276
2277        # check .log_prob() can broadcast.
2278        dist = HalfNormal(torch.ones(2, 1, 4))
2279        log_prob = dist.log_prob(torch.ones(3, 1))
2280        self.assertEqual(log_prob.shape, (2, 3, 4))
2281
2282    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2283    def test_halfnormal_logprob(self):
2284        std = torch.randn(5, 1).abs().requires_grad_()
2285
2286        def ref_log_prob(idx, x, log_prob):
2287            s = std.view(-1)[idx].detach()
2288            expected = scipy.stats.halfnorm(scale=s).logpdf(x)
2289            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2290
2291        self._check_log_prob(HalfNormal(std), ref_log_prob)
2292
2293    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2294    def test_halfnormal_sample(self):
2295        set_rng_seed(0)  # see Note [Randomized statistical tests]
2296        for std in [0.1, 1.0, 10.0]:
2297            self._check_sampler_sampler(
2298                HalfNormal(std),
2299                scipy.stats.halfnorm(scale=std),
2300                f"HalfNormal(scale={std})",
2301            )
2302
2303    @set_default_dtype(torch.double)
2304    def test_inversegamma(self):
2305        alpha = torch.randn(2, 3).exp().requires_grad_()
2306        beta = torch.randn(2, 3).exp().requires_grad_()
2307        alpha_1d = torch.randn(1).exp().requires_grad_()
2308        beta_1d = torch.randn(1).exp().requires_grad_()
2309        self.assertEqual(InverseGamma(alpha, beta).sample().size(), (2, 3))
2310        self.assertEqual(InverseGamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
2311        self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
2312        self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample().size(), (1,))
2313        self.assertEqual(InverseGamma(0.5, 0.5).sample().size(), ())
2314        self.assertEqual(InverseGamma(0.5, 0.5).sample((1,)).size(), (1,))
2315
2316        self._gradcheck_log_prob(InverseGamma, (alpha, beta))
2317
2318        dist = InverseGamma(torch.ones(4), torch.ones(2, 1, 1))
2319        log_prob = dist.log_prob(torch.ones(3, 1))
2320        self.assertEqual(log_prob.shape, (2, 3, 4))
2321
2322    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2323    def test_inversegamma_sample(self):
2324        set_rng_seed(0)  # see Note [Randomized statistical tests]
2325        for concentration, rate in product([2, 5], [0.1, 1.0, 10.0]):
2326            self._check_sampler_sampler(
2327                InverseGamma(concentration, rate),
2328                scipy.stats.invgamma(concentration, scale=rate),
2329                "InverseGamma()",
2330            )
2331
2332    @set_default_dtype(torch.double)
2333    def test_lognormal(self):
2334        mean = torch.randn(5, 5, requires_grad=True)
2335        std = torch.randn(5, 5).abs().requires_grad_()
2336        mean_1d = torch.randn(1, requires_grad=True)
2337        std_1d = torch.randn(1).abs().requires_grad_()
2338        mean_delta = torch.tensor([1.0, 0.0])
2339        std_delta = torch.tensor([1e-5, 1e-5])
2340        self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
2341        self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
2342        self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
2343        self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
2344        self.assertEqual(LogNormal(0.2, 0.6).sample((1,)).size(), (1,))
2345        self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,))
2346
2347        # sample check for extreme value of mean, std
2348        set_rng_seed(1)
2349        self.assertEqual(
2350            LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
2351            torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
2352            atol=1e-4,
2353            rtol=0,
2354        )
2355
2356        self._gradcheck_log_prob(LogNormal, (mean, std))
2357        self._gradcheck_log_prob(LogNormal, (mean, 1.0))
2358        self._gradcheck_log_prob(LogNormal, (0.0, std))
2359
2360        # check .log_prob() can broadcast.
2361        dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
2362        log_prob = dist.log_prob(torch.ones(3, 1))
2363        self.assertEqual(log_prob.shape, (2, 3, 4))
2364
2365        self._check_forward_ad(lambda x: x.log_normal_())
2366
2367    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2368    def test_lognormal_logprob(self):
2369        mean = torch.randn(5, 1, requires_grad=True)
2370        std = torch.randn(5, 1).abs().requires_grad_()
2371
2372        def ref_log_prob(idx, x, log_prob):
2373            m = mean.view(-1)[idx].detach()
2374            s = std.view(-1)[idx].detach()
2375            expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
2376            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2377
2378        self._check_log_prob(LogNormal(mean, std), ref_log_prob)
2379
2380    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2381    def test_lognormal_sample(self):
2382        set_rng_seed(0)  # see Note [Randomized statistical tests]
2383        for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2384            self._check_sampler_sampler(
2385                LogNormal(mean, std),
2386                scipy.stats.lognorm(scale=math.exp(mean), s=std),
2387                f"LogNormal(loc={mean}, scale={std})",
2388            )
2389
2390    @set_default_dtype(torch.double)
2391    def test_logisticnormal(self):
2392        set_rng_seed(1)  # see Note [Randomized statistical tests]
2393        mean = torch.randn(5, 5).requires_grad_()
2394        std = torch.randn(5, 5).abs().requires_grad_()
2395        mean_1d = torch.randn(1).requires_grad_()
2396        std_1d = torch.randn(1).abs().requires_grad_()
2397        mean_delta = torch.tensor([1.0, 0.0])
2398        std_delta = torch.tensor([1e-5, 1e-5])
2399        self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
2400        self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
2401        self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
2402        self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
2403        self.assertEqual(LogisticNormal(0.2, 0.6).sample().size(), (2,))
2404        self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))
2405
2406        # sample check for extreme value of mean, std
2407        set_rng_seed(1)
2408        self.assertEqual(
2409            LogisticNormal(mean_delta, std_delta).sample(),
2410            torch.tensor(
2411                [
2412                    math.exp(1) / (1.0 + 1.0 + math.exp(1)),
2413                    1.0 / (1.0 + 1.0 + math.exp(1)),
2414                    1.0 / (1.0 + 1.0 + math.exp(1)),
2415                ]
2416            ),
2417            atol=1e-4,
2418            rtol=0,
2419        )
2420
2421        # TODO: gradcheck seems to mutate the sample values so that the simplex
2422        # constraint fails by a very small margin.
2423        self._gradcheck_log_prob(
2424            lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std)
2425        )
2426        self._gradcheck_log_prob(
2427            lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0)
2428        )
2429        self._gradcheck_log_prob(
2430            lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std)
2431        )
2432
2433    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2434    def test_logisticnormal_logprob(self):
2435        mean = torch.randn(5, 7).requires_grad_()
2436        std = torch.randn(5, 7).abs().requires_grad_()
2437
2438        # Smoke test for now
2439        # TODO: Once _check_log_prob works with multidimensional distributions,
2440        #       add proper testing of the log probabilities.
2441        dist = LogisticNormal(mean, std)
2442        assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,)
2443
2444    def _get_logistic_normal_ref_sampler(self, base_dist):
2445        def _sampler(num_samples):
2446            x = base_dist.rvs(num_samples)
2447            offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1))
2448            z = 1.0 / (1.0 + np.exp(offset - x))
2449            z_cumprod = np.cumprod(1 - z, axis=-1)
2450            y1 = np.pad(z, ((0, 0), (0, 1)), mode="constant", constant_values=1.0)
2451            y2 = np.pad(
2452                z_cumprod, ((0, 0), (1, 0)), mode="constant", constant_values=1.0
2453            )
2454            return y1 * y2
2455
2456        return _sampler
2457
2458    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2459    def test_logisticnormal_sample(self):
2460        set_rng_seed(0)  # see Note [Randomized statistical tests]
2461        means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)])
2462        covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)])
2463        for mean, cov in product(means, covs):
2464            base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
2465            ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
2466            ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist)
2467            mean_th = torch.tensor(mean)
2468            std_th = torch.tensor(np.sqrt(np.diag(cov)))
2469            self._check_sampler_sampler(
2470                LogisticNormal(mean_th, std_th),
2471                ref_dist,
2472                f"LogisticNormal(loc={mean_th}, scale={std_th})",
2473                multivariate=True,
2474            )
2475
2476    def test_mixture_same_family_shape(self):
2477        normal_case_1d = MixtureSameFamily(
2478            Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5))
2479        )
2480        normal_case_1d_batch = MixtureSameFamily(
2481            Categorical(torch.rand(3, 5)), Normal(torch.randn(3, 5), torch.rand(3, 5))
2482        )
2483        normal_case_1d_multi_batch = MixtureSameFamily(
2484            Categorical(torch.rand(4, 3, 5)),
2485            Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)),
2486        )
2487        normal_case_2d = MixtureSameFamily(
2488            Categorical(torch.rand(5)),
2489            Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1),
2490        )
2491        normal_case_2d_batch = MixtureSameFamily(
2492            Categorical(torch.rand(3, 5)),
2493            Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1),
2494        )
2495        normal_case_2d_multi_batch = MixtureSameFamily(
2496            Categorical(torch.rand(4, 3, 5)),
2497            Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1),
2498        )
2499
2500        self.assertEqual(normal_case_1d.sample().size(), ())
2501        self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
2502        self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
2503        self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
2504        self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
2505        self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
2506        self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3))
2507        self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3))
2508        self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3))
2509
2510        self.assertEqual(normal_case_2d.sample().size(), (2,))
2511        self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
2512        self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
2513        self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
2514        self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
2515        self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
2516        self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2))
2517        self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2))
2518        self.assertEqual(
2519            normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2)
2520        )
2521
2522    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2523    def test_mixture_same_family_log_prob(self):
2524        probs = torch.rand(5, 5).softmax(dim=-1)
2525        loc = torch.randn(5, 5)
2526        scale = torch.rand(5, 5)
2527
2528        def ref_log_prob(idx, x, log_prob):
2529            p = probs[idx].numpy()
2530            m = loc[idx].numpy()
2531            s = scale[idx].numpy()
2532            mix = scipy.stats.multinomial(1, p)
2533            comp = scipy.stats.norm(m, s)
2534            expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p))
2535            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2536
2537        self._check_log_prob(
2538            MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
2539            ref_log_prob,
2540        )
2541
2542    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2543    def test_mixture_same_family_sample(self):
2544        probs = torch.rand(5).softmax(dim=-1)
2545        loc = torch.randn(5)
2546        scale = torch.rand(5)
2547
2548        class ScipyMixtureNormal:
2549            def __init__(self, probs, mu, std):
2550                self.probs = probs
2551                self.mu = mu
2552                self.std = std
2553
2554            def rvs(self, n_sample):
2555                comp_samples = [
2556                    scipy.stats.norm(m, s).rvs(n_sample)
2557                    for m, s in zip(self.mu, self.std)
2558                ]
2559                mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample)
2560                samples = []
2561                for i in range(n_sample):
2562                    samples.append(comp_samples[mix_samples[i].argmax()][i])
2563                return np.asarray(samples)
2564
2565        self._check_sampler_sampler(
2566            MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
2567            ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
2568            f"""MixtureSameFamily(Categorical(probs={probs}),
2569            Normal(loc={loc}, scale={scale}))""",
2570        )
2571
2572    @set_default_dtype(torch.double)
2573    def test_normal(self):
2574        loc = torch.randn(5, 5, requires_grad=True)
2575        scale = torch.randn(5, 5).abs().requires_grad_()
2576        loc_1d = torch.randn(1, requires_grad=True)
2577        scale_1d = torch.randn(1).abs().requires_grad_()
2578        loc_delta = torch.tensor([1.0, 0.0])
2579        scale_delta = torch.tensor([1e-5, 1e-5])
2580        self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
2581        self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
2582        self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2583        self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,))
2584        self.assertEqual(Normal(0.2, 0.6).sample((1,)).size(), (1,))
2585        self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,))
2586
2587        # sample check for extreme value of mean, std
2588        set_rng_seed(1)
2589        self.assertEqual(
2590            Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
2591            torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
2592            atol=1e-4,
2593            rtol=0,
2594        )
2595
2596        self._gradcheck_log_prob(Normal, (loc, scale))
2597        self._gradcheck_log_prob(Normal, (loc, 1.0))
2598        self._gradcheck_log_prob(Normal, (0.0, scale))
2599
2600        state = torch.get_rng_state()
2601        eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale))
2602        torch.set_rng_state(state)
2603        z = Normal(loc, scale).rsample()
2604        z.backward(torch.ones_like(z))
2605        self.assertEqual(loc.grad, torch.ones_like(loc))
2606        self.assertEqual(scale.grad, eps)
2607        loc.grad.zero_()
2608        scale.grad.zero_()
2609        self.assertEqual(z.size(), (5, 5))
2610
2611        def ref_log_prob(idx, x, log_prob):
2612            m = loc.view(-1)[idx]
2613            s = scale.view(-1)[idx]
2614            expected = math.exp(-((x - m) ** 2) / (2 * s**2)) / math.sqrt(
2615                2 * math.pi * s**2
2616            )
2617            self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0)
2618
2619        self._check_log_prob(Normal(loc, scale), ref_log_prob)
2620        self._check_forward_ad(torch.normal)
2621        self._check_forward_ad(lambda x: torch.normal(x, 0.5))
2622        self._check_forward_ad(lambda x: torch.normal(0.2, x))
2623        self._check_forward_ad(lambda x: torch.normal(x, x))
2624        self._check_forward_ad(lambda x: x.normal_())
2625
2626    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2627    def test_normal_sample(self):
2628        set_rng_seed(0)  # see Note [Randomized statistical tests]
2629        for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2630            self._check_sampler_sampler(
2631                Normal(loc, scale),
2632                scipy.stats.norm(loc=loc, scale=scale),
2633                f"Normal(mean={loc}, std={scale})",
2634            )
2635
2636    @set_default_dtype(torch.double)
2637    def test_lowrank_multivariate_normal_shape(self):
2638        mean = torch.randn(5, 3, requires_grad=True)
2639        mean_no_batch = torch.randn(3, requires_grad=True)
2640        mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
2641
2642        # construct PSD covariance
2643        cov_factor = torch.randn(3, 1, requires_grad=True)
2644        cov_diag = torch.randn(3).abs().requires_grad_()
2645
2646        # construct batch of PSD covariances
2647        cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True)
2648        cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_()
2649
2650        # ensure that sample, batch, event shapes all handled correctly
2651        self.assertEqual(
2652            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample().size(),
2653            (5, 3),
2654        )
2655        self.assertEqual(
2656            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2657            .sample()
2658            .size(),
2659            (3,),
2660        )
2661        self.assertEqual(
2662            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2663            .sample()
2664            .size(),
2665            (6, 5, 3),
2666        )
2667        self.assertEqual(
2668            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2,)).size(),
2669            (2, 5, 3),
2670        )
2671        self.assertEqual(
2672            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2673            .sample((2,))
2674            .size(),
2675            (2, 3),
2676        )
2677        self.assertEqual(
2678            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2679            .sample((2,))
2680            .size(),
2681            (2, 6, 5, 3),
2682        )
2683        self.assertEqual(
2684            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2, 7)).size(),
2685            (2, 7, 5, 3),
2686        )
2687        self.assertEqual(
2688            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2689            .sample((2, 7))
2690            .size(),
2691            (2, 7, 3),
2692        )
2693        self.assertEqual(
2694            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2695            .sample((2, 7))
2696            .size(),
2697            (2, 7, 6, 5, 3),
2698        )
2699        self.assertEqual(
2700            LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched)
2701            .sample((2, 7))
2702            .size(),
2703            (2, 7, 6, 5, 3),
2704        )
2705        self.assertEqual(
2706            LowRankMultivariateNormal(
2707                mean_no_batch, cov_factor_batched, cov_diag_batched
2708            )
2709            .sample((2, 7))
2710            .size(),
2711            (2, 7, 6, 5, 3),
2712        )
2713        self.assertEqual(
2714            LowRankMultivariateNormal(
2715                mean_multi_batch, cov_factor_batched, cov_diag_batched
2716            )
2717            .sample((2, 7))
2718            .size(),
2719            (2, 7, 6, 5, 3),
2720        )
2721
2722        # check gradients
2723        self._gradcheck_log_prob(
2724            LowRankMultivariateNormal, (mean, cov_factor, cov_diag)
2725        )
2726        self._gradcheck_log_prob(
2727            LowRankMultivariateNormal, (mean_multi_batch, cov_factor, cov_diag)
2728        )
2729        self._gradcheck_log_prob(
2730            LowRankMultivariateNormal,
2731            (mean_multi_batch, cov_factor_batched, cov_diag_batched),
2732        )
2733
2734    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2735    def test_lowrank_multivariate_normal_log_prob(self):
2736        mean = torch.randn(3, requires_grad=True)
2737        cov_factor = torch.randn(3, 1, requires_grad=True)
2738        cov_diag = torch.randn(3).abs().requires_grad_()
2739        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2740
2741        # check that logprob values match scipy logpdf,
2742        # and that covariance and scale_tril parameters are equivalent
2743        dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2744        ref_dist = scipy.stats.multivariate_normal(
2745            mean.detach().numpy(), cov.detach().numpy()
2746        )
2747
2748        x = dist1.sample((10,))
2749        expected = ref_dist.logpdf(x.numpy())
2750
2751        self.assertEqual(
2752            0.0,
2753            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
2754            atol=1e-3,
2755            rtol=0,
2756        )
2757
2758        # Double-check that batched versions behave the same as unbatched
2759        mean = torch.randn(5, 3, requires_grad=True)
2760        cov_factor = torch.randn(5, 3, 2, requires_grad=True)
2761        cov_diag = torch.randn(5, 3).abs().requires_grad_()
2762
2763        dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2764        dist_unbatched = [
2765            LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i])
2766            for i in range(mean.size(0))
2767        ]
2768
2769        x = dist_batched.sample((10,))
2770        batched_prob = dist_batched.log_prob(x)
2771        unbatched_prob = torch.stack(
2772            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
2773        ).t()
2774
2775        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
2776        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
2777
2778    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2779    def test_lowrank_multivariate_normal_sample(self):
2780        set_rng_seed(0)  # see Note [Randomized statistical tests]
2781        mean = torch.randn(5, requires_grad=True)
2782        cov_factor = torch.randn(5, 1, requires_grad=True)
2783        cov_diag = torch.randn(5).abs().requires_grad_()
2784        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2785
2786        self._check_sampler_sampler(
2787            LowRankMultivariateNormal(mean, cov_factor, cov_diag),
2788            scipy.stats.multivariate_normal(
2789                mean.detach().numpy(), cov.detach().numpy()
2790            ),
2791            f"LowRankMultivariateNormal(loc={mean}, cov_factor={cov_factor}, cov_diag={cov_diag})",
2792            multivariate=True,
2793        )
2794
2795    def test_lowrank_multivariate_normal_properties(self):
2796        loc = torch.randn(5)
2797        cov_factor = torch.randn(5, 2)
2798        cov_diag = torch.randn(5).abs()
2799        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2800        m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag)
2801        m2 = MultivariateNormal(loc=loc, covariance_matrix=cov)
2802        self.assertEqual(m1.mean, m2.mean)
2803        self.assertEqual(m1.variance, m2.variance)
2804        self.assertEqual(m1.covariance_matrix, m2.covariance_matrix)
2805        self.assertEqual(m1.scale_tril, m2.scale_tril)
2806        self.assertEqual(m1.precision_matrix, m2.precision_matrix)
2807        self.assertEqual(m1.entropy(), m2.entropy())
2808
2809    def test_lowrank_multivariate_normal_moments(self):
2810        set_rng_seed(0)  # see Note [Randomized statistical tests]
2811        mean = torch.randn(5)
2812        cov_factor = torch.randn(5, 2)
2813        cov_diag = torch.randn(5).abs()
2814        d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2815        samples = d.rsample((100000,))
2816        empirical_mean = samples.mean(0)
2817        self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
2818        empirical_var = samples.var(0)
2819        self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0)
2820
2821    @set_default_dtype(torch.double)
2822    def test_multivariate_normal_shape(self):
2823        mean = torch.randn(5, 3, requires_grad=True)
2824        mean_no_batch = torch.randn(3, requires_grad=True)
2825        mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
2826
2827        # construct PSD covariance
2828        tmp = torch.randn(3, 10)
2829        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
2830        prec = cov.inverse().requires_grad_()
2831        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
2832
2833        # construct batch of PSD covariances
2834        tmp = torch.randn(6, 5, 3, 10)
2835        cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
2836        prec_batched = cov_batched.inverse()
2837        scale_tril_batched = torch.linalg.cholesky(cov_batched)
2838
2839        # ensure that sample, batch, event shapes all handled correctly
2840        self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
2841        self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
2842        self.assertEqual(
2843            MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3)
2844        )
2845        self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
2846        self.assertEqual(
2847            MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3)
2848        )
2849        self.assertEqual(
2850            MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3)
2851        )
2852        self.assertEqual(
2853            MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3)
2854        )
2855        self.assertEqual(
2856            MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3)
2857        )
2858        self.assertEqual(
2859            MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(),
2860            (2, 7, 6, 5, 3),
2861        )
2862        self.assertEqual(
2863            MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3)
2864        )
2865        self.assertEqual(
2866            MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(),
2867            (2, 7, 6, 5, 3),
2868        )
2869        self.assertEqual(
2870            MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(),
2871            (2, 7, 6, 5, 3),
2872        )
2873        self.assertEqual(
2874            MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(),
2875            (2, 7, 5, 3),
2876        )
2877        self.assertEqual(
2878            MultivariateNormal(mean, precision_matrix=prec_batched)
2879            .sample((2, 7))
2880            .size(),
2881            (2, 7, 6, 5, 3),
2882        )
2883        self.assertEqual(
2884            MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(),
2885            (2, 7, 5, 3),
2886        )
2887        self.assertEqual(
2888            MultivariateNormal(mean, scale_tril=scale_tril_batched)
2889            .sample((2, 7))
2890            .size(),
2891            (2, 7, 6, 5, 3),
2892        )
2893
2894        # check gradients
2895        # We write a custom gradcheck function to maintain the symmetry
2896        # of the perturbed covariances and their inverses (precision)
2897        def multivariate_normal_log_prob_gradcheck(
2898            mean, covariance=None, precision=None, scale_tril=None
2899        ):
2900            mvn_samples = (
2901                MultivariateNormal(mean, covariance, precision, scale_tril)
2902                .sample()
2903                .requires_grad_()
2904            )
2905
2906            def gradcheck_func(samples, mu, sigma, prec, scale_tril):
2907                if sigma is not None:
2908                    sigma = 0.5 * (sigma + sigma.mT)  # Ensure symmetry of covariance
2909                if prec is not None:
2910                    prec = 0.5 * (prec + prec.mT)  # Ensure symmetry of precision
2911                if scale_tril is not None:
2912                    scale_tril = scale_tril.tril()
2913                return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
2914
2915            gradcheck(
2916                gradcheck_func,
2917                (mvn_samples, mean, covariance, precision, scale_tril),
2918                raise_exception=True,
2919            )
2920
2921        multivariate_normal_log_prob_gradcheck(mean, cov)
2922        multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov)
2923        multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched)
2924        multivariate_normal_log_prob_gradcheck(mean, None, prec)
2925        multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched)
2926        multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
2927        multivariate_normal_log_prob_gradcheck(
2928            mean_no_batch, None, None, scale_tril_batched
2929        )
2930
2931    @set_default_dtype(torch.double)
2932    def test_multivariate_normal_stable_with_precision_matrix(self):
2933        x = torch.randn(10)
2934        P = torch.exp(-((x - x.unsqueeze(-1)) ** 2))  # RBF kernel
2935        MultivariateNormal(x.new_zeros(10), precision_matrix=P)
2936
2937    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2938    def test_multivariate_normal_log_prob(self):
2939        mean = torch.randn(3, requires_grad=True)
2940        tmp = torch.randn(3, 10)
2941        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
2942        prec = cov.inverse().requires_grad_()
2943        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
2944
2945        # check that logprob values match scipy logpdf,
2946        # and that covariance and scale_tril parameters are equivalent
2947        dist1 = MultivariateNormal(mean, cov)
2948        dist2 = MultivariateNormal(mean, precision_matrix=prec)
2949        dist3 = MultivariateNormal(mean, scale_tril=scale_tril)
2950        ref_dist = scipy.stats.multivariate_normal(
2951            mean.detach().numpy(), cov.detach().numpy()
2952        )
2953
2954        x = dist1.sample((10,))
2955        expected = ref_dist.logpdf(x.numpy())
2956
2957        self.assertEqual(
2958            0.0,
2959            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
2960            atol=1e-3,
2961            rtol=0,
2962        )
2963        self.assertEqual(
2964            0.0,
2965            np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2),
2966            atol=1e-3,
2967            rtol=0,
2968        )
2969        self.assertEqual(
2970            0.0,
2971            np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2),
2972            atol=1e-3,
2973            rtol=0,
2974        )
2975
2976        # Double-check that batched versions behave the same as unbatched
2977        mean = torch.randn(5, 3, requires_grad=True)
2978        tmp = torch.randn(5, 3, 10)
2979        cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
2980
2981        dist_batched = MultivariateNormal(mean, cov)
2982        dist_unbatched = [
2983            MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0))
2984        ]
2985
2986        x = dist_batched.sample((10,))
2987        batched_prob = dist_batched.log_prob(x)
2988        unbatched_prob = torch.stack(
2989            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
2990        ).t()
2991
2992        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
2993        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
2994
2995    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2996    def test_multivariate_normal_sample(self):
2997        set_rng_seed(0)  # see Note [Randomized statistical tests]
2998        mean = torch.randn(3, requires_grad=True)
2999        tmp = torch.randn(3, 10)
3000        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3001        prec = cov.inverse().requires_grad_()
3002        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3003
3004        self._check_sampler_sampler(
3005            MultivariateNormal(mean, cov),
3006            scipy.stats.multivariate_normal(
3007                mean.detach().numpy(), cov.detach().numpy()
3008            ),
3009            f"MultivariateNormal(loc={mean}, cov={cov})",
3010            multivariate=True,
3011        )
3012        self._check_sampler_sampler(
3013            MultivariateNormal(mean, precision_matrix=prec),
3014            scipy.stats.multivariate_normal(
3015                mean.detach().numpy(), cov.detach().numpy()
3016            ),
3017            f"MultivariateNormal(loc={mean}, atol={prec})",
3018            multivariate=True,
3019        )
3020        self._check_sampler_sampler(
3021            MultivariateNormal(mean, scale_tril=scale_tril),
3022            scipy.stats.multivariate_normal(
3023                mean.detach().numpy(), cov.detach().numpy()
3024            ),
3025            f"MultivariateNormal(loc={mean}, scale_tril={scale_tril})",
3026            multivariate=True,
3027        )
3028
3029    @set_default_dtype(torch.double)
3030    def test_multivariate_normal_properties(self):
3031        loc = torch.randn(5)
3032        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
3033        m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
3034        self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
3035        self.assertEqual(
3036            m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])
3037        )
3038        self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
3039
3040    @set_default_dtype(torch.double)
3041    def test_multivariate_normal_moments(self):
3042        set_rng_seed(0)  # see Note [Randomized statistical tests]
3043        mean = torch.randn(5)
3044        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
3045        d = MultivariateNormal(mean, scale_tril=scale_tril)
3046        samples = d.rsample((100000,))
3047        empirical_mean = samples.mean(0)
3048        self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
3049        empirical_var = samples.var(0)
3050        self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
3051
3052    # We applied same tests in Multivariate Normal distribution for Wishart distribution
3053    @set_default_dtype(torch.double)
3054    def test_wishart_shape(self):
3055        set_rng_seed(0)  # see Note [Randomized statistical tests]
3056        ndim = 3
3057
3058        df = torch.rand(5, requires_grad=True) + ndim
3059        df_no_batch = torch.rand([], requires_grad=True) + ndim
3060        df_multi_batch = torch.rand(6, 5, requires_grad=True) + ndim
3061
3062        # construct PSD covariance
3063        tmp = torch.randn(ndim, 10)
3064        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3065        prec = cov.inverse().requires_grad_()
3066        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3067
3068        # construct batch of PSD covariances
3069        tmp = torch.randn(6, 5, ndim, 10)
3070        cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
3071        prec_batched = cov_batched.inverse()
3072        scale_tril_batched = torch.linalg.cholesky(cov_batched)
3073
3074        # ensure that sample, batch, event shapes all handled correctly
3075        self.assertEqual(Wishart(df, cov).sample().size(), (5, ndim, ndim))
3076        self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (ndim, ndim))
3077        self.assertEqual(
3078            Wishart(df_multi_batch, cov).sample().size(), (6, 5, ndim, ndim)
3079        )
3080        self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, ndim, ndim))
3081        self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, ndim, ndim))
3082        self.assertEqual(
3083            Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, ndim, ndim)
3084        )
3085        self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, ndim, ndim))
3086        self.assertEqual(
3087            Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, ndim, ndim)
3088        )
3089        self.assertEqual(
3090            Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim)
3091        )
3092        self.assertEqual(
3093            Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim)
3094        )
3095        self.assertEqual(
3096            Wishart(df_no_batch, cov_batched).sample((2, 7)).size(),
3097            (2, 7, 6, 5, ndim, ndim),
3098        )
3099        self.assertEqual(
3100            Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(),
3101            (2, 7, 6, 5, ndim, ndim),
3102        )
3103        self.assertEqual(
3104            Wishart(df, precision_matrix=prec).sample((2, 7)).size(),
3105            (2, 7, 5, ndim, ndim),
3106        )
3107        self.assertEqual(
3108            Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(),
3109            (2, 7, 6, 5, ndim, ndim),
3110        )
3111        self.assertEqual(
3112            Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(),
3113            (2, 7, 5, ndim, ndim),
3114        )
3115        self.assertEqual(
3116            Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(),
3117            (2, 7, 6, 5, ndim, ndim),
3118        )
3119
3120        # check gradients
3121        # Modified and applied the same tests for multivariate_normal
3122        def wishart_log_prob_gradcheck(
3123            df=None, covariance=None, precision=None, scale_tril=None
3124        ):
3125            wishart_samples = (
3126                Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()
3127            )
3128
3129            def gradcheck_func(samples, nu, sigma, prec, scale_tril):
3130                if sigma is not None:
3131                    sigma = 0.5 * (sigma + sigma.mT)  # Ensure symmetry of covariance
3132                if prec is not None:
3133                    prec = 0.5 * (prec + prec.mT)  # Ensure symmetry of precision
3134                if scale_tril is not None:
3135                    scale_tril = scale_tril.tril()
3136                return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
3137
3138            gradcheck(
3139                gradcheck_func,
3140                (wishart_samples, df, covariance, precision, scale_tril),
3141                raise_exception=True,
3142            )
3143
3144        wishart_log_prob_gradcheck(df, cov)
3145        wishart_log_prob_gradcheck(df_multi_batch, cov)
3146        wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
3147        wishart_log_prob_gradcheck(df, None, prec)
3148        wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
3149        wishart_log_prob_gradcheck(df, None, None, scale_tril)
3150        wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)
3151
3152    def test_wishart_stable_with_precision_matrix(self):
3153        set_rng_seed(0)  # see Note [Randomized statistical tests]
3154        ndim = 10
3155        x = torch.randn(ndim)
3156        P = torch.exp(-((x - x.unsqueeze(-1)) ** 2))  # RBF kernel
3157        Wishart(torch.tensor(ndim), precision_matrix=P)
3158
3159    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3160    @set_default_dtype(torch.double)
3161    def test_wishart_log_prob(self):
3162        set_rng_seed(0)  # see Note [Randomized statistical tests]
3163        ndim = 3
3164        df = torch.rand([], requires_grad=True) + ndim - 1
3165        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3166        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3167            df += 1.0
3168        tmp = torch.randn(ndim, 10)
3169        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3170        prec = cov.inverse().requires_grad_()
3171        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3172
3173        # check that logprob values match scipy logpdf,
3174        # and that covariance and scale_tril parameters are equivalent
3175        dist1 = Wishart(df, cov)
3176        dist2 = Wishart(df, precision_matrix=prec)
3177        dist3 = Wishart(df, scale_tril=scale_tril)
3178        ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
3179
3180        x = dist1.sample((1000,))
3181        expected = ref_dist.logpdf(x.transpose(0, 2).numpy())
3182
3183        self.assertEqual(
3184            0.0,
3185            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
3186            atol=1e-3,
3187            rtol=0,
3188        )
3189        self.assertEqual(
3190            0.0,
3191            np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2),
3192            atol=1e-3,
3193            rtol=0,
3194        )
3195        self.assertEqual(
3196            0.0,
3197            np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2),
3198            atol=1e-3,
3199            rtol=0,
3200        )
3201
3202        # Double-check that batched versions behave the same as unbatched
3203        df = torch.rand(5, requires_grad=True) + ndim - 1
3204        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3205        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3206            df += 1.0
3207        tmp = torch.randn(5, ndim, 10)
3208        cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
3209
3210        dist_batched = Wishart(df, cov)
3211        dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]
3212
3213        x = dist_batched.sample((1000,))
3214        batched_prob = dist_batched.log_prob(x)
3215        unbatched_prob = torch.stack(
3216            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
3217        ).t()
3218
3219        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
3220        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
3221
3222    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3223    @set_default_dtype(torch.double)
3224    def test_wishart_sample(self):
3225        set_rng_seed(0)  # see Note [Randomized statistical tests]
3226        ndim = 3
3227        df = torch.rand([], requires_grad=True) + ndim - 1
3228        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3229        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3230            df += 1.0
3231        tmp = torch.randn(ndim, 10)
3232        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3233        prec = cov.inverse().requires_grad_()
3234        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3235
3236        ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
3237
3238        self._check_sampler_sampler(
3239            Wishart(df, cov),
3240            ref_dist,
3241            f"Wishart(df={df}, covariance_matrix={cov})",
3242            multivariate=True,
3243        )
3244        self._check_sampler_sampler(
3245            Wishart(df, precision_matrix=prec),
3246            ref_dist,
3247            f"Wishart(df={df}, precision_matrix={prec})",
3248            multivariate=True,
3249        )
3250        self._check_sampler_sampler(
3251            Wishart(df, scale_tril=scale_tril),
3252            ref_dist,
3253            f"Wishart(df={df}, scale_tril={scale_tril})",
3254            multivariate=True,
3255        )
3256
3257    def test_wishart_properties(self):
3258        set_rng_seed(0)  # see Note [Randomized statistical tests]
3259        ndim = 5
3260        df = torch.rand([]) + ndim - 1
3261        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
3262        m = Wishart(df=df, scale_tril=scale_tril)
3263        self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
3264        self.assertEqual(
3265            m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])
3266        )
3267        self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
3268
3269    def test_wishart_moments(self):
3270        set_rng_seed(0)  # see Note [Randomized statistical tests]
3271        ndim = 3
3272        df = torch.rand([]) + ndim - 1
3273        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
3274        d = Wishart(df=df, scale_tril=scale_tril)
3275        samples = d.rsample((ndim * ndim * 100000,))
3276        empirical_mean = samples.mean(0)
3277        self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0)
3278        empirical_var = samples.var(0)
3279        self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0)
3280
3281    @set_default_dtype(torch.double)
3282    def test_exponential(self):
3283        rate = torch.randn(5, 5).abs().requires_grad_()
3284        rate_1d = torch.randn(1).abs().requires_grad_()
3285        self.assertEqual(Exponential(rate).sample().size(), (5, 5))
3286        self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
3287        self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
3288        self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
3289        self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
3290        self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
3291
3292        self._gradcheck_log_prob(Exponential, (rate,))
3293        state = torch.get_rng_state()
3294        eps = rate.new(rate.size()).exponential_()
3295        torch.set_rng_state(state)
3296        z = Exponential(rate).rsample()
3297        z.backward(torch.ones_like(z))
3298        self.assertEqual(rate.grad, -eps / rate**2)
3299        rate.grad.zero_()
3300        self.assertEqual(z.size(), (5, 5))
3301
3302        def ref_log_prob(idx, x, log_prob):
3303            m = rate.view(-1)[idx]
3304            expected = math.log(m) - m * x
3305            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3306
3307        self._check_log_prob(Exponential(rate), ref_log_prob)
3308        self._check_forward_ad(lambda x: x.exponential_())
3309
3310        def mean_var(lambd, sample):
3311            sample.exponential_(lambd)
3312            mean = sample.float().mean()
3313            var = sample.float().var()
3314            self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2)
3315            self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2)
3316
3317        for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]:
3318            for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]:
3319                sample_len = 50000
3320                mean_var(lambd, torch.rand(sample_len, dtype=dtype))
3321
3322    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3323    def test_exponential_sample(self):
3324        set_rng_seed(1)  # see Note [Randomized statistical tests]
3325        for rate in [1e-5, 1.0, 10.0]:
3326            self._check_sampler_sampler(
3327                Exponential(rate),
3328                scipy.stats.expon(scale=1.0 / rate),
3329                f"Exponential(rate={rate})",
3330            )
3331
3332    @set_default_dtype(torch.double)
3333    def test_laplace(self):
3334        loc = torch.randn(5, 5, requires_grad=True)
3335        scale = torch.randn(5, 5).abs().requires_grad_()
3336        loc_1d = torch.randn(1, requires_grad=True)
3337        scale_1d = torch.randn(1, requires_grad=True)
3338        loc_delta = torch.tensor([1.0, 0.0])
3339        scale_delta = torch.tensor([1e-5, 1e-5])
3340        self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
3341        self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5))
3342        self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
3343        self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
3344        self.assertEqual(Laplace(0.2, 0.6).sample((1,)).size(), (1,))
3345        self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,))
3346
3347        # sample check for extreme value of mean, std
3348        set_rng_seed(0)
3349        self.assertEqual(
3350            Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
3351            torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
3352            atol=1e-4,
3353            rtol=0,
3354        )
3355
3356        self._gradcheck_log_prob(Laplace, (loc, scale))
3357        self._gradcheck_log_prob(Laplace, (loc, 1.0))
3358        self._gradcheck_log_prob(Laplace, (0.0, scale))
3359
3360        state = torch.get_rng_state()
3361        eps = torch.ones_like(loc).uniform_(-0.5, 0.5)
3362        torch.set_rng_state(state)
3363        z = Laplace(loc, scale).rsample()
3364        z.backward(torch.ones_like(z))
3365        self.assertEqual(loc.grad, torch.ones_like(loc))
3366        self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
3367        loc.grad.zero_()
3368        scale.grad.zero_()
3369        self.assertEqual(z.size(), (5, 5))
3370
3371        def ref_log_prob(idx, x, log_prob):
3372            m = loc.view(-1)[idx]
3373            s = scale.view(-1)[idx]
3374            expected = -math.log(2 * s) - abs(x - m) / s
3375            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3376
3377        self._check_log_prob(Laplace(loc, scale), ref_log_prob)
3378
3379    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3380    @set_default_dtype(torch.double)
3381    def test_laplace_sample(self):
3382        set_rng_seed(1)  # see Note [Randomized statistical tests]
3383        for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
3384            self._check_sampler_sampler(
3385                Laplace(loc, scale),
3386                scipy.stats.laplace(loc=loc, scale=scale),
3387                f"Laplace(loc={loc}, scale={scale})",
3388            )
3389
3390    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3391    def test_gamma_shape(self):
3392        alpha = torch.randn(2, 3).exp().requires_grad_()
3393        beta = torch.randn(2, 3).exp().requires_grad_()
3394        alpha_1d = torch.randn(1).exp().requires_grad_()
3395        beta_1d = torch.randn(1).exp().requires_grad_()
3396        self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
3397        self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
3398        self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
3399        self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
3400        self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
3401        self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
3402
3403        def ref_log_prob(idx, x, log_prob):
3404            a = alpha.view(-1)[idx].detach()
3405            b = beta.view(-1)[idx].detach()
3406            expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
3407            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3408
3409        self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
3410
3411    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3412    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3413    def test_gamma_gpu_shape(self):
3414        alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
3415        beta = torch.randn(2, 3).cuda().exp().requires_grad_()
3416        alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
3417        beta_1d = torch.randn(1).cuda().exp().requires_grad_()
3418        self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
3419        self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
3420        self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
3421        self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
3422        self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
3423        self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
3424
3425        def ref_log_prob(idx, x, log_prob):
3426            a = alpha.view(-1)[idx].detach().cpu()
3427            b = beta.view(-1)[idx].detach().cpu()
3428            expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
3429            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3430
3431        self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
3432
3433    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3434    def test_gamma_sample(self):
3435        set_rng_seed(0)  # see Note [Randomized statistical tests]
3436        for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3437            self._check_sampler_sampler(
3438                Gamma(alpha, beta),
3439                scipy.stats.gamma(alpha, scale=1.0 / beta),
3440                f"Gamma(concentration={alpha}, rate={beta})",
3441            )
3442
3443    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3444    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3445    def test_gamma_gpu_sample(self):
3446        set_rng_seed(0)
3447        for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3448            a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
3449            self._check_sampler_sampler(
3450                Gamma(a, b),
3451                scipy.stats.gamma(alpha, scale=1.0 / beta),
3452                f"Gamma(alpha={alpha}, beta={beta})",
3453                failure_rate=1e-4,
3454            )
3455
3456    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3457    def test_pareto(self):
3458        scale = torch.randn(2, 3).abs().requires_grad_()
3459        alpha = torch.randn(2, 3).abs().requires_grad_()
3460        scale_1d = torch.randn(1).abs().requires_grad_()
3461        alpha_1d = torch.randn(1).abs().requires_grad_()
3462        self.assertEqual(Pareto(scale_1d, 0.5).mean, inf)
3463        self.assertEqual(Pareto(scale_1d, 0.5).variance, inf)
3464        self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
3465        self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
3466        self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
3467        self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,))
3468        self.assertEqual(Pareto(1.0, 1.0).sample().size(), ())
3469        self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,))
3470
3471        def ref_log_prob(idx, x, log_prob):
3472            s = scale.view(-1)[idx].detach()
3473            a = alpha.view(-1)[idx].detach()
3474            expected = scipy.stats.pareto.logpdf(x, a, scale=s)
3475            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3476
3477        self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
3478
3479    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3480    def test_pareto_sample(self):
3481        set_rng_seed(1)  # see Note [Randomized statistical tests]
3482        for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3483            self._check_sampler_sampler(
3484                Pareto(scale, alpha),
3485                scipy.stats.pareto(alpha, scale=scale),
3486                f"Pareto(scale={scale}, alpha={alpha})",
3487            )
3488
3489    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3490    def test_gumbel(self):
3491        loc = torch.randn(2, 3, requires_grad=True)
3492        scale = torch.randn(2, 3).abs().requires_grad_()
3493        loc_1d = torch.randn(1, requires_grad=True)
3494        scale_1d = torch.randn(1).abs().requires_grad_()
3495        self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
3496        self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
3497        self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
3498        self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
3499        self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
3500        self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
3501        self.assertEqual(
3502            Gumbel(
3503                torch.tensor(0.0, dtype=torch.float32),
3504                torch.tensor(1.0, dtype=torch.float32),
3505                validate_args=False,
3506            ).cdf(20.0),
3507            1.0,
3508            atol=1e-4,
3509            rtol=0,
3510        )
3511        self.assertEqual(
3512            Gumbel(
3513                torch.tensor(0.0, dtype=torch.float64),
3514                torch.tensor(1.0, dtype=torch.float64),
3515                validate_args=False,
3516            ).cdf(50.0),
3517            1.0,
3518            atol=1e-4,
3519            rtol=0,
3520        )
3521        self.assertEqual(
3522            Gumbel(
3523                torch.tensor(0.0, dtype=torch.float32),
3524                torch.tensor(1.0, dtype=torch.float32),
3525                validate_args=False,
3526            ).cdf(-5.0),
3527            0.0,
3528            atol=1e-4,
3529            rtol=0,
3530        )
3531        self.assertEqual(
3532            Gumbel(
3533                torch.tensor(0.0, dtype=torch.float64),
3534                torch.tensor(1.0, dtype=torch.float64),
3535                validate_args=False,
3536            ).cdf(-10.0),
3537            0.0,
3538            atol=1e-8,
3539            rtol=0,
3540        )
3541
3542        def ref_log_prob(idx, x, log_prob):
3543            l = loc.view(-1)[idx].detach()
3544            s = scale.view(-1)[idx].detach()
3545            expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
3546            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3547
3548        self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
3549
3550    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3551    @set_default_dtype(torch.double)
3552    def test_gumbel_sample(self):
3553        set_rng_seed(1)  # see note [Randomized statistical tests]
3554        for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3555            self._check_sampler_sampler(
3556                Gumbel(loc, scale),
3557                scipy.stats.gumbel_r(loc=loc, scale=scale),
3558                f"Gumbel(loc={loc}, scale={scale})",
3559            )
3560
3561    def test_kumaraswamy_shape(self):
3562        concentration1 = torch.randn(2, 3).abs().requires_grad_()
3563        concentration0 = torch.randn(2, 3).abs().requires_grad_()
3564        concentration1_1d = torch.randn(1).abs().requires_grad_()
3565        concentration0_1d = torch.randn(1).abs().requires_grad_()
3566        self.assertEqual(
3567            Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)
3568        )
3569        self.assertEqual(
3570            Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3)
3571        )
3572        self.assertEqual(
3573            Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,)
3574        )
3575        self.assertEqual(
3576            Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(),
3577            (1, 1),
3578        )
3579        self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ())
3580        self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,))
3581
3582    # Kumaraswamy distribution is not implemented in SciPy
3583    # Hence these tests are explicit
3584    def test_kumaraswamy_mean_variance(self):
3585        c1_1 = torch.randn(2, 3).abs().requires_grad_()
3586        c0_1 = torch.randn(2, 3).abs().requires_grad_()
3587        c1_2 = torch.randn(4).abs().requires_grad_()
3588        c0_2 = torch.randn(4).abs().requires_grad_()
3589        cases = [(c1_1, c0_1), (c1_2, c0_2)]
3590        for i, (a, b) in enumerate(cases):
3591            m = Kumaraswamy(a, b)
3592            samples = m.sample((60000,))
3593            expected = samples.mean(0)
3594            actual = m.mean
3595            error = (expected - actual).abs()
3596            max_error = max(error[error == error])
3597            self.assertLess(
3598                max_error,
3599                0.01,
3600                f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .mean",
3601            )
3602            expected = samples.var(0)
3603            actual = m.variance
3604            error = (expected - actual).abs()
3605            max_error = max(error[error == error])
3606            self.assertLess(
3607                max_error,
3608                0.01,
3609                f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .variance",
3610            )
3611
3612    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3613    def test_fishersnedecor(self):
3614        df1 = torch.randn(2, 3).abs().requires_grad_()
3615        df2 = torch.randn(2, 3).abs().requires_grad_()
3616        df1_1d = torch.randn(1).abs()
3617        df2_1d = torch.randn(1).abs()
3618        self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
3619        self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance))
3620        self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
3621        self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3))
3622        self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
3623        self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1))
3624        self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ())
3625        self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,))
3626
3627        def ref_log_prob(idx, x, log_prob):
3628            f1 = df1.view(-1)[idx].detach()
3629            f2 = df2.view(-1)[idx].detach()
3630            expected = scipy.stats.f.logpdf(x, f1, f2)
3631            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3632
3633        self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
3634
3635    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3636    def test_fishersnedecor_sample(self):
3637        set_rng_seed(1)  # see note [Randomized statistical tests]
3638        for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
3639            self._check_sampler_sampler(
3640                FisherSnedecor(df1, df2),
3641                scipy.stats.f(df1, df2),
3642                f"FisherSnedecor(loc={df1}, scale={df2})",
3643            )
3644
3645    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3646    def test_chi2_shape(self):
3647        df = torch.randn(2, 3).exp().requires_grad_()
3648        df_1d = torch.randn(1).exp().requires_grad_()
3649        self.assertEqual(Chi2(df).sample().size(), (2, 3))
3650        self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
3651        self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
3652        self.assertEqual(Chi2(df_1d).sample().size(), (1,))
3653        self.assertEqual(
3654            Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), ()
3655        )
3656        self.assertEqual(Chi2(0.5).sample().size(), ())
3657        self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,))
3658
3659        def ref_log_prob(idx, x, log_prob):
3660            d = df.view(-1)[idx].detach()
3661            expected = scipy.stats.chi2.logpdf(x, d)
3662            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3663
3664        self._check_log_prob(Chi2(df), ref_log_prob)
3665
3666    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3667    def test_chi2_sample(self):
3668        set_rng_seed(0)  # see Note [Randomized statistical tests]
3669        for df in [0.1, 1.0, 5.0]:
3670            self._check_sampler_sampler(
3671                Chi2(df), scipy.stats.chi2(df), f"Chi2(df={df})"
3672            )
3673
3674    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3675    def test_studentT(self):
3676        df = torch.randn(2, 3).exp().requires_grad_()
3677        df_1d = torch.randn(1).exp().requires_grad_()
3678        self.assertTrue(is_all_nan(StudentT(1).mean))
3679        self.assertTrue(is_all_nan(StudentT(1).variance))
3680        self.assertEqual(StudentT(2).variance, inf)
3681        self.assertEqual(StudentT(df).sample().size(), (2, 3))
3682        self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3))
3683        self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1))
3684        self.assertEqual(StudentT(df_1d).sample().size(), (1,))
3685        self.assertEqual(
3686            StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), ()
3687        )
3688        self.assertEqual(StudentT(0.5).sample().size(), ())
3689        self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,))
3690
3691        def ref_log_prob(idx, x, log_prob):
3692            d = df.view(-1)[idx].detach()
3693            expected = scipy.stats.t.logpdf(x, d)
3694            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3695
3696        self._check_log_prob(StudentT(df), ref_log_prob)
3697
3698    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3699    @set_default_dtype(torch.double)
3700    def test_studentT_sample(self):
3701        set_rng_seed(11)  # see Note [Randomized statistical tests]
3702        for df, loc, scale in product(
3703            [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]
3704        ):
3705            self._check_sampler_sampler(
3706                StudentT(df=df, loc=loc, scale=scale),
3707                scipy.stats.t(df=df, loc=loc, scale=scale),
3708                f"StudentT(df={df}, loc={loc}, scale={scale})",
3709            )
3710
3711    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3712    def test_studentT_log_prob(self):
3713        set_rng_seed(0)  # see Note [Randomized statistical tests]
3714        num_samples = 10
3715        for df, loc, scale in product(
3716            [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]
3717        ):
3718            dist = StudentT(df=df, loc=loc, scale=scale)
3719            x = dist.sample((num_samples,))
3720            actual_log_prob = dist.log_prob(x)
3721            for i in range(num_samples):
3722                expected_log_prob = scipy.stats.t.logpdf(
3723                    x[i], df=df, loc=loc, scale=scale
3724                )
3725                self.assertEqual(
3726                    float(actual_log_prob[i]),
3727                    float(expected_log_prob),
3728                    atol=1e-3,
3729                    rtol=0,
3730                )
3731
3732    def test_dirichlet_shape(self):
3733        alpha = torch.randn(2, 3).exp().requires_grad_()
3734        alpha_1d = torch.randn(4).exp().requires_grad_()
3735        self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
3736        self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
3737        self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
3738        self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
3739
3740    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3741    @set_default_dtype(torch.double)
3742    def test_dirichlet_log_prob(self):
3743        num_samples = 10
3744        alpha = torch.exp(torch.randn(5))
3745        dist = Dirichlet(alpha)
3746        x = dist.sample((num_samples,))
3747        actual_log_prob = dist.log_prob(x)
3748        for i in range(num_samples):
3749            expected_log_prob = scipy.stats.dirichlet.logpdf(
3750                x[i].numpy(), alpha.numpy()
3751            )
3752            self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
3753
3754    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3755    def test_dirichlet_log_prob_zero(self):
3756        # Specifically test the special case where x=0 and alpha=1.  The PDF is
3757        # proportional to x**(alpha-1), which in this case works out to 0**0=1.
3758        # The log PDF of this term should therefore be 0.  However, it's easy
3759        # to accidentally introduce NaNs by calculating log(x) without regard
3760        # for the value of alpha-1.
3761        alpha = torch.tensor([1, 2])
3762        dist = Dirichlet(alpha)
3763        x = torch.tensor([0, 1])
3764        actual_log_prob = dist.log_prob(x)
3765        expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy())
3766        self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0)
3767
3768    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3769    def test_dirichlet_sample(self):
3770        set_rng_seed(0)  # see Note [Randomized statistical tests]
3771        alpha = torch.exp(torch.randn(3))
3772        self._check_sampler_sampler(
3773            Dirichlet(alpha),
3774            scipy.stats.dirichlet(alpha.numpy()),
3775            f"Dirichlet(alpha={list(alpha)})",
3776            multivariate=True,
3777        )
3778
3779    def test_dirichlet_mode(self):
3780        # Test a few edge cases for the Dirichlet distribution mode. This also covers beta distributions.
3781        concentrations_and_modes = [
3782            ([2, 2, 1], [0.5, 0.5, 0.0]),
3783            ([3, 2, 1], [2 / 3, 1 / 3, 0]),
3784            ([0.5, 0.2, 0.2], [1.0, 0.0, 0.0]),
3785            ([1, 1, 1], [nan, nan, nan]),
3786        ]
3787        for concentration, mode in concentrations_and_modes:
3788            dist = Dirichlet(torch.tensor(concentration))
3789            self.assertEqual(dist.mode, torch.tensor(mode))
3790
3791    def test_beta_shape(self):
3792        con1 = torch.randn(2, 3).exp().requires_grad_()
3793        con0 = torch.randn(2, 3).exp().requires_grad_()
3794        con1_1d = torch.randn(4).exp().requires_grad_()
3795        con0_1d = torch.randn(4).exp().requires_grad_()
3796        self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
3797        self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
3798        self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
3799        self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4))
3800        self.assertEqual(Beta(0.1, 0.3).sample().size(), ())
3801        self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
3802
3803    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3804    def test_beta_log_prob(self):
3805        for _ in range(100):
3806            con1 = np.exp(np.random.normal())
3807            con0 = np.exp(np.random.normal())
3808            dist = Beta(con1, con0)
3809            x = dist.sample()
3810            actual_log_prob = dist.log_prob(x).sum()
3811            expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
3812            self.assertEqual(
3813                float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0
3814            )
3815
3816    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3817    @set_default_dtype(torch.double)
3818    def test_beta_sample(self):
3819        set_rng_seed(1)  # see Note [Randomized statistical tests]
3820        for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
3821            self._check_sampler_sampler(
3822                Beta(con1, con0),
3823                scipy.stats.beta(con1, con0),
3824                f"Beta(alpha={con1}, beta={con0})",
3825            )
3826        # Check that small alphas do not cause NANs.
3827        for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
3828            x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
3829            self.assertTrue(np.isfinite(x) and x > 0, f"Invalid Beta.sample(): {x}")
3830
3831    def test_beta_underflow(self):
3832        # For low values of (alpha, beta), the gamma samples can underflow
3833        # with float32 and result in a spurious mode at 0.5. To prevent this,
3834        # torch._sample_dirichlet works with double precision for intermediate
3835        # calculations.
3836        set_rng_seed(1)
3837        num_samples = 50000
3838        for dtype in [torch.float, torch.double]:
3839            conc = torch.tensor(1e-2, dtype=dtype)
3840            beta_samples = Beta(conc, conc).sample([num_samples])
3841            self.assertEqual((beta_samples == 0).sum(), 0)
3842            self.assertEqual((beta_samples == 1).sum(), 0)
3843            # assert support is concentrated around 0 and 1
3844            frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
3845            frac_ones = float((beta_samples > 0.9).sum()) / num_samples
3846            self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0)
3847            self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0)
3848
3849    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3850    def test_beta_underflow_gpu(self):
3851        set_rng_seed(1)
3852        num_samples = 50000
3853        conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
3854        beta_samples = Beta(conc, conc).sample([num_samples])
3855        self.assertEqual((beta_samples == 0).sum(), 0)
3856        self.assertEqual((beta_samples == 1).sum(), 0)
3857        # assert support is concentrated around 0 and 1
3858        frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
3859        frac_ones = float((beta_samples > 0.9).sum()) / num_samples
3860        # TODO: increase precision once imbalance on GPU is fixed.
3861        self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0)
3862        self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0)
3863
3864    @set_default_dtype(torch.double)
3865    def test_continuous_bernoulli(self):
3866        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
3867        r = torch.tensor(0.3, requires_grad=True)
3868        s = 0.3
3869        self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
3870        self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
3871        self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
3872        self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
3873        self.assertEqual(
3874            ContinuousBernoulli(r).sample((3, 2)).size(),
3875            (
3876                3,
3877                2,
3878            ),
3879        )
3880        self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
3881        self._gradcheck_log_prob(ContinuousBernoulli, (p,))
3882
3883        def ref_log_prob(idx, val, log_prob):
3884            prob = p[idx]
3885            if prob > 0.499 and prob < 0.501:  # using default value of lim here
3886                log_norm_const = (
3887                    math.log(2.0)
3888                    + 4.0 / 3.0 * math.pow(prob - 0.5, 2)
3889                    + 104.0 / 45.0 * math.pow(prob - 0.5, 4)
3890                )
3891            else:
3892                log_norm_const = math.log(
3893                    2.0 * math.atanh(1.0 - 2.0 * prob) / (1.0 - 2.0 * prob)
3894                )
3895            res = (
3896                val * math.log(prob) + (1.0 - val) * math.log1p(-prob) + log_norm_const
3897            )
3898            self.assertEqual(log_prob, res)
3899
3900        self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
3901        self._check_log_prob(
3902            ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob
3903        )
3904
3905        # check entropy computation
3906        self.assertEqual(
3907            ContinuousBernoulli(p).entropy(),
3908            torch.tensor([-0.02938, -0.07641, -0.00682]),
3909            atol=1e-4,
3910            rtol=0,
3911        )
3912        # entropy below corresponds to the clamped value of prob when using float 64
3913        # the value for float32 should be -1.76898
3914        self.assertEqual(
3915            ContinuousBernoulli(torch.tensor([0.0])).entropy(),
3916            torch.tensor([-2.58473]),
3917            atol=1e-5,
3918            rtol=0,
3919        )
3920        self.assertEqual(
3921            ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4, rtol=0
3922        )
3923
3924    def test_continuous_bernoulli_3d(self):
3925        p = torch.full((2, 3, 5), 0.5).requires_grad_()
3926        self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
3927        self.assertEqual(
3928            ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5)
3929        )
3930        self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
3931
3932    def test_lkj_cholesky_log_prob(self):
3933        def tril_cholesky_to_tril_corr(x):
3934            x = vec_to_tril_matrix(x, -1)
3935            diag = (1 - (x * x).sum(-1)).sqrt().diag_embed()
3936            x = x + diag
3937            return tril_matrix_to_vec(x @ x.T, -1)
3938
3939        for dim in range(2, 5):
3940            log_probs = []
3941            lkj = LKJCholesky(dim, concentration=1.0, validate_args=True)
3942            for i in range(2):
3943                sample = lkj.sample()
3944                sample_tril = tril_matrix_to_vec(sample, diag=-1)
3945                log_prob = lkj.log_prob(sample)
3946                log_abs_det_jacobian = torch.slogdet(
3947                    jacobian(tril_cholesky_to_tril_corr, sample_tril)
3948                ).logabsdet
3949                log_probs.append(log_prob - log_abs_det_jacobian)
3950            # for concentration=1., the density is uniform over the space of all
3951            # correlation matrices.
3952            if dim == 2:
3953                # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.)
3954                self.assertTrue(
3955                    all(
3956                        torch.allclose(x, torch.tensor(0.5).log(), atol=1e-10)
3957                        for x in log_probs
3958                    )
3959                )
3960            self.assertEqual(log_probs[0], log_probs[1])
3961            invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0)
3962            self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample))
3963
3964    def test_independent_shape(self):
3965        for Dist, params in _get_examples():
3966            for param in params:
3967                base_dist = Dist(**param)
3968                x = base_dist.sample()
3969                base_log_prob_shape = base_dist.log_prob(x).shape
3970                for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
3971                    indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
3972                    indep_log_prob_shape = base_log_prob_shape[
3973                        : len(base_log_prob_shape) - reinterpreted_batch_ndims
3974                    ]
3975                    self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape)
3976                    self.assertEqual(
3977                        indep_dist.sample().shape, base_dist.sample().shape
3978                    )
3979                    self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
3980                    if indep_dist.has_rsample:
3981                        self.assertEqual(
3982                            indep_dist.sample().shape, base_dist.sample().shape
3983                        )
3984                    try:
3985                        self.assertEqual(
3986                            indep_dist.enumerate_support().shape,
3987                            base_dist.enumerate_support().shape,
3988                        )
3989                        self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
3990                    except NotImplementedError:
3991                        pass
3992                    try:
3993                        self.assertEqual(
3994                            indep_dist.variance.shape, base_dist.variance.shape
3995                        )
3996                    except NotImplementedError:
3997                        pass
3998                    try:
3999                        self.assertEqual(
4000                            indep_dist.entropy().shape, indep_log_prob_shape
4001                        )
4002                    except NotImplementedError:
4003                        pass
4004
4005    def test_independent_expand(self):
4006        for Dist, params in _get_examples():
4007            for param in params:
4008                base_dist = Dist(**param)
4009                for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
4010                    for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]:
4011                        indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
4012                        expanded_shape = s + indep_dist.batch_shape
4013                        expanded = indep_dist.expand(expanded_shape)
4014                        expanded_sample = expanded.sample()
4015                        expected_shape = expanded_shape + indep_dist.event_shape
4016                        self.assertEqual(expanded_sample.shape, expected_shape)
4017                        self.assertEqual(
4018                            expanded.log_prob(expanded_sample),
4019                            indep_dist.log_prob(expanded_sample),
4020                        )
4021                        self.assertEqual(expanded.event_shape, indep_dist.event_shape)
4022                        self.assertEqual(expanded.batch_shape, expanded_shape)
4023
4024    @set_default_dtype(torch.double)
4025    def test_cdf_icdf_inverse(self):
4026        # Tests the invertibility property on the distributions
4027        for Dist, params in _get_examples():
4028            for i, param in enumerate(params):
4029                dist = Dist(**param)
4030                samples = dist.sample(sample_shape=(20,))
4031                try:
4032                    cdf = dist.cdf(samples)
4033                    actual = dist.icdf(cdf)
4034                except NotImplementedError:
4035                    continue
4036                rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
4037                self.assertLess(
4038                    rel_error.max(),
4039                    1e-4,
4040                    msg="\n".join(
4041                        [
4042                            f"{Dist.__name__} example {i + 1}/{len(params)}, icdf(cdf(x)) != x",
4043                            f"x = {samples}",
4044                            f"cdf(x) = {cdf}",
4045                            f"icdf(cdf(x)) = {actual}",
4046                        ]
4047                    ),
4048                )
4049
4050    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4051    def test_gamma_log_prob_at_boundary(self):
4052        for concentration, log_prob in [(0.5, inf), (1, 0), (2, -inf)]:
4053            dist = Gamma(concentration, 1)
4054            scipy_dist = scipy.stats.gamma(concentration)
4055            self.assertAlmostEqual(dist.log_prob(0), log_prob)
4056            self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0))
4057
4058    @set_default_dtype(torch.double)
4059    def test_cdf_log_prob(self):
4060        # Tests if the differentiation of the CDF gives the PDF at a given value
4061        for Dist, params in _get_examples():
4062            for i, param in enumerate(params):
4063                # We do not need grads wrt params here, e.g. shape of gamma distribution.
4064                param = {
4065                    key: value.detach() if isinstance(value, torch.Tensor) else value
4066                    for key, value in param.items()
4067                }
4068                dist = Dist(**param)
4069                samples = dist.sample()
4070                if not dist.support.is_discrete:
4071                    samples.requires_grad_()
4072                try:
4073                    cdfs = dist.cdf(samples)
4074                    pdfs = dist.log_prob(samples).exp()
4075                except NotImplementedError:
4076                    continue
4077                cdfs_derivative = grad(cdfs.sum(), [samples])[
4078                    0
4079                ]  # this should not be wrapped in torch.abs()
4080                self.assertEqual(
4081                    cdfs_derivative,
4082                    pdfs,
4083                    msg="\n".join(
4084                        [
4085                            f"{Dist.__name__} example {i + 1}/{len(params)}, d(cdf)/dx != pdf(x)",
4086                            f"x = {samples}",
4087                            f"cdf = {cdfs}",
4088                            f"pdf = {pdfs}",
4089                            f"grad(cdf) = {cdfs_derivative}",
4090                        ]
4091                    ),
4092                )
4093
4094    def test_valid_parameter_broadcasting(self):
4095        # Test correct broadcasting of parameter sizes for distributions that have multiple
4096        # parameters.
4097        # example type (distribution instance, expected sample shape)
4098        valid_examples = [
4099            (Normal(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4100            (Normal(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4101            (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4102            (
4103                Normal(
4104                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4105                ),
4106                (2, 2),
4107            ),
4108            (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)),
4109            (Normal(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4110            (FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=1), (2,)),
4111            (FisherSnedecor(df1=1, df2=torch.tensor([1.0, 1.0])), (2,)),
4112            (
4113                FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([1.0])),
4114                (2,),
4115            ),
4116            (
4117                FisherSnedecor(
4118                    df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0], [1.0]])
4119                ),
4120                (2, 2),
4121            ),
4122            (
4123                FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0]])),
4124                (1, 2),
4125            ),
4126            (
4127                FisherSnedecor(df1=torch.tensor([1.0]), df2=torch.tensor([[1.0]])),
4128                (1, 1),
4129            ),
4130            (Gamma(concentration=torch.tensor([1.0, 1.0]), rate=1), (2,)),
4131            (Gamma(concentration=1, rate=torch.tensor([1.0, 1.0])), (2,)),
4132            (
4133                Gamma(
4134                    concentration=torch.tensor([1.0, 1.0]),
4135                    rate=torch.tensor([[1.0], [1.0], [1.0]]),
4136                ),
4137                (3, 2),
4138            ),
4139            (
4140                Gamma(
4141                    concentration=torch.tensor([1.0, 1.0]),
4142                    rate=torch.tensor([[1.0], [1.0]]),
4143                ),
4144                (2, 2),
4145            ),
4146            (
4147                Gamma(
4148                    concentration=torch.tensor([1.0, 1.0]), rate=torch.tensor([[1.0]])
4149                ),
4150                (1, 2),
4151            ),
4152            (
4153                Gamma(concentration=torch.tensor([1.0]), rate=torch.tensor([[1.0]])),
4154                (1, 1),
4155            ),
4156            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4157            (Gumbel(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4158            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4159            (
4160                Gumbel(
4161                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4162                ),
4163                (2, 2),
4164            ),
4165            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)),
4166            (Gumbel(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4167            (
4168                Kumaraswamy(
4169                    concentration1=torch.tensor([1.0, 1.0]), concentration0=1.0
4170                ),
4171                (2,),
4172            ),
4173            (
4174                Kumaraswamy(concentration1=1, concentration0=torch.tensor([1.0, 1.0])),
4175                (2,),
4176            ),
4177            (
4178                Kumaraswamy(
4179                    concentration1=torch.tensor([1.0, 1.0]),
4180                    concentration0=torch.tensor([1.0]),
4181                ),
4182                (2,),
4183            ),
4184            (
4185                Kumaraswamy(
4186                    concentration1=torch.tensor([1.0, 1.0]),
4187                    concentration0=torch.tensor([[1.0], [1.0]]),
4188                ),
4189                (2, 2),
4190            ),
4191            (
4192                Kumaraswamy(
4193                    concentration1=torch.tensor([1.0, 1.0]),
4194                    concentration0=torch.tensor([[1.0]]),
4195                ),
4196                (1, 2),
4197            ),
4198            (
4199                Kumaraswamy(
4200                    concentration1=torch.tensor([1.0]),
4201                    concentration0=torch.tensor([[1.0]]),
4202                ),
4203                (1, 1),
4204            ),
4205            (Laplace(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4206            (Laplace(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4207            (Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4208            (
4209                Laplace(
4210                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4211                ),
4212                (2, 2),
4213            ),
4214            (
4215                Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])),
4216                (1, 2),
4217            ),
4218            (Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4219            (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=1), (2,)),
4220            (Pareto(scale=1, alpha=torch.tensor([1.0, 1.0])), (2,)),
4221            (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([1.0])), (2,)),
4222            (
4223                Pareto(
4224                    scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0], [1.0]])
4225                ),
4226                (2, 2),
4227            ),
4228            (
4229                Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0]])),
4230                (1, 2),
4231            ),
4232            (Pareto(scale=torch.tensor([1.0]), alpha=torch.tensor([[1.0]])), (1, 1)),
4233            (StudentT(df=torch.tensor([1.0, 1.0]), loc=1), (2,)),
4234            (StudentT(df=1, scale=torch.tensor([1.0, 1.0])), (2,)),
4235            (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([1.0])), (2,)),
4236            (
4237                StudentT(
4238                    df=torch.tensor([1.0, 1.0]), scale=torch.tensor([[1.0], [1.0]])
4239                ),
4240                (2, 2),
4241            ),
4242            (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([[1.0]])), (1, 2)),
4243            (StudentT(df=torch.tensor([1.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4244            (StudentT(df=1.0, loc=torch.zeros(5, 1), scale=torch.ones(3)), (5, 3)),
4245        ]
4246
4247        for dist, expected_size in valid_examples:
4248            actual_size = dist.sample().size()
4249            self.assertEqual(
4250                actual_size,
4251                expected_size,
4252                msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}",
4253            )
4254
4255            sample_shape = torch.Size((2,))
4256            expected_size = sample_shape + expected_size
4257            actual_size = dist.sample(sample_shape).size()
4258            self.assertEqual(
4259                actual_size,
4260                expected_size,
4261                msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}",
4262            )
4263
4264    def test_invalid_parameter_broadcasting(self):
4265        # invalid broadcasting cases; should throw error
4266        # example type (distribution class, distribution params)
4267        invalid_examples = [
4268            (
4269                Normal,
4270                {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])},
4271            ),
4272            (
4273                Normal,
4274                {
4275                    "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
4276                    "scale": torch.tensor([1, 1]),
4277                },
4278            ),
4279            (
4280                FisherSnedecor,
4281                {
4282                    "df1": torch.tensor([1, 1]),
4283                    "df2": torch.tensor([1, 1, 1]),
4284                },
4285            ),
4286            (
4287                Gumbel,
4288                {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])},
4289            ),
4290            (
4291                Gumbel,
4292                {
4293                    "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
4294                    "scale": torch.tensor([1, 1]),
4295                },
4296            ),
4297            (
4298                Gamma,
4299                {
4300                    "concentration": torch.tensor([0, 0]),
4301                    "rate": torch.tensor([1, 1, 1]),
4302                },
4303            ),
4304            (
4305                Kumaraswamy,
4306                {
4307                    "concentration1": torch.tensor([[1, 1]]),
4308                    "concentration0": torch.tensor([1, 1, 1, 1]),
4309                },
4310            ),
4311            (
4312                Kumaraswamy,
4313                {
4314                    "concentration1": torch.tensor([[[1, 1, 1], [1, 1, 1]]]),
4315                    "concentration0": torch.tensor([1, 1]),
4316                },
4317            ),
4318            (Laplace, {"loc": torch.tensor([0, 0]), "scale": torch.tensor([1, 1, 1])}),
4319            (Pareto, {"scale": torch.tensor([1, 1]), "alpha": torch.tensor([1, 1, 1])}),
4320            (
4321                StudentT,
4322                {
4323                    "df": torch.tensor([1.0, 1.0]),
4324                    "scale": torch.tensor([1.0, 1.0, 1.0]),
4325                },
4326            ),
4327            (
4328                StudentT,
4329                {"df": torch.tensor([1.0, 1.0]), "loc": torch.tensor([1.0, 1.0, 1.0])},
4330            ),
4331        ]
4332
4333        for dist, kwargs in invalid_examples:
4334            self.assertRaises(RuntimeError, dist, **kwargs)
4335
4336    def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
4337        # We cannot easily check the mode for discrete distributions, but we can look left and right
4338        # to ensure the log probability is smaller than at the mode.
4339        for step in [-1, 1]:
4340            log_prob_mode = dist.log_prob(sanitized_mode)
4341            if isinstance(dist, OneHotCategorical):
4342                idx = (dist._categorical.mode + 1) % dist.probs.shape[-1]
4343                other = torch.nn.functional.one_hot(
4344                    idx, num_classes=dist.probs.shape[-1]
4345                ).to(dist.mode)
4346            else:
4347                other = dist.mode + step
4348            mask = batch_isfinite & dist.support.check(other)
4349            self.assertTrue(mask.any() or dist.mode.unique().numel() == 1)
4350            # Add a dimension to the right if the event shape is not a scalar, e.g. OneHotCategorical.
4351            other = torch.where(
4352                mask[..., None] if mask.ndim < other.ndim else mask,
4353                other,
4354                dist.sample(),
4355            )
4356            log_prob_other = dist.log_prob(other)
4357            delta = log_prob_mode - log_prob_other
4358            self.assertTrue(
4359                (-1e-12 < delta[mask].detach()).all()
4360            )  # Allow up to 1e-12 rounding error.
4361
4362    def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
4363        # We perturb the mode in the unconstrained space and expect the log probability to decrease.
4364        num_points = 10
4365        transform = transform_to(dist.support)
4366        unconstrained_mode = transform.inv(sanitized_mode)
4367        perturbation = 1e-5 * (
4368            torch.rand((num_points,) + unconstrained_mode.shape) - 0.5
4369        )
4370        perturbed_mode = transform(perturbation + unconstrained_mode)
4371        log_prob_mode = dist.log_prob(sanitized_mode)
4372        log_prob_other = dist.log_prob(perturbed_mode)
4373        delta = log_prob_mode - log_prob_other
4374
4375        # We pass the test with a small tolerance to allow for rounding and manually set the
4376        # difference to zero if both log probs are infinite with the same sign.
4377        both_infinite_with_same_sign = (log_prob_mode == log_prob_other) & (
4378            log_prob_mode.abs() == inf
4379        )
4380        delta[both_infinite_with_same_sign] = 0.0
4381        ordering = (delta > -1e-12).all(axis=0)
4382        self.assertTrue(ordering[batch_isfinite].all())
4383
4384    @set_default_dtype(torch.double)
4385    def test_mode(self):
4386        discrete_distributions = (
4387            Bernoulli,
4388            Binomial,
4389            Categorical,
4390            Geometric,
4391            NegativeBinomial,
4392            OneHotCategorical,
4393            Poisson,
4394        )
4395        no_mode_available = (
4396            ContinuousBernoulli,
4397            LKJCholesky,
4398            LogisticNormal,
4399            MixtureSameFamily,
4400            Multinomial,
4401            RelaxedBernoulli,
4402            RelaxedOneHotCategorical,
4403        )
4404
4405        for dist_cls, params in _get_examples():
4406            for param in params:
4407                dist = dist_cls(**param)
4408                if (
4409                    isinstance(dist, no_mode_available)
4410                    or type(dist) is TransformedDistribution
4411                ):
4412                    with self.assertRaises(NotImplementedError):
4413                        dist.mode
4414                    continue
4415
4416                # Check that either all or no elements in the event shape are nan: the mode cannot be
4417                # defined for part of an event.
4418                isfinite = dist.mode.isfinite().reshape(
4419                    dist.batch_shape + (dist.event_shape.numel(),)
4420                )
4421                batch_isfinite = isfinite.all(axis=-1)
4422                self.assertTrue((batch_isfinite | ~isfinite.any(axis=-1)).all())
4423
4424                # We sanitize undefined modes by sampling from the distribution.
4425                sanitized_mode = torch.where(
4426                    ~dist.mode.isnan(), dist.mode, dist.sample()
4427                )
4428                if isinstance(dist, discrete_distributions):
4429                    self._test_discrete_distribution_mode(
4430                        dist, sanitized_mode, batch_isfinite
4431                    )
4432                else:
4433                    self._test_continuous_distribution_mode(
4434                        dist, sanitized_mode, batch_isfinite
4435                    )
4436
4437                self.assertFalse(dist.log_prob(sanitized_mode).isnan().any())
4438
4439
4440# These tests are only needed for a few distributions that implement custom
4441# reparameterized gradients. Most .rsample() implementations simply rely on
4442# the reparameterization trick and do not need to be tested for accuracy.
4443@skipIfTorchDynamo("Not a TorchDynamo suitable test")
4444class TestRsample(DistributionsTestCase):
4445    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4446    def test_gamma(self):
4447        num_samples = 100
4448        for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
4449            alphas = torch.tensor(
4450                [alpha] * num_samples, dtype=torch.float, requires_grad=True
4451            )
4452            betas = alphas.new_ones(num_samples)
4453            x = Gamma(alphas, betas).rsample()
4454            x.sum().backward()
4455            x, ind = x.sort()
4456            x = x.detach().numpy()
4457            actual_grad = alphas.grad[ind].numpy()
4458            # Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
4459            cdf = scipy.stats.gamma.cdf
4460            pdf = scipy.stats.gamma.pdf
4461            eps = 0.01 * alpha / (1.0 + alpha**0.5)
4462            cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
4463            cdf_x = pdf(x, alpha)
4464            expected_grad = -cdf_alpha / cdf_x
4465            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4466            self.assertLess(
4467                np.max(rel_error),
4468                0.0005,
4469                "\n".join(
4470                    [
4471                        f"Bad gradient dx/alpha for x ~ Gamma({alpha}, 1)",
4472                        f"x {x}",
4473                        f"expected {expected_grad}",
4474                        f"actual {actual_grad}",
4475                        f"rel error {rel_error}",
4476                        f"max error {rel_error.max()}",
4477                        f"at alpha={alpha}, x={x[rel_error.argmax()]}",
4478                    ]
4479                ),
4480            )
4481
4482    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4483    def test_chi2(self):
4484        num_samples = 100
4485        for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
4486            dfs = torch.tensor(
4487                [df] * num_samples, dtype=torch.float, requires_grad=True
4488            )
4489            x = Chi2(dfs).rsample()
4490            x.sum().backward()
4491            x, ind = x.sort()
4492            x = x.detach().numpy()
4493            actual_grad = dfs.grad[ind].numpy()
4494            # Compare with expected gradient dx/ddf along constant cdf(x,df).
4495            cdf = scipy.stats.chi2.cdf
4496            pdf = scipy.stats.chi2.pdf
4497            eps = 0.01 * df / (1.0 + df**0.5)
4498            cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps)
4499            cdf_x = pdf(x, df)
4500            expected_grad = -cdf_df / cdf_x
4501            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4502            self.assertLess(
4503                np.max(rel_error),
4504                0.001,
4505                "\n".join(
4506                    [
4507                        f"Bad gradient dx/ddf for x ~ Chi2({df})",
4508                        f"x {x}",
4509                        f"expected {expected_grad}",
4510                        f"actual {actual_grad}",
4511                        f"rel error {rel_error}",
4512                        f"max error {rel_error.max()}",
4513                    ]
4514                ),
4515            )
4516
4517    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4518    def test_dirichlet_on_diagonal(self):
4519        num_samples = 20
4520        grid = [1e-1, 1e0, 1e1]
4521        for a0, a1, a2 in product(grid, grid, grid):
4522            alphas = torch.tensor(
4523                [[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True
4524            )
4525            x = Dirichlet(alphas).rsample()[:, 0]
4526            x.sum().backward()
4527            x, ind = x.sort()
4528            x = x.detach().numpy()
4529            actual_grad = alphas.grad[ind].numpy()[:, 0]
4530            # Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha).
4531            # This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]).
4532            cdf = scipy.stats.beta.cdf
4533            pdf = scipy.stats.beta.pdf
4534            alpha, beta = a0, a1 + a2
4535            eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
4536            cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (
4537                2 * eps
4538            )
4539            cdf_x = pdf(x, alpha, beta)
4540            expected_grad = -cdf_alpha / cdf_x
4541            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4542            self.assertLess(
4543                np.max(rel_error),
4544                0.001,
4545                "\n".join(
4546                    [
4547                        f"Bad gradient dx[0]/dalpha[0] for Dirichlet([{a0}, {a1}, {a2}])",
4548                        f"x {x}",
4549                        f"expected {expected_grad}",
4550                        f"actual {actual_grad}",
4551                        f"rel error {rel_error}",
4552                        f"max error {rel_error.max()}",
4553                        f"at x={x[rel_error.argmax()]}",
4554                    ]
4555                ),
4556            )
4557
4558    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4559    def test_beta_wrt_alpha(self):
4560        num_samples = 20
4561        grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
4562        for con1, con0 in product(grid, grid):
4563            con1s = torch.tensor(
4564                [con1] * num_samples, dtype=torch.float, requires_grad=True
4565            )
4566            con0s = con1s.new_tensor([con0] * num_samples)
4567            x = Beta(con1s, con0s).rsample()
4568            x.sum().backward()
4569            x, ind = x.sort()
4570            x = x.detach().numpy()
4571            actual_grad = con1s.grad[ind].numpy()
4572            # Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0).
4573            cdf = scipy.stats.beta.cdf
4574            pdf = scipy.stats.beta.pdf
4575            eps = 0.01 * con1 / (1.0 + np.sqrt(con1))
4576            cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / (
4577                2 * eps
4578            )
4579            cdf_x = pdf(x, con1, con0)
4580            expected_grad = -cdf_alpha / cdf_x
4581            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4582            self.assertLess(
4583                np.max(rel_error),
4584                0.005,
4585                "\n".join(
4586                    [
4587                        f"Bad gradient dx/dcon1 for x ~ Beta({con1}, {con0})",
4588                        f"x {x}",
4589                        f"expected {expected_grad}",
4590                        f"actual {actual_grad}",
4591                        f"rel error {rel_error}",
4592                        f"max error {rel_error.max()}",
4593                        f"at x = {x[rel_error.argmax()]}",
4594                    ]
4595                ),
4596            )
4597
4598    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4599    def test_beta_wrt_beta(self):
4600        num_samples = 20
4601        grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
4602        for con1, con0 in product(grid, grid):
4603            con0s = torch.tensor(
4604                [con0] * num_samples, dtype=torch.float, requires_grad=True
4605            )
4606            con1s = con0s.new_tensor([con1] * num_samples)
4607            x = Beta(con1s, con0s).rsample()
4608            x.sum().backward()
4609            x, ind = x.sort()
4610            x = x.detach().numpy()
4611            actual_grad = con0s.grad[ind].numpy()
4612            # Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0).
4613            cdf = scipy.stats.beta.cdf
4614            pdf = scipy.stats.beta.pdf
4615            eps = 0.01 * con0 / (1.0 + np.sqrt(con0))
4616            cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps)
4617            cdf_x = pdf(x, con1, con0)
4618            expected_grad = -cdf_beta / cdf_x
4619            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4620            self.assertLess(
4621                np.max(rel_error),
4622                0.005,
4623                "\n".join(
4624                    [
4625                        f"Bad gradient dx/dcon0 for x ~ Beta({con1}, {con0})",
4626                        f"x {x}",
4627                        f"expected {expected_grad}",
4628                        f"actual {actual_grad}",
4629                        f"rel error {rel_error}",
4630                        f"max error {rel_error.max()}",
4631                        f"at x = {x[rel_error.argmax()]!r}",
4632                    ]
4633                ),
4634            )
4635
4636    def test_dirichlet_multivariate(self):
4637        alpha_crit = 0.25 * (5.0**0.5 - 1.0)
4638        num_samples = 100000
4639        for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
4640            alpha = alpha_crit + shift
4641            alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True)
4642            alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
4643            z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
4644            mean_z3 = 1.0 / (2.0 * alpha + 1.0)
4645            loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean()
4646            actual_grad = grad(loss, [alpha])[0]
4647            # Compute expected gradient by hand.
4648            num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
4649            den = (1.0 + alpha) ** 2 * (1.0 + 2.0 * alpha) ** 3
4650            expected_grad = num / den
4651            self.assertEqual(
4652                actual_grad,
4653                expected_grad,
4654                atol=0.002,
4655                rtol=0,
4656                msg="\n".join(
4657                    [
4658                        "alpha = alpha_c + %.2g" % shift,  # noqa: UP031
4659                        "expected_grad: %.5g" % expected_grad,  # noqa: UP031
4660                        "actual_grad: %.5g" % actual_grad,  # noqa: UP031
4661                        "error = %.2g"  # noqa: UP031
4662                        % torch.abs(expected_grad - actual_grad).max(),  # noqa: UP031
4663                    ]
4664                ),
4665            )
4666
4667    @set_default_dtype(torch.double)
4668    def test_dirichlet_tangent_field(self):
4669        num_samples = 20
4670        alpha_grid = [0.5, 1.0, 2.0]
4671
4672        # v = dx/dalpha[0] is the reparameterized gradient aka tangent field.
4673        def compute_v(x, alpha):
4674            return torch.stack(
4675                [
4676                    _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0]
4677                    for i in range(3)
4678                ],
4679                dim=-1,
4680            )
4681
4682        for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid):
4683            alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand(
4684                num_samples, 3
4685            )
4686            x = Dirichlet(alpha).rsample()
4687            dlogp_da = grad(
4688                [Dirichlet(alpha).log_prob(x.detach()).sum()],
4689                [alpha],
4690                retain_graph=True,
4691            )[0][:, 0]
4692            dlogp_dx = grad(
4693                [Dirichlet(alpha.detach()).log_prob(x).sum()], [x], retain_graph=True
4694            )[0]
4695            v = torch.stack(
4696                [
4697                    grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0]
4698                    for i in range(3)
4699                ],
4700                dim=-1,
4701            )
4702            # Compute ramaining properties by finite difference.
4703            self.assertEqual(compute_v(x, alpha), v, msg="Bug in compute_v() helper")
4704            # dx is an arbitrary orthonormal basis tangent to the simplex.
4705            dx = torch.tensor([[2.0, -1.0, -1.0], [0.0, 1.0, -1.0]])
4706            dx /= dx.norm(2, -1, True)
4707            eps = 1e-2 * x.min(-1, True)[0]  # avoid boundary
4708            dv0 = (
4709                compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha)
4710            ) / (2 * eps)
4711            dv1 = (
4712                compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha)
4713            ) / (2 * eps)
4714            div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1)
4715            # This is a modification of the standard continuity equation, using the product rule to allow
4716            # expression in terms of log_prob rather than the less numerically stable log_prob.exp().
4717            error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
4718            self.assertLess(
4719                torch.abs(error).max(),
4720                0.005,
4721                "\n".join(
4722                    [
4723                        f"Dirichlet([{a1}, {a2}, {a3}]) gradient violates continuity equation:",
4724                        f"error = {error}",
4725                    ]
4726                ),
4727            )
4728
4729
4730class TestDistributionShapes(DistributionsTestCase):
4731    def setUp(self):
4732        super().setUp()
4733        self.scalar_sample = 1
4734        self.tensor_sample_1 = torch.ones(3, 2)
4735        self.tensor_sample_2 = torch.ones(3, 2, 3)
4736
4737    def test_entropy_shape(self):
4738        for Dist, params in _get_examples():
4739            for i, param in enumerate(params):
4740                dist = Dist(validate_args=False, **param)
4741                try:
4742                    actual_shape = dist.entropy().size()
4743                    expected_shape = (
4744                        dist.batch_shape if dist.batch_shape else torch.Size()
4745                    )
4746                    message = f"{Dist.__name__} example {i + 1}/{len(params)}, shape mismatch. expected {expected_shape}, actual {actual_shape}"  # noqa: B950
4747                    self.assertEqual(actual_shape, expected_shape, msg=message)
4748                except NotImplementedError:
4749                    continue
4750
4751    def test_bernoulli_shape_scalar_params(self):
4752        bernoulli = Bernoulli(0.3)
4753        self.assertEqual(bernoulli._batch_shape, torch.Size())
4754        self.assertEqual(bernoulli._event_shape, torch.Size())
4755        self.assertEqual(bernoulli.sample().size(), torch.Size())
4756        self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
4757        self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample)
4758        self.assertEqual(
4759            bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4760        )
4761        self.assertEqual(
4762            bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4763        )
4764
4765    def test_bernoulli_shape_tensor_params(self):
4766        bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4767        self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
4768        self.assertEqual(bernoulli._event_shape, torch.Size(()))
4769        self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
4770        self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4771        self.assertEqual(
4772            bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4773        )
4774        self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
4775        self.assertEqual(
4776            bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4777        )
4778
4779    def test_geometric_shape_scalar_params(self):
4780        geometric = Geometric(0.3)
4781        self.assertEqual(geometric._batch_shape, torch.Size())
4782        self.assertEqual(geometric._event_shape, torch.Size())
4783        self.assertEqual(geometric.sample().size(), torch.Size())
4784        self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2)))
4785        self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample)
4786        self.assertEqual(
4787            geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4788        )
4789        self.assertEqual(
4790            geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4791        )
4792
4793    def test_geometric_shape_tensor_params(self):
4794        geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4795        self.assertEqual(geometric._batch_shape, torch.Size((3, 2)))
4796        self.assertEqual(geometric._event_shape, torch.Size(()))
4797        self.assertEqual(geometric.sample().size(), torch.Size((3, 2)))
4798        self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4799        self.assertEqual(
4800            geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4801        )
4802        self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2)
4803        self.assertEqual(
4804            geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4805        )
4806
4807    def test_beta_shape_scalar_params(self):
4808        dist = Beta(0.1, 0.1)
4809        self.assertEqual(dist._batch_shape, torch.Size())
4810        self.assertEqual(dist._event_shape, torch.Size())
4811        self.assertEqual(dist.sample().size(), torch.Size())
4812        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
4813        self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
4814        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4815        self.assertEqual(
4816            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4817        )
4818
4819    def test_beta_shape_tensor_params(self):
4820        dist = Beta(
4821            torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
4822            torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
4823        )
4824        self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
4825        self.assertEqual(dist._event_shape, torch.Size(()))
4826        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4827        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4828        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4829        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4830        self.assertEqual(
4831            dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4832        )
4833
4834    def test_binomial_shape(self):
4835        dist = Binomial(10, torch.tensor([0.6, 0.3]))
4836        self.assertEqual(dist._batch_shape, torch.Size((2,)))
4837        self.assertEqual(dist._event_shape, torch.Size(()))
4838        self.assertEqual(dist.sample().size(), torch.Size((2,)))
4839        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2)))
4840        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4841        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4842
4843    def test_binomial_shape_vectorized_n(self):
4844        dist = Binomial(
4845            torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1])
4846        )
4847        self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
4848        self.assertEqual(dist._event_shape, torch.Size(()))
4849        self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
4850        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
4851        self.assertEqual(
4852            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4853        )
4854        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4855
4856    def test_multinomial_shape(self):
4857        dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4858        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4859        self.assertEqual(dist._event_shape, torch.Size((2,)))
4860        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4861        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4862        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
4863        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4864        self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
4865
4866    def test_categorical_shape(self):
4867        # unbatched
4868        dist = Categorical(torch.tensor([0.6, 0.3, 0.1]))
4869        self.assertEqual(dist._batch_shape, torch.Size(()))
4870        self.assertEqual(dist._event_shape, torch.Size(()))
4871        self.assertEqual(dist.sample().size(), torch.Size())
4872        self.assertEqual(
4873            dist.sample((3, 2)).size(),
4874            torch.Size(
4875                (
4876                    3,
4877                    2,
4878                )
4879            ),
4880        )
4881        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4882        self.assertEqual(
4883            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4884        )
4885        self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1)))
4886        # batched
4887        dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4888        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4889        self.assertEqual(dist._event_shape, torch.Size(()))
4890        self.assertEqual(dist.sample().size(), torch.Size((3,)))
4891        self.assertEqual(
4892            dist.sample((3, 2)).size(),
4893            torch.Size(
4894                (
4895                    3,
4896                    2,
4897                    3,
4898                )
4899            ),
4900        )
4901        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4902        self.assertEqual(
4903            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4904        )
4905        self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3)))
4906
4907    def test_one_hot_categorical_shape(self):
4908        # unbatched
4909        dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1]))
4910        self.assertEqual(dist._batch_shape, torch.Size(()))
4911        self.assertEqual(dist._event_shape, torch.Size((3,)))
4912        self.assertEqual(dist.sample().size(), torch.Size((3,)))
4913        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
4914        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4915        sample = torch.tensor([0.0, 1.0, 0.0]).expand(3, 2, 3)
4916        self.assertEqual(
4917            dist.log_prob(sample).size(),
4918            torch.Size(
4919                (
4920                    3,
4921                    2,
4922                )
4923            ),
4924        )
4925        self.assertEqual(
4926            dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,))
4927        )
4928        sample = torch.eye(3)
4929        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
4930        # batched
4931        dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4932        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4933        self.assertEqual(dist._event_shape, torch.Size((2,)))
4934        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4935        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4936        sample = torch.tensor([0.0, 1.0])
4937        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
4938        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4939        self.assertEqual(
4940            dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3))
4941        )
4942        sample = torch.tensor([0.0, 1.0]).expand(3, 1, 2)
4943        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3)))
4944
4945    def test_cauchy_shape_scalar_params(self):
4946        cauchy = Cauchy(0, 1)
4947        self.assertEqual(cauchy._batch_shape, torch.Size())
4948        self.assertEqual(cauchy._event_shape, torch.Size())
4949        self.assertEqual(cauchy.sample().size(), torch.Size())
4950        self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
4951        self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample)
4952        self.assertEqual(
4953            cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4954        )
4955        self.assertEqual(
4956            cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4957        )
4958
4959    def test_cauchy_shape_tensor_params(self):
4960        cauchy = Cauchy(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
4961        self.assertEqual(cauchy._batch_shape, torch.Size((2,)))
4962        self.assertEqual(cauchy._event_shape, torch.Size(()))
4963        self.assertEqual(cauchy.sample().size(), torch.Size((2,)))
4964        self.assertEqual(
4965            cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
4966        )
4967        self.assertEqual(
4968            cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4969        )
4970        self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
4971        self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
4972
4973    def test_halfcauchy_shape_scalar_params(self):
4974        halfcauchy = HalfCauchy(1)
4975        self.assertEqual(halfcauchy._batch_shape, torch.Size())
4976        self.assertEqual(halfcauchy._event_shape, torch.Size())
4977        self.assertEqual(halfcauchy.sample().size(), torch.Size())
4978        self.assertEqual(
4979            halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))
4980        )
4981        self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample)
4982        self.assertEqual(
4983            halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4984        )
4985        self.assertEqual(
4986            halfcauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4987        )
4988
4989    def test_halfcauchy_shape_tensor_params(self):
4990        halfcauchy = HalfCauchy(torch.tensor([1.0, 1.0]))
4991        self.assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
4992        self.assertEqual(halfcauchy._event_shape, torch.Size(()))
4993        self.assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
4994        self.assertEqual(
4995            halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
4996        )
4997        self.assertEqual(
4998            halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4999        )
5000        self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2)
5001        self.assertEqual(
5002            halfcauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))
5003        )
5004
5005    def test_dirichlet_shape(self):
5006        dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
5007        self.assertEqual(dist._batch_shape, torch.Size((3,)))
5008        self.assertEqual(dist._event_shape, torch.Size((2,)))
5009        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
5010        self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
5011        simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(
5012            -1, keepdim=True
5013        )
5014        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
5015        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
5016        simplex_sample = torch.ones(3, 1, 2)
5017        simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
5018        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
5019
5020    def test_mixture_same_family_shape(self):
5021        dist = MixtureSameFamily(
5022            Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5))
5023        )
5024        self.assertEqual(dist._batch_shape, torch.Size())
5025        self.assertEqual(dist._event_shape, torch.Size())
5026        self.assertEqual(dist.sample().size(), torch.Size())
5027        self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4)))
5028        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5029        self.assertEqual(
5030            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5031        )
5032
5033    def test_gamma_shape_scalar_params(self):
5034        gamma = Gamma(1, 1)
5035        self.assertEqual(gamma._batch_shape, torch.Size())
5036        self.assertEqual(gamma._event_shape, torch.Size())
5037        self.assertEqual(gamma.sample().size(), torch.Size())
5038        self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
5039        self.assertEqual(gamma.log_prob(self.scalar_sample).size(), torch.Size())
5040        self.assertEqual(
5041            gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5042        )
5043        self.assertEqual(
5044            gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5045        )
5046
5047    def test_gamma_shape_tensor_params(self):
5048        gamma = Gamma(torch.tensor([1.0, 1.0]), torch.tensor([1.0, 1.0]))
5049        self.assertEqual(gamma._batch_shape, torch.Size((2,)))
5050        self.assertEqual(gamma._event_shape, torch.Size(()))
5051        self.assertEqual(gamma.sample().size(), torch.Size((2,)))
5052        self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5053        self.assertEqual(
5054            gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5055        )
5056        self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
5057        self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5058
5059    def test_chi2_shape_scalar_params(self):
5060        chi2 = Chi2(1)
5061        self.assertEqual(chi2._batch_shape, torch.Size())
5062        self.assertEqual(chi2._event_shape, torch.Size())
5063        self.assertEqual(chi2.sample().size(), torch.Size())
5064        self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2)))
5065        self.assertEqual(chi2.log_prob(self.scalar_sample).size(), torch.Size())
5066        self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5067        self.assertEqual(
5068            chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5069        )
5070
5071    def test_chi2_shape_tensor_params(self):
5072        chi2 = Chi2(torch.tensor([1.0, 1.0]))
5073        self.assertEqual(chi2._batch_shape, torch.Size((2,)))
5074        self.assertEqual(chi2._event_shape, torch.Size(()))
5075        self.assertEqual(chi2.sample().size(), torch.Size((2,)))
5076        self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5077        self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5078        self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
5079        self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5080
5081    def test_studentT_shape_scalar_params(self):
5082        st = StudentT(1)
5083        self.assertEqual(st._batch_shape, torch.Size())
5084        self.assertEqual(st._event_shape, torch.Size())
5085        self.assertEqual(st.sample().size(), torch.Size())
5086        self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
5087        self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
5088        self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5089        self.assertEqual(
5090            st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5091        )
5092
5093    def test_studentT_shape_tensor_params(self):
5094        st = StudentT(torch.tensor([1.0, 1.0]))
5095        self.assertEqual(st._batch_shape, torch.Size((2,)))
5096        self.assertEqual(st._event_shape, torch.Size(()))
5097        self.assertEqual(st.sample().size(), torch.Size((2,)))
5098        self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5099        self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5100        self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
5101        self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5102
5103    def test_pareto_shape_scalar_params(self):
5104        pareto = Pareto(1, 1)
5105        self.assertEqual(pareto._batch_shape, torch.Size())
5106        self.assertEqual(pareto._event_shape, torch.Size())
5107        self.assertEqual(pareto.sample().size(), torch.Size())
5108        self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
5109        self.assertEqual(
5110            pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2))
5111        )
5112        self.assertEqual(
5113            pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3))
5114        )
5115
5116    def test_gumbel_shape_scalar_params(self):
5117        gumbel = Gumbel(1, 1)
5118        self.assertEqual(gumbel._batch_shape, torch.Size())
5119        self.assertEqual(gumbel._event_shape, torch.Size())
5120        self.assertEqual(gumbel.sample().size(), torch.Size())
5121        self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
5122        self.assertEqual(
5123            gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5124        )
5125        self.assertEqual(
5126            gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5127        )
5128
5129    def test_kumaraswamy_shape_scalar_params(self):
5130        kumaraswamy = Kumaraswamy(1, 1)
5131        self.assertEqual(kumaraswamy._batch_shape, torch.Size())
5132        self.assertEqual(kumaraswamy._event_shape, torch.Size())
5133        self.assertEqual(kumaraswamy.sample().size(), torch.Size())
5134        self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2)))
5135        self.assertEqual(
5136            kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5137        )
5138        self.assertEqual(
5139            kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5140        )
5141
5142    def test_vonmises_shape_tensor_params(self):
5143        von_mises = VonMises(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5144        self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
5145        self.assertEqual(von_mises._event_shape, torch.Size(()))
5146        self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
5147        self.assertEqual(
5148            von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
5149        )
5150        self.assertEqual(
5151            von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5152        )
5153        self.assertEqual(
5154            von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))
5155        )
5156
5157    def test_vonmises_shape_scalar_params(self):
5158        von_mises = VonMises(0.0, 1.0)
5159        self.assertEqual(von_mises._batch_shape, torch.Size())
5160        self.assertEqual(von_mises._event_shape, torch.Size())
5161        self.assertEqual(von_mises.sample().size(), torch.Size())
5162        self.assertEqual(
5163            von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))
5164        )
5165        self.assertEqual(
5166            von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5167        )
5168        self.assertEqual(
5169            von_mises.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5170        )
5171
5172    def test_weibull_scale_scalar_params(self):
5173        weibull = Weibull(1, 1)
5174        self.assertEqual(weibull._batch_shape, torch.Size())
5175        self.assertEqual(weibull._event_shape, torch.Size())
5176        self.assertEqual(weibull.sample().size(), torch.Size())
5177        self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2)))
5178        self.assertEqual(
5179            weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5180        )
5181        self.assertEqual(
5182            weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5183        )
5184
5185    def test_wishart_shape_scalar_params(self):
5186        wishart = Wishart(torch.tensor(1), torch.tensor([[1.0]]))
5187        self.assertEqual(wishart._batch_shape, torch.Size())
5188        self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
5189        self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
5190        self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
5191        self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)
5192
5193    def test_wishart_shape_tensor_params(self):
5194        wishart = Wishart(torch.tensor([1.0, 1.0]), torch.tensor([[[1.0]], [[1.0]]]))
5195        self.assertEqual(wishart._batch_shape, torch.Size((2,)))
5196        self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
5197        self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
5198        self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
5199        self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
5200        self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))
5201
5202    def test_normal_shape_scalar_params(self):
5203        normal = Normal(0, 1)
5204        self.assertEqual(normal._batch_shape, torch.Size())
5205        self.assertEqual(normal._event_shape, torch.Size())
5206        self.assertEqual(normal.sample().size(), torch.Size())
5207        self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
5208        self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
5209        self.assertEqual(
5210            normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5211        )
5212        self.assertEqual(
5213            normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5214        )
5215
5216    def test_normal_shape_tensor_params(self):
5217        normal = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5218        self.assertEqual(normal._batch_shape, torch.Size((2,)))
5219        self.assertEqual(normal._event_shape, torch.Size(()))
5220        self.assertEqual(normal.sample().size(), torch.Size((2,)))
5221        self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5222        self.assertEqual(
5223            normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5224        )
5225        self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
5226        self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5227
5228    def test_uniform_shape_scalar_params(self):
5229        uniform = Uniform(0, 1)
5230        self.assertEqual(uniform._batch_shape, torch.Size())
5231        self.assertEqual(uniform._event_shape, torch.Size())
5232        self.assertEqual(uniform.sample().size(), torch.Size())
5233        self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
5234        self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample)
5235        self.assertEqual(
5236            uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5237        )
5238        self.assertEqual(
5239            uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5240        )
5241
5242    def test_uniform_shape_tensor_params(self):
5243        uniform = Uniform(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5244        self.assertEqual(uniform._batch_shape, torch.Size((2,)))
5245        self.assertEqual(uniform._event_shape, torch.Size(()))
5246        self.assertEqual(uniform.sample().size(), torch.Size((2,)))
5247        self.assertEqual(
5248            uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
5249        )
5250        self.assertEqual(
5251            uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5252        )
5253        self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
5254        self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5255
5256    def test_exponential_shape_scalar_param(self):
5257        expon = Exponential(1.0)
5258        self.assertEqual(expon._batch_shape, torch.Size())
5259        self.assertEqual(expon._event_shape, torch.Size())
5260        self.assertEqual(expon.sample().size(), torch.Size())
5261        self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
5262        self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
5263        self.assertEqual(
5264            expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5265        )
5266        self.assertEqual(
5267            expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5268        )
5269
5270    def test_exponential_shape_tensor_param(self):
5271        expon = Exponential(torch.tensor([1.0, 1.0]))
5272        self.assertEqual(expon._batch_shape, torch.Size((2,)))
5273        self.assertEqual(expon._event_shape, torch.Size(()))
5274        self.assertEqual(expon.sample().size(), torch.Size((2,)))
5275        self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5276        self.assertEqual(
5277            expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5278        )
5279        self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
5280        self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2)))
5281
5282    def test_laplace_shape_scalar_params(self):
5283        laplace = Laplace(0, 1)
5284        self.assertEqual(laplace._batch_shape, torch.Size())
5285        self.assertEqual(laplace._event_shape, torch.Size())
5286        self.assertEqual(laplace.sample().size(), torch.Size())
5287        self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
5288        self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
5289        self.assertEqual(
5290            laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5291        )
5292        self.assertEqual(
5293            laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5294        )
5295
5296    def test_laplace_shape_tensor_params(self):
5297        laplace = Laplace(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5298        self.assertEqual(laplace._batch_shape, torch.Size((2,)))
5299        self.assertEqual(laplace._event_shape, torch.Size(()))
5300        self.assertEqual(laplace.sample().size(), torch.Size((2,)))
5301        self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5302        self.assertEqual(
5303            laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5304        )
5305        self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
5306        self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5307
5308    def test_continuous_bernoulli_shape_scalar_params(self):
5309        continuous_bernoulli = ContinuousBernoulli(0.3)
5310        self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
5311        self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
5312        self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
5313        self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
5314        self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
5315        self.assertEqual(
5316            continuous_bernoulli.log_prob(self.tensor_sample_1).size(),
5317            torch.Size((3, 2)),
5318        )
5319        self.assertEqual(
5320            continuous_bernoulli.log_prob(self.tensor_sample_2).size(),
5321            torch.Size((3, 2, 3)),
5322        )
5323
5324    def test_continuous_bernoulli_shape_tensor_params(self):
5325        continuous_bernoulli = ContinuousBernoulli(
5326            torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])
5327        )
5328        self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
5329        self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
5330        self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
5331        self.assertEqual(
5332            continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))
5333        )
5334        self.assertEqual(
5335            continuous_bernoulli.log_prob(self.tensor_sample_1).size(),
5336            torch.Size((3, 2)),
5337        )
5338        self.assertRaises(
5339            ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2
5340        )
5341        self.assertEqual(
5342            continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(),
5343            torch.Size((3, 3, 2)),
5344        )
5345
5346    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
5347    def test_mixture_same_family_mean_shape(self):
5348        mix_distribution = Categorical(torch.ones([3, 1, 3]))
5349        component_distribution = Normal(torch.zeros([3, 3, 3]), torch.ones([3, 3, 3]))
5350        gmm = MixtureSameFamily(mix_distribution, component_distribution)
5351        self.assertEqual(len(gmm.mean.shape), 2)
5352
5353
5354@skipIfTorchDynamo("Not a TorchDynamo suitable test")
5355class TestKL(DistributionsTestCase):
5356    def setUp(self):
5357        super().setUp()
5358
5359        class Binomial30(Binomial):
5360            def __init__(self, probs):
5361                super().__init__(30, probs)
5362
5363        # These are pairs of distributions with 4 x 4 parameters as specified.
5364        # The first of the pair e.g. bernoulli[0] varies column-wise and the second
5365        # e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
5366        bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
5367        binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
5368        binomial_vectorized_count = (
5369            Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
5370            Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])),
5371        )
5372        beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
5373        categorical = pairwise(
5374            Categorical,
5375            [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]],
5376        )
5377        cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5378        chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0])
5379        dirichlet = pairwise(
5380            Dirichlet,
5381            [[0.1, 0.2, 0.7], [0.5, 0.4, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]],
5382        )
5383        exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
5384        gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
5385        gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
5386        halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
5387        inversegamma = pairwise(
5388            InverseGamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]
5389        )
5390        laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
5391        lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5392        normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5393        independent = (Independent(normal[0], 1), Independent(normal[1], 1))
5394        onehotcategorical = pairwise(
5395            OneHotCategorical,
5396            [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]],
5397        )
5398        pareto = (
5399            Pareto(
5400                torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4),
5401                torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4),
5402            ),
5403            Pareto(
5404                torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4),
5405                torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4),
5406            ),
5407        )
5408        poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
5409        uniform_within_unit = pairwise(
5410            Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8]
5411        )
5412        uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
5413        uniform_real = pairwise(Uniform, [-2.0, -1, 0, 2], [-1.0, 1, 1, 4])
5414        uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5])
5415        continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])
5416
5417        # These tests should pass with precision = 0.01, but that makes tests very expensive.
5418        # Instead, we test with precision = 0.1 and only test with higher precision locally
5419        # when adding a new KL implementation.
5420        # The following pairs are not tested due to very high variance of the monte carlo
5421        # estimator; their implementations have been reviewed with extra care:
5422        # - (pareto, normal)
5423        self.precision = 0.1  # Set this to 0.01 when testing a new KL implementation.
5424        self.max_samples = int(1e07)  # Increase this when testing at smaller precision.
5425        self.samples_per_batch = int(1e04)
5426        self.finite_examples = [
5427            (bernoulli, bernoulli),
5428            (bernoulli, poisson),
5429            (beta, beta),
5430            (beta, chi2),
5431            (beta, exponential),
5432            (beta, gamma),
5433            (beta, normal),
5434            (binomial30, binomial30),
5435            (binomial_vectorized_count, binomial_vectorized_count),
5436            (categorical, categorical),
5437            (cauchy, cauchy),
5438            (chi2, chi2),
5439            (chi2, exponential),
5440            (chi2, gamma),
5441            (chi2, normal),
5442            (dirichlet, dirichlet),
5443            (exponential, chi2),
5444            (exponential, exponential),
5445            (exponential, gamma),
5446            (exponential, gumbel),
5447            (exponential, normal),
5448            (gamma, chi2),
5449            (gamma, exponential),
5450            (gamma, gamma),
5451            (gamma, gumbel),
5452            (gamma, normal),
5453            (gumbel, gumbel),
5454            (gumbel, normal),
5455            (halfnormal, halfnormal),
5456            (independent, independent),
5457            (inversegamma, inversegamma),
5458            (laplace, laplace),
5459            (lognormal, lognormal),
5460            (laplace, normal),
5461            (normal, gumbel),
5462            (normal, laplace),
5463            (normal, normal),
5464            (onehotcategorical, onehotcategorical),
5465            (pareto, chi2),
5466            (pareto, pareto),
5467            (pareto, exponential),
5468            (pareto, gamma),
5469            (poisson, poisson),
5470            (uniform_within_unit, beta),
5471            (uniform_positive, chi2),
5472            (uniform_positive, exponential),
5473            (uniform_positive, gamma),
5474            (uniform_real, gumbel),
5475            (uniform_real, normal),
5476            (uniform_pareto, pareto),
5477            (continuous_bernoulli, continuous_bernoulli),
5478            (continuous_bernoulli, exponential),
5479            (continuous_bernoulli, normal),
5480            (beta, continuous_bernoulli),
5481        ]
5482
5483        self.infinite_examples = [
5484            (Bernoulli(0), Bernoulli(1)),
5485            (Bernoulli(1), Bernoulli(0)),
5486            (
5487                Categorical(torch.tensor([0.9, 0.1])),
5488                Categorical(torch.tensor([1.0, 0.0])),
5489            ),
5490            (
5491                Categorical(torch.tensor([[0.9, 0.1], [0.9, 0.1]])),
5492                Categorical(torch.tensor([1.0, 0.0])),
5493            ),
5494            (Beta(1, 2), Uniform(0.25, 1)),
5495            (Beta(1, 2), Uniform(0, 0.75)),
5496            (Beta(1, 2), Uniform(0.25, 0.75)),
5497            (Beta(1, 2), Pareto(1, 2)),
5498            (Binomial(31, 0.7), Binomial(30, 0.3)),
5499            (
5500                Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
5501                Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8])),
5502            ),
5503            (Chi2(1), Beta(2, 3)),
5504            (Chi2(1), Pareto(2, 3)),
5505            (Chi2(1), Uniform(-2, 3)),
5506            (Exponential(1), Beta(2, 3)),
5507            (Exponential(1), Pareto(2, 3)),
5508            (Exponential(1), Uniform(-2, 3)),
5509            (Gamma(1, 2), Beta(3, 4)),
5510            (Gamma(1, 2), Pareto(3, 4)),
5511            (Gamma(1, 2), Uniform(-3, 4)),
5512            (Gumbel(-1, 2), Beta(3, 4)),
5513            (Gumbel(-1, 2), Chi2(3)),
5514            (Gumbel(-1, 2), Exponential(3)),
5515            (Gumbel(-1, 2), Gamma(3, 4)),
5516            (Gumbel(-1, 2), Pareto(3, 4)),
5517            (Gumbel(-1, 2), Uniform(-3, 4)),
5518            (Laplace(-1, 2), Beta(3, 4)),
5519            (Laplace(-1, 2), Chi2(3)),
5520            (Laplace(-1, 2), Exponential(3)),
5521            (Laplace(-1, 2), Gamma(3, 4)),
5522            (Laplace(-1, 2), Pareto(3, 4)),
5523            (Laplace(-1, 2), Uniform(-3, 4)),
5524            (Normal(-1, 2), Beta(3, 4)),
5525            (Normal(-1, 2), Chi2(3)),
5526            (Normal(-1, 2), Exponential(3)),
5527            (Normal(-1, 2), Gamma(3, 4)),
5528            (Normal(-1, 2), Pareto(3, 4)),
5529            (Normal(-1, 2), Uniform(-3, 4)),
5530            (Pareto(2, 1), Chi2(3)),
5531            (Pareto(2, 1), Exponential(3)),
5532            (Pareto(2, 1), Gamma(3, 4)),
5533            (Pareto(1, 2), Normal(-3, 4)),
5534            (Pareto(1, 2), Pareto(3, 4)),
5535            (Poisson(2), Bernoulli(0.5)),
5536            (Poisson(2.3), Binomial(10, 0.2)),
5537            (Uniform(-1, 1), Beta(2, 2)),
5538            (Uniform(0, 2), Beta(3, 4)),
5539            (Uniform(-1, 2), Beta(3, 4)),
5540            (Uniform(-1, 2), Chi2(3)),
5541            (Uniform(-1, 2), Exponential(3)),
5542            (Uniform(-1, 2), Gamma(3, 4)),
5543            (Uniform(-1, 2), Pareto(3, 4)),
5544            (ContinuousBernoulli(0.25), Uniform(0.25, 1)),
5545            (ContinuousBernoulli(0.25), Uniform(0, 0.75)),
5546            (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
5547            (ContinuousBernoulli(0.25), Pareto(1, 2)),
5548            (Exponential(1), ContinuousBernoulli(0.75)),
5549            (Gamma(1, 2), ContinuousBernoulli(0.75)),
5550            (Gumbel(-1, 2), ContinuousBernoulli(0.75)),
5551            (Laplace(-1, 2), ContinuousBernoulli(0.75)),
5552            (Normal(-1, 2), ContinuousBernoulli(0.75)),
5553            (Uniform(-1, 1), ContinuousBernoulli(0.75)),
5554            (Uniform(0, 2), ContinuousBernoulli(0.75)),
5555            (Uniform(-1, 2), ContinuousBernoulli(0.75)),
5556        ]
5557
5558    def test_kl_monte_carlo(self):
5559        set_rng_seed(0)  # see Note [Randomized statistical tests]
5560        for (p, _), (_, q) in self.finite_examples:
5561            actual = kl_divergence(p, q)
5562            numerator = 0
5563            denominator = 0
5564            while denominator < self.max_samples:
5565                x = p.sample(sample_shape=(self.samples_per_batch,))
5566                numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
5567                denominator += x.size(0)
5568                expected = numerator / denominator
5569                error = torch.abs(expected - actual) / (1 + expected)
5570                if error[error == error].max() < self.precision:
5571                    break
5572            self.assertLess(
5573                error[error == error].max(),
5574                self.precision,
5575                "\n".join(
5576                    [
5577                        f"Incorrect KL({type(p).__name__}, {type(q).__name__}).",
5578                        f"Expected ({denominator} Monte Carlo samples): {expected}",
5579                        f"Actual (analytic): {actual}",
5580                    ]
5581                ),
5582            )
5583
5584    # Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
5585    # positive (semi) definite matrices. n is set to 5, but can be increased during testing.
5586    def test_kl_multivariate_normal(self):
5587        set_rng_seed(0)  # see Note [Randomized statistical tests]
5588        n = 5  # Number of tests for multivariate_normal
5589        for i in range(0, n):
5590            loc = [torch.randn(4) for _ in range(0, 2)]
5591            scale_tril = [
5592                transform_to(constraints.lower_cholesky)(torch.randn(4, 4))
5593                for _ in range(0, 2)
5594            ]
5595            p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0])
5596            q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1])
5597            actual = kl_divergence(p, q)
5598            numerator = 0
5599            denominator = 0
5600            while denominator < self.max_samples:
5601                x = p.sample(sample_shape=(self.samples_per_batch,))
5602                numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
5603                denominator += x.size(0)
5604                expected = numerator / denominator
5605                error = torch.abs(expected - actual) / (1 + expected)
5606                if error[error == error].max() < self.precision:
5607                    break
5608            self.assertLess(
5609                error[error == error].max(),
5610                self.precision,
5611                "\n".join(
5612                    [
5613                        f"Incorrect KL(MultivariateNormal, MultivariateNormal) instance {i + 1}/{n}",
5614                        f"Expected ({denominator} Monte Carlo sample): {expected}",
5615                        f"Actual (analytic): {actual}",
5616                    ]
5617                ),
5618            )
5619
5620    def test_kl_multivariate_normal_batched(self):
5621        b = 7  # Number of batches
5622        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5623        scale_tril = [
5624            transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3))
5625            for _ in range(0, 2)
5626        ]
5627        expected_kl = torch.stack(
5628            [
5629                kl_divergence(
5630                    MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
5631                    MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i]),
5632                )
5633                for i in range(0, b)
5634            ]
5635        )
5636        actual_kl = kl_divergence(
5637            MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
5638            MultivariateNormal(loc[1], scale_tril=scale_tril[1]),
5639        )
5640        self.assertEqual(expected_kl, actual_kl)
5641
5642    def test_kl_multivariate_normal_batched_broadcasted(self):
5643        b = 7  # Number of batches
5644        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5645        scale_tril = [
5646            transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
5647            transform_to(constraints.lower_cholesky)(torch.randn(3, 3)),
5648        ]
5649        expected_kl = torch.stack(
5650            [
5651                kl_divergence(
5652                    MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
5653                    MultivariateNormal(loc[1][i], scale_tril=scale_tril[1]),
5654                )
5655                for i in range(0, b)
5656            ]
5657        )
5658        actual_kl = kl_divergence(
5659            MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
5660            MultivariateNormal(loc[1], scale_tril=scale_tril[1]),
5661        )
5662        self.assertEqual(expected_kl, actual_kl)
5663
5664    def test_kl_lowrank_multivariate_normal(self):
5665        set_rng_seed(0)  # see Note [Randomized statistical tests]
5666        n = 5  # Number of tests for lowrank_multivariate_normal
5667        for i in range(0, n):
5668            loc = [torch.randn(4) for _ in range(0, 2)]
5669            cov_factor = [torch.randn(4, 3) for _ in range(0, 2)]
5670            cov_diag = [
5671                transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2)
5672            ]
5673            covariance_matrix = [
5674                cov_factor[i].matmul(cov_factor[i].t()) + cov_diag[i].diag()
5675                for i in range(0, 2)
5676            ]
5677            p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0])
5678            q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])
5679            p_full = MultivariateNormal(loc[0], covariance_matrix[0])
5680            q_full = MultivariateNormal(loc[1], covariance_matrix[1])
5681            expected = kl_divergence(p_full, q_full)
5682
5683            actual_lowrank_lowrank = kl_divergence(p, q)
5684            actual_lowrank_full = kl_divergence(p, q_full)
5685            actual_full_lowrank = kl_divergence(p_full, q)
5686
5687            error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
5688            self.assertLess(
5689                error_lowrank_lowrank,
5690                self.precision,
5691                "\n".join(
5692                    [
5693                        f"Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}",
5694                        f"Expected (from KL MultivariateNormal): {expected}",
5695                        f"Actual (analytic): {actual_lowrank_lowrank}",
5696                    ]
5697                ),
5698            )
5699
5700            error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
5701            self.assertLess(
5702                error_lowrank_full,
5703                self.precision,
5704                "\n".join(
5705                    [
5706                        f"Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {i + 1}/{n}",
5707                        f"Expected (from KL MultivariateNormal): {expected}",
5708                        f"Actual (analytic): {actual_lowrank_full}",
5709                    ]
5710                ),
5711            )
5712
5713            error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
5714            self.assertLess(
5715                error_full_lowrank,
5716                self.precision,
5717                "\n".join(
5718                    [
5719                        f"Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}",
5720                        f"Expected (from KL MultivariateNormal): {expected}",
5721                        f"Actual (analytic): {actual_full_lowrank}",
5722                    ]
5723                ),
5724            )
5725
5726    def test_kl_lowrank_multivariate_normal_batched(self):
5727        b = 7  # Number of batches
5728        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5729        cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)]
5730        cov_diag = [
5731            transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2)
5732        ]
5733        expected_kl = torch.stack(
5734            [
5735                kl_divergence(
5736                    LowRankMultivariateNormal(
5737                        loc[0][i], cov_factor[0][i], cov_diag[0][i]
5738                    ),
5739                    LowRankMultivariateNormal(
5740                        loc[1][i], cov_factor[1][i], cov_diag[1][i]
5741                    ),
5742                )
5743                for i in range(0, b)
5744            ]
5745        )
5746        actual_kl = kl_divergence(
5747            LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]),
5748            LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]),
5749        )
5750        self.assertEqual(expected_kl, actual_kl)
5751
5752    def test_kl_exponential_family(self):
5753        for (p, _), (_, q) in self.finite_examples:
5754            if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
5755                actual = kl_divergence(p, q)
5756                expected = _kl_expfamily_expfamily(p, q)
5757                self.assertEqual(
5758                    actual,
5759                    expected,
5760                    msg="\n".join(
5761                        [
5762                            f"Incorrect KL({type(p).__name__}, {type(q).__name__}).",
5763                            f"Expected (using Bregman Divergence) {expected}",
5764                            f"Actual (analytic) {actual}",
5765                            f"max error = {torch.abs(actual - expected).max()}",
5766                        ]
5767                    ),
5768                )
5769
5770    def test_kl_infinite(self):
5771        for p, q in self.infinite_examples:
5772            self.assertTrue(
5773                (kl_divergence(p, q) == inf).all(),
5774                f"Incorrect KL({type(p).__name__}, {type(q).__name__})",
5775            )
5776
5777    def test_kl_edgecases(self):
5778        self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
5779        self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
5780        self.assertEqual(
5781            kl_divergence(
5782                Categorical(torch.tensor([0.0, 1.0])),
5783                Categorical(torch.tensor([0.0, 1.0])),
5784            ),
5785            0,
5786        )
5787
5788    def test_kl_shape(self):
5789        for Dist, params in _get_examples():
5790            for i, param in enumerate(params):
5791                dist = Dist(**param)
5792                try:
5793                    kl = kl_divergence(dist, dist)
5794                except NotImplementedError:
5795                    continue
5796                expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
5797                self.assertEqual(
5798                    kl.shape,
5799                    expected_shape,
5800                    msg="\n".join(
5801                        [
5802                            f"{Dist.__name__} example {i + 1}/{len(params)}",
5803                            f"Expected {expected_shape}",
5804                            f"Actual {kl.shape}",
5805                        ]
5806                    ),
5807                )
5808
5809    def test_kl_transformed(self):
5810        # Regression test for https://github.com/pytorch/pytorch/issues/34859
5811        scale = torch.ones(2, 3)
5812        loc = torch.zeros(2, 3)
5813        normal = Normal(loc=loc, scale=scale)
5814        diag_normal = Independent(normal, reinterpreted_batch_ndims=1)
5815        trans_dist = TransformedDistribution(
5816            diag_normal, AffineTransform(loc=0.0, scale=2.0)
5817        )
5818        self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,))
5819        self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,))
5820
5821    @set_default_dtype(torch.double)
5822    def test_entropy_monte_carlo(self):
5823        set_rng_seed(0)  # see Note [Randomized statistical tests]
5824        for Dist, params in _get_examples():
5825            for i, param in enumerate(params):
5826                dist = Dist(**param)
5827                try:
5828                    actual = dist.entropy()
5829                except NotImplementedError:
5830                    continue
5831                x = dist.sample(sample_shape=(60000,))
5832                expected = -dist.log_prob(x).mean(0)
5833                ignore = (expected == inf) | (expected == -inf)
5834                expected[ignore] = actual[ignore]
5835                self.assertEqual(
5836                    actual,
5837                    expected,
5838                    atol=0.2,
5839                    rtol=0,
5840                    msg="\n".join(
5841                        [
5842                            f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().",
5843                            f"Expected (monte carlo) {expected}",
5844                            f"Actual (analytic) {actual}",
5845                            f"max error = {torch.abs(actual - expected).max()}",
5846                        ]
5847                    ),
5848                )
5849
5850    @set_default_dtype(torch.double)
5851    def test_entropy_exponential_family(self):
5852        for Dist, params in _get_examples():
5853            if not issubclass(Dist, ExponentialFamily):
5854                continue
5855            for i, param in enumerate(params):
5856                dist = Dist(**param)
5857                try:
5858                    actual = dist.entropy()
5859                except NotImplementedError:
5860                    continue
5861                try:
5862                    expected = ExponentialFamily.entropy(dist)
5863                except NotImplementedError:
5864                    continue
5865                self.assertEqual(
5866                    actual,
5867                    expected,
5868                    msg="\n".join(
5869                        [
5870                            f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().",
5871                            f"Expected (Bregman Divergence) {expected}",
5872                            f"Actual (analytic) {actual}",
5873                            f"max error = {torch.abs(actual - expected).max()}",
5874                        ]
5875                    ),
5876                )
5877
5878
5879class TestConstraints(DistributionsTestCase):
5880    def test_params_constraints(self):
5881        normalize_probs_dists = (
5882            Categorical,
5883            Multinomial,
5884            OneHotCategorical,
5885            OneHotCategoricalStraightThrough,
5886            RelaxedOneHotCategorical,
5887        )
5888
5889        for Dist, params in _get_examples():
5890            for i, param in enumerate(params):
5891                dist = Dist(**param)
5892                for name, value in param.items():
5893                    if isinstance(value, numbers.Number):
5894                        value = torch.tensor([value])
5895                    if Dist in normalize_probs_dists and name == "probs":
5896                        # These distributions accept positive probs, but elsewhere we
5897                        # use a stricter constraint to the simplex.
5898                        value = value / value.sum(-1, True)
5899                    try:
5900                        constraint = dist.arg_constraints[name]
5901                    except KeyError:
5902                        continue  # ignore optional parameters
5903
5904                    # Check param shape is compatible with distribution shape.
5905                    self.assertGreaterEqual(value.dim(), constraint.event_dim)
5906                    value_batch_shape = value.shape[
5907                        : value.dim() - constraint.event_dim
5908                    ]
5909                    torch.broadcast_shapes(dist.batch_shape, value_batch_shape)
5910
5911                    if is_dependent(constraint):
5912                        continue
5913
5914                    message = f"{Dist.__name__} example {i + 1}/{len(params)} parameter {name} = {value}"
5915                    self.assertTrue(constraint.check(value).all(), msg=message)
5916
5917    def test_support_constraints(self):
5918        for Dist, params in _get_examples():
5919            self.assertIsInstance(Dist.support, Constraint)
5920            for i, param in enumerate(params):
5921                dist = Dist(**param)
5922                value = dist.sample()
5923                constraint = dist.support
5924                message = (
5925                    f"{Dist.__name__} example {i + 1}/{len(params)} sample = {value}"
5926                )
5927                self.assertEqual(
5928                    constraint.event_dim, len(dist.event_shape), msg=message
5929                )
5930                ok = constraint.check(value)
5931                self.assertEqual(ok.shape, dist.batch_shape, msg=message)
5932                self.assertTrue(ok.all(), msg=message)
5933
5934
5935@skipIfTorchDynamo("Not a TorchDynamo suitable test")
5936class TestNumericalStability(DistributionsTestCase):
5937    def _test_pdf_score(
5938        self,
5939        dist_class,
5940        x,
5941        expected_value,
5942        probs=None,
5943        logits=None,
5944        expected_gradient=None,
5945        atol=1e-5,
5946    ):
5947        if probs is not None:
5948            p = probs.detach().requires_grad_()
5949            dist = dist_class(p)
5950        else:
5951            p = logits.detach().requires_grad_()
5952            dist = dist_class(logits=p)
5953        log_pdf = dist.log_prob(x)
5954        log_pdf.sum().backward()
5955        self.assertEqual(
5956            log_pdf,
5957            expected_value,
5958            atol=atol,
5959            rtol=0,
5960            msg=f"Incorrect value for tensor type: {type(x)}. Expected = {expected_value}, Actual = {log_pdf}",
5961        )
5962        if expected_gradient is not None:
5963            self.assertEqual(
5964                p.grad,
5965                expected_gradient,
5966                atol=atol,
5967                rtol=0,
5968                msg=f"Incorrect gradient for tensor type: {type(x)}. Expected = {expected_gradient}, Actual = {p.grad}",
5969            )
5970
5971    def test_bernoulli_gradient(self):
5972        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
5973            self._test_pdf_score(
5974                dist_class=Bernoulli,
5975                probs=tensor_type([0]),
5976                x=tensor_type([0]),
5977                expected_value=tensor_type([0]),
5978                expected_gradient=tensor_type([0]),
5979            )
5980
5981            self._test_pdf_score(
5982                dist_class=Bernoulli,
5983                probs=tensor_type([0]),
5984                x=tensor_type([1]),
5985                expected_value=tensor_type(
5986                    [torch.finfo(tensor_type([]).dtype).eps]
5987                ).log(),
5988                expected_gradient=tensor_type([0]),
5989            )
5990
5991            self._test_pdf_score(
5992                dist_class=Bernoulli,
5993                probs=tensor_type([1e-4]),
5994                x=tensor_type([1]),
5995                expected_value=tensor_type([math.log(1e-4)]),
5996                expected_gradient=tensor_type([10000]),
5997            )
5998
5999            # Lower precision due to:
6000            # >>> 1 / (1 - torch.FloatTensor([0.9999]))
6001            # 9998.3408
6002            # [torch.FloatTensor of size 1]
6003            self._test_pdf_score(
6004                dist_class=Bernoulli,
6005                probs=tensor_type([1 - 1e-4]),
6006                x=tensor_type([0]),
6007                expected_value=tensor_type([math.log(1e-4)]),
6008                expected_gradient=tensor_type([-10000]),
6009                atol=2,
6010            )
6011
6012            self._test_pdf_score(
6013                dist_class=Bernoulli,
6014                logits=tensor_type([math.log(9999)]),
6015                x=tensor_type([0]),
6016                expected_value=tensor_type([math.log(1e-4)]),
6017                expected_gradient=tensor_type([-1]),
6018                atol=1e-3,
6019            )
6020
6021    def test_bernoulli_with_logits_underflow(self):
6022        for tensor_type, lim in [
6023            (torch.FloatTensor, -1e38),
6024            (torch.DoubleTensor, -1e308),
6025        ]:
6026            self._test_pdf_score(
6027                dist_class=Bernoulli,
6028                logits=tensor_type([lim]),
6029                x=tensor_type([0]),
6030                expected_value=tensor_type([0]),
6031                expected_gradient=tensor_type([0]),
6032            )
6033
6034    def test_bernoulli_with_logits_overflow(self):
6035        for tensor_type, lim in [
6036            (torch.FloatTensor, 1e38),
6037            (torch.DoubleTensor, 1e308),
6038        ]:
6039            self._test_pdf_score(
6040                dist_class=Bernoulli,
6041                logits=tensor_type([lim]),
6042                x=tensor_type([1]),
6043                expected_value=tensor_type([0]),
6044                expected_gradient=tensor_type([0]),
6045            )
6046
6047    def test_categorical_log_prob(self):
6048        for dtype in [torch.float, torch.double]:
6049            p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
6050            categorical = OneHotCategorical(p)
6051            log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
6052            self.assertEqual(log_pdf.item(), 0)
6053
6054    def test_categorical_log_prob_with_logits(self):
6055        for dtype in [torch.float, torch.double]:
6056            p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
6057            categorical = OneHotCategorical(logits=p)
6058            log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
6059            self.assertEqual(log_pdf_prob_1.item(), 0)
6060            log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype))
6061            self.assertEqual(log_pdf_prob_0.item(), -inf)
6062
6063    def test_multinomial_log_prob(self):
6064        for dtype in [torch.float, torch.double]:
6065            p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
6066            s = torch.tensor([0, 10], dtype=dtype)
6067            multinomial = Multinomial(10, p)
6068            log_pdf = multinomial.log_prob(s)
6069            self.assertEqual(log_pdf.item(), 0)
6070
6071    def test_multinomial_log_prob_with_logits(self):
6072        for dtype in [torch.float, torch.double]:
6073            p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
6074            multinomial = Multinomial(10, logits=p)
6075            log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype))
6076            self.assertEqual(log_pdf_prob_1.item(), 0)
6077            log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
6078            self.assertEqual(log_pdf_prob_0.item(), -inf)
6079
6080    def test_continuous_bernoulli_gradient(self):
6081        def expec_val(x, probs=None, logits=None):
6082            assert not (probs is None and logits is None)
6083            if logits is not None:
6084                probs = 1.0 / (1.0 + math.exp(-logits))
6085            bern_log_lik = x * math.log(probs) + (1.0 - x) * math.log1p(-probs)
6086            if probs < 0.499 or probs > 0.501:  # using default values of lims here
6087                log_norm_const = (
6088                    math.log(math.fabs(math.atanh(1.0 - 2.0 * probs)))
6089                    - math.log(math.fabs(1.0 - 2.0 * probs))
6090                    + math.log(2.0)
6091                )
6092            else:
6093                aux = math.pow(probs - 0.5, 2)
6094                log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
6095            log_lik = bern_log_lik + log_norm_const
6096            return log_lik
6097
6098        def expec_grad(x, probs=None, logits=None):
6099            assert not (probs is None and logits is None)
6100            if logits is not None:
6101                probs = 1.0 / (1.0 + math.exp(-logits))
6102            grad_bern_log_lik = x / probs - (1.0 - x) / (1.0 - probs)
6103            if probs < 0.499 or probs > 0.501:  # using default values of lims here
6104                grad_log_c = (
6105                    2.0 * probs
6106                    - 4.0 * (probs - 1.0) * probs * math.atanh(1.0 - 2.0 * probs)
6107                    - 1.0
6108                )
6109                grad_log_c /= (
6110                    2.0
6111                    * (probs - 1.0)
6112                    * probs
6113                    * (2.0 * probs - 1.0)
6114                    * math.atanh(1.0 - 2.0 * probs)
6115                )
6116            else:
6117                grad_log_c = 8.0 / 3.0 * (probs - 0.5) + 416.0 / 45.0 * math.pow(
6118                    probs - 0.5, 3
6119                )
6120            grad = grad_bern_log_lik + grad_log_c
6121            if logits is not None:
6122                grad *= 1.0 / (1.0 + math.exp(logits)) - 1.0 / math.pow(
6123                    1.0 + math.exp(logits), 2
6124                )
6125            return grad
6126
6127        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
6128            self._test_pdf_score(
6129                dist_class=ContinuousBernoulli,
6130                probs=tensor_type([0.1]),
6131                x=tensor_type([0.1]),
6132                expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
6133                expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]),
6134            )
6135
6136            self._test_pdf_score(
6137                dist_class=ContinuousBernoulli,
6138                probs=tensor_type([0.1]),
6139                x=tensor_type([1.0]),
6140                expected_value=tensor_type([expec_val(1.0, probs=0.1)]),
6141                expected_gradient=tensor_type([expec_grad(1.0, probs=0.1)]),
6142            )
6143
6144            self._test_pdf_score(
6145                dist_class=ContinuousBernoulli,
6146                probs=tensor_type([0.4999]),
6147                x=tensor_type([0.9]),
6148                expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
6149                expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]),
6150            )
6151
6152            self._test_pdf_score(
6153                dist_class=ContinuousBernoulli,
6154                probs=tensor_type([1e-4]),
6155                x=tensor_type([1]),
6156                expected_value=tensor_type([expec_val(1, probs=1e-4)]),
6157                expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
6158                atol=1e-3,
6159            )
6160
6161            self._test_pdf_score(
6162                dist_class=ContinuousBernoulli,
6163                probs=tensor_type([1 - 1e-4]),
6164                x=tensor_type([0.1]),
6165                expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
6166                expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
6167                atol=2,
6168            )
6169
6170            self._test_pdf_score(
6171                dist_class=ContinuousBernoulli,
6172                logits=tensor_type([math.log(9999)]),
6173                x=tensor_type([0]),
6174                expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
6175                expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
6176                atol=1e-3,
6177            )
6178
6179            self._test_pdf_score(
6180                dist_class=ContinuousBernoulli,
6181                logits=tensor_type([0.001]),
6182                x=tensor_type([0.5]),
6183                expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
6184                expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]),
6185            )
6186
6187    def test_continuous_bernoulli_with_logits_underflow(self):
6188        for tensor_type, lim, expected in [
6189            (torch.FloatTensor, -1e38, 2.76898),
6190            (torch.DoubleTensor, -1e308, 3.58473),
6191        ]:
6192            self._test_pdf_score(
6193                dist_class=ContinuousBernoulli,
6194                logits=tensor_type([lim]),
6195                x=tensor_type([0]),
6196                expected_value=tensor_type([expected]),
6197                expected_gradient=tensor_type([0.0]),
6198            )
6199
6200    def test_continuous_bernoulli_with_logits_overflow(self):
6201        for tensor_type, lim, expected in [
6202            (torch.FloatTensor, 1e38, 2.76898),
6203            (torch.DoubleTensor, 1e308, 3.58473),
6204        ]:
6205            self._test_pdf_score(
6206                dist_class=ContinuousBernoulli,
6207                logits=tensor_type([lim]),
6208                x=tensor_type([1]),
6209                expected_value=tensor_type([expected]),
6210                expected_gradient=tensor_type([0.0]),
6211            )
6212
6213
6214# TODO: make this a pytest parameterized test
6215class TestLazyLogitsInitialization(DistributionsTestCase):
6216    def setUp(self):
6217        super().setUp()
6218        # ContinuousBernoulli is not tested because log_prob is not computed simply
6219        # from 'logits', but 'probs' is also needed
6220        self.examples = [
6221            e
6222            for e in _get_examples()
6223            if e.Dist
6224            in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)
6225        ]
6226
6227    def test_lazy_logits_initialization(self):
6228        for Dist, params in self.examples:
6229            param = params[0].copy()
6230            if "probs" not in param:
6231                continue
6232            probs = param.pop("probs")
6233            param["logits"] = probs_to_logits(probs)
6234            dist = Dist(**param)
6235            # Create new instance to generate a valid sample
6236            dist.log_prob(Dist(**param).sample())
6237            message = f"Failed for {Dist.__name__} example 0/{len(params)}"
6238            self.assertNotIn("probs", dist.__dict__, msg=message)
6239            try:
6240                dist.enumerate_support()
6241            except NotImplementedError:
6242                pass
6243            self.assertNotIn("probs", dist.__dict__, msg=message)
6244            batch_shape, event_shape = dist.batch_shape, dist.event_shape
6245            self.assertNotIn("probs", dist.__dict__, msg=message)
6246
6247    def test_lazy_probs_initialization(self):
6248        for Dist, params in self.examples:
6249            param = params[0].copy()
6250            if "probs" not in param:
6251                continue
6252            dist = Dist(**param)
6253            dist.sample()
6254            message = f"Failed for {Dist.__name__} example 0/{len(params)}"
6255            self.assertNotIn("logits", dist.__dict__, msg=message)
6256            try:
6257                dist.enumerate_support()
6258            except NotImplementedError:
6259                pass
6260            self.assertNotIn("logits", dist.__dict__, msg=message)
6261            batch_shape, event_shape = dist.batch_shape, dist.event_shape
6262            self.assertNotIn("logits", dist.__dict__, msg=message)
6263
6264
6265@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
6266@skipIfTorchDynamo("FIXME: Tries to trace through SciPy and fails")
6267class TestAgainstScipy(DistributionsTestCase):
6268    def setUp(self):
6269        super().setUp()
6270        positive_var = torch.randn(20, dtype=torch.double).exp()
6271        positive_var2 = torch.randn(20, dtype=torch.double).exp()
6272        random_var = torch.randn(20, dtype=torch.double)
6273        simplex_tensor = softmax(torch.randn(20, dtype=torch.double), dim=-1)
6274        cov_tensor = torch.randn(20, 20, dtype=torch.double)
6275        cov_tensor = cov_tensor @ cov_tensor.mT
6276        self.distribution_pairs = [
6277            (Bernoulli(simplex_tensor), scipy.stats.bernoulli(simplex_tensor)),
6278            (
6279                Beta(positive_var, positive_var2),
6280                scipy.stats.beta(positive_var, positive_var2),
6281            ),
6282            (
6283                Binomial(10, simplex_tensor),
6284                scipy.stats.binom(
6285                    10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy()
6286                ),
6287            ),
6288            (
6289                Cauchy(random_var, positive_var),
6290                scipy.stats.cauchy(loc=random_var, scale=positive_var),
6291            ),
6292            (Dirichlet(positive_var), scipy.stats.dirichlet(positive_var)),
6293            (
6294                Exponential(positive_var),
6295                scipy.stats.expon(scale=positive_var.reciprocal()),
6296            ),
6297            (
6298                FisherSnedecor(
6299                    positive_var, 4 + positive_var2
6300                ),  # var for df2<=4 is undefined
6301                scipy.stats.f(positive_var, 4 + positive_var2),
6302            ),
6303            (
6304                Gamma(positive_var, positive_var2),
6305                scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()),
6306            ),
6307            (Geometric(simplex_tensor), scipy.stats.geom(simplex_tensor, loc=-1)),
6308            (
6309                Gumbel(random_var, positive_var2),
6310                scipy.stats.gumbel_r(random_var, positive_var2),
6311            ),
6312            (HalfCauchy(positive_var), scipy.stats.halfcauchy(scale=positive_var)),
6313            (HalfNormal(positive_var2), scipy.stats.halfnorm(scale=positive_var2)),
6314            (
6315                InverseGamma(positive_var, positive_var2),
6316                scipy.stats.invgamma(positive_var, scale=positive_var2),
6317            ),
6318            (
6319                Laplace(random_var, positive_var2),
6320                scipy.stats.laplace(random_var, positive_var2),
6321            ),
6322            (
6323                # Tests fail 1e-5 threshold if scale > 3
6324                LogNormal(random_var, positive_var.clamp(max=3)),
6325                scipy.stats.lognorm(
6326                    s=positive_var.clamp(max=3), scale=random_var.exp()
6327                ),
6328            ),
6329            (
6330                LowRankMultivariateNormal(
6331                    random_var, torch.zeros(20, 1, dtype=torch.double), positive_var2
6332                ),
6333                scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)),
6334            ),
6335            (
6336                Multinomial(10, simplex_tensor),
6337                scipy.stats.multinomial(10, simplex_tensor),
6338            ),
6339            (
6340                MultivariateNormal(random_var, torch.diag(positive_var2)),
6341                scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)),
6342            ),
6343            (
6344                MultivariateNormal(random_var, cov_tensor),
6345                scipy.stats.multivariate_normal(random_var, cov_tensor),
6346            ),
6347            (
6348                Normal(random_var, positive_var2),
6349                scipy.stats.norm(random_var, positive_var2),
6350            ),
6351            (
6352                OneHotCategorical(simplex_tensor),
6353                scipy.stats.multinomial(1, simplex_tensor),
6354            ),
6355            (
6356                Pareto(positive_var, 2 + positive_var2),
6357                scipy.stats.pareto(2 + positive_var2, scale=positive_var),
6358            ),
6359            (Poisson(positive_var), scipy.stats.poisson(positive_var)),
6360            (
6361                StudentT(2 + positive_var, random_var, positive_var2),
6362                scipy.stats.t(2 + positive_var, random_var, positive_var2),
6363            ),
6364            (
6365                Uniform(random_var, random_var + positive_var),
6366                scipy.stats.uniform(random_var, positive_var),
6367            ),
6368            (
6369                VonMises(random_var, positive_var),
6370                scipy.stats.vonmises(positive_var, loc=random_var),
6371            ),
6372            (
6373                Weibull(
6374                    positive_var[0], positive_var2[0]
6375                ),  # scipy var for Weibull only supports scalars
6376                scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0]),
6377            ),
6378            (
6379                # scipy var for Wishart only supports scalars
6380                # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
6381                Wishart(
6382                    (
6383                        20
6384                        if version.parse(scipy.__version__) < version.parse("1.7.0")
6385                        else 19
6386                    )
6387                    + positive_var[0],
6388                    cov_tensor,
6389                ),
6390                scipy.stats.wishart(
6391                    (
6392                        20
6393                        if version.parse(scipy.__version__) < version.parse("1.7.0")
6394                        else 19
6395                    )
6396                    + positive_var[0].item(),
6397                    cov_tensor,
6398                ),
6399            ),
6400        ]
6401
6402    def test_mean(self):
6403        for pytorch_dist, scipy_dist in self.distribution_pairs:
6404            if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
6405                # Cauchy, HalfCauchy distributions' mean is nan, skipping check
6406                continue
6407            elif isinstance(
6408                pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)
6409            ):
6410                self.assertEqual(pytorch_dist.mean, scipy_dist.mean, msg=pytorch_dist)
6411            else:
6412                self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), msg=pytorch_dist)
6413
6414    def test_variance_stddev(self):
6415        for pytorch_dist, scipy_dist in self.distribution_pairs:
6416            if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)):
6417                # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
6418                # VonMises variance is circular and scipy doesn't produce a correct result
6419                continue
6420            elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
6421                self.assertEqual(
6422                    pytorch_dist.variance, np.diag(scipy_dist.cov()), msg=pytorch_dist
6423                )
6424                self.assertEqual(
6425                    pytorch_dist.stddev,
6426                    np.diag(scipy_dist.cov()) ** 0.5,
6427                    msg=pytorch_dist,
6428                )
6429            elif isinstance(
6430                pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)
6431            ):
6432                self.assertEqual(
6433                    pytorch_dist.variance, np.diag(scipy_dist.cov), msg=pytorch_dist
6434                )
6435                self.assertEqual(
6436                    pytorch_dist.stddev,
6437                    np.diag(scipy_dist.cov) ** 0.5,
6438                    msg=pytorch_dist,
6439                )
6440            else:
6441                self.assertEqual(
6442                    pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist
6443                )
6444                self.assertEqual(
6445                    pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist
6446                )
6447
6448    @set_default_dtype(torch.double)
6449    def test_cdf(self):
6450        for pytorch_dist, scipy_dist in self.distribution_pairs:
6451            samples = pytorch_dist.sample((5,))
6452            try:
6453                cdf = pytorch_dist.cdf(samples)
6454            except NotImplementedError:
6455                continue
6456            self.assertEqual(cdf, scipy_dist.cdf(samples), msg=pytorch_dist)
6457
6458    def test_icdf(self):
6459        for pytorch_dist, scipy_dist in self.distribution_pairs:
6460            samples = torch.rand((5,) + pytorch_dist.batch_shape, dtype=torch.double)
6461            try:
6462                icdf = pytorch_dist.icdf(samples)
6463            except NotImplementedError:
6464                continue
6465            self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)
6466
6467
6468class TestFunctors(DistributionsTestCase):
6469    def test_cat_transform(self):
6470        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6471        x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
6472        x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6473        t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
6474        dim = 0
6475        x = torch.cat([x1, x2, x3], dim=dim)
6476        t = CatTransform([t1, t2, t3], dim=dim)
6477        actual_dom_check = t.domain.check(x)
6478        expected_dom_check = torch.cat(
6479            [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim
6480        )
6481        self.assertEqual(expected_dom_check, actual_dom_check)
6482        actual = t(x)
6483        expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
6484        self.assertEqual(expected, actual)
6485        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6486        y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6487        y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6488        y = torch.cat([y1, y2, y3], dim=dim)
6489        actual_cod_check = t.codomain.check(y)
6490        expected_cod_check = torch.cat(
6491            [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)],
6492            dim=dim,
6493        )
6494        self.assertEqual(actual_cod_check, expected_cod_check)
6495        actual_inv = t.inv(y)
6496        expected_inv = torch.cat([t1.inv(y1), t2.inv(y2), t3.inv(y3)], dim=dim)
6497        self.assertEqual(expected_inv, actual_inv)
6498        actual_jac = t.log_abs_det_jacobian(x, y)
6499        expected_jac = torch.cat(
6500            [
6501                t1.log_abs_det_jacobian(x1, y1),
6502                t2.log_abs_det_jacobian(x2, y2),
6503                t3.log_abs_det_jacobian(x3, y3),
6504            ],
6505            dim=dim,
6506        )
6507        self.assertEqual(actual_jac, expected_jac)
6508
6509    def test_cat_transform_non_uniform(self):
6510        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6511        x2 = torch.cat(
6512            [
6513                (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100,
6514                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6515            ]
6516        )
6517        t1 = ExpTransform()
6518        t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
6519        dim = 0
6520        x = torch.cat([x1, x2], dim=dim)
6521        t = CatTransform([t1, t2], dim=dim, lengths=[1, 2])
6522        actual_dom_check = t.domain.check(x)
6523        expected_dom_check = torch.cat(
6524            [t1.domain.check(x1), t2.domain.check(x2)], dim=dim
6525        )
6526        self.assertEqual(expected_dom_check, actual_dom_check)
6527        actual = t(x)
6528        expected = torch.cat([t1(x1), t2(x2)], dim=dim)
6529        self.assertEqual(expected, actual)
6530        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6531        y2 = torch.cat(
6532            [
6533                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6534                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6535            ]
6536        )
6537        y = torch.cat([y1, y2], dim=dim)
6538        actual_cod_check = t.codomain.check(y)
6539        expected_cod_check = torch.cat(
6540            [t1.codomain.check(y1), t2.codomain.check(y2)], dim=dim
6541        )
6542        self.assertEqual(actual_cod_check, expected_cod_check)
6543        actual_inv = t.inv(y)
6544        expected_inv = torch.cat([t1.inv(y1), t2.inv(y2)], dim=dim)
6545        self.assertEqual(expected_inv, actual_inv)
6546        actual_jac = t.log_abs_det_jacobian(x, y)
6547        expected_jac = torch.cat(
6548            [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)], dim=dim
6549        )
6550        self.assertEqual(actual_jac, expected_jac)
6551
6552    def test_cat_event_dim(self):
6553        t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
6554        t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
6555        dim = 1
6556        bs = 16
6557        x1 = torch.randn(bs, 2)
6558        x2 = torch.randn(bs, 2)
6559        x = torch.cat([x1, x2], dim=1)
6560        t = CatTransform([t1, t2], dim=dim, lengths=[2, 2])
6561        y1 = t1(x1)
6562        y2 = t2(x2)
6563        y = t(x)
6564        actual_jac = t.log_abs_det_jacobian(x, y)
6565        expected_jac = sum(
6566            [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)]
6567        )
6568
6569    def test_stack_transform(self):
6570        x1 = -1 * torch.arange(1, 101, dtype=torch.float)
6571        x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100
6572        x3 = torch.arange(1, 101, dtype=torch.float)
6573        t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
6574        dim = 0
6575        x = torch.stack([x1, x2, x3], dim=dim)
6576        t = StackTransform([t1, t2, t3], dim=dim)
6577        actual_dom_check = t.domain.check(x)
6578        expected_dom_check = torch.stack(
6579            [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim
6580        )
6581        self.assertEqual(expected_dom_check, actual_dom_check)
6582        actual = t(x)
6583        expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
6584        self.assertEqual(expected, actual)
6585        y1 = torch.arange(1, 101, dtype=torch.float)
6586        y2 = torch.arange(1, 101, dtype=torch.float)
6587        y3 = torch.arange(1, 101, dtype=torch.float)
6588        y = torch.stack([y1, y2, y3], dim=dim)
6589        actual_cod_check = t.codomain.check(y)
6590        expected_cod_check = torch.stack(
6591            [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)],
6592            dim=dim,
6593        )
6594        self.assertEqual(actual_cod_check, expected_cod_check)
6595        actual_inv = t.inv(x)
6596        expected_inv = torch.stack([t1.inv(x1), t2.inv(x2), t3.inv(x3)], dim=dim)
6597        self.assertEqual(expected_inv, actual_inv)
6598        actual_jac = t.log_abs_det_jacobian(x, y)
6599        expected_jac = torch.stack(
6600            [
6601                t1.log_abs_det_jacobian(x1, y1),
6602                t2.log_abs_det_jacobian(x2, y2),
6603                t3.log_abs_det_jacobian(x3, y3),
6604            ],
6605            dim=dim,
6606        )
6607        self.assertEqual(actual_jac, expected_jac)
6608
6609
6610class TestValidation(DistributionsTestCase):
6611    def test_valid(self):
6612        for Dist, params in _get_examples():
6613            for param in params:
6614                Dist(validate_args=True, **param)
6615
6616    @set_default_dtype(torch.double)
6617    def test_invalid_log_probs_arg(self):
6618        # Check that validation errors are indeed disabled,
6619        # but they might raise another error
6620        for Dist, params in _get_examples():
6621            if Dist == TransformedDistribution:
6622                # TransformedDistribution has a distribution instance
6623                # as the argument, so we cannot do much about that
6624                continue
6625            for i, param in enumerate(params):
6626                d_nonval = Dist(validate_args=False, **param)
6627                d_val = Dist(validate_args=True, **param)
6628                for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
6629                    # samples with incorrect shape must throw ValueError only
6630                    try:
6631                        log_prob = d_val.log_prob(v)
6632                    except ValueError:
6633                        pass
6634                    # get sample of correct shape
6635                    val = torch.full(d_val.batch_shape + d_val.event_shape, v)
6636                    # check samples with incorrect support
6637                    try:
6638                        log_prob = d_val.log_prob(val)
6639                    except ValueError as e:
6640                        if e.args and "must be within the support" in e.args[0]:
6641                            try:
6642                                log_prob = d_nonval.log_prob(val)
6643                            except RuntimeError:
6644                                pass
6645
6646                # check correct samples are ok
6647                valid_value = d_val.sample()
6648                d_val.log_prob(valid_value)
6649                # check invalid values raise ValueError
6650                if valid_value.dtype == torch.long:
6651                    valid_value = valid_value.float()
6652                invalid_value = torch.full_like(valid_value, math.nan)
6653                try:
6654                    with self.assertRaisesRegex(
6655                        ValueError,
6656                        "Expected value argument .* to be within the support .*",
6657                    ):
6658                        d_val.log_prob(invalid_value)
6659                except AssertionError as e:
6660                    fail_string = "Support ValueError not raised for {} example {}/{}"
6661                    raise AssertionError(
6662                        fail_string.format(Dist.__name__, i + 1, len(params))
6663                    ) from e
6664
6665    @set_default_dtype(torch.double)
6666    def test_invalid(self):
6667        for Dist, params in _get_bad_examples():
6668            for i, param in enumerate(params):
6669                try:
6670                    with self.assertRaises(ValueError):
6671                        Dist(validate_args=True, **param)
6672                except AssertionError as e:
6673                    fail_string = "ValueError not raised for {} example {}/{}"
6674                    raise AssertionError(
6675                        fail_string.format(Dist.__name__, i + 1, len(params))
6676                    ) from e
6677
6678    def test_warning_unimplemented_constraints(self):
6679        class Delta(Distribution):
6680            def __init__(self, validate_args=True):
6681                super().__init__(validate_args=validate_args)
6682
6683            def sample(self, sample_shape=torch.Size()):
6684                return torch.tensor(0.0).expand(sample_shape)
6685
6686            def log_prob(self, value):
6687                if self._validate_args:
6688                    self._validate_sample(value)
6689                value[value != 0.0] = -float("inf")
6690                value[value == 0.0] = 0.0
6691                return value
6692
6693        with self.assertWarns(UserWarning):
6694            d = Delta()
6695        sample = d.sample((2,))
6696        with self.assertWarns(UserWarning):
6697            d.log_prob(sample)
6698
6699
6700class TestJit(DistributionsTestCase):
6701    def _examples(self):
6702        for Dist, params in _get_examples():
6703            for param in params:
6704                keys = param.keys()
6705                values = tuple(param[key] for key in keys)
6706                if not all(isinstance(x, torch.Tensor) for x in values):
6707                    continue
6708                sample = Dist(**param).sample()
6709                yield Dist, keys, values, sample
6710
6711    def _perturb_tensor(self, value, constraint):
6712        if isinstance(constraint, constraints._IntegerGreaterThan):
6713            return value + 1
6714        if isinstance(
6715            constraint,
6716            (constraints._PositiveDefinite, constraints._PositiveSemidefinite),
6717        ):
6718            return value + torch.eye(value.shape[-1])
6719        if value.dtype in [torch.float, torch.double]:
6720            transform = transform_to(constraint)
6721            delta = value.new(value.shape).normal_()
6722            return transform(transform.inv(value) + delta)
6723        if value.dtype == torch.long:
6724            result = value.clone()
6725            result[value == 0] = 1
6726            result[value == 1] = 0
6727            return result
6728        raise NotImplementedError
6729
6730    def _perturb(self, Dist, keys, values, sample):
6731        with torch.no_grad():
6732            if Dist is Uniform:
6733                param = dict(zip(keys, values))
6734                param["low"] = param["low"] - torch.rand(param["low"].shape)
6735                param["high"] = param["high"] + torch.rand(param["high"].shape)
6736                values = [param[key] for key in keys]
6737            else:
6738                values = [
6739                    self._perturb_tensor(
6740                        value, Dist.arg_constraints.get(key, constraints.real)
6741                    )
6742                    for key, value in zip(keys, values)
6743                ]
6744            param = dict(zip(keys, values))
6745            sample = Dist(**param).sample()
6746            return values, sample
6747
6748    @set_default_dtype(torch.double)
6749    def test_sample(self):
6750        for Dist, keys, values, sample in self._examples():
6751
6752            def f(*values):
6753                param = dict(zip(keys, values))
6754                dist = Dist(**param)
6755                return dist.sample()
6756
6757            traced_f = torch.jit.trace(f, values, check_trace=False)
6758
6759            # FIXME Schema not found for node
6760            xfail = [
6761                Cauchy,  # aten::cauchy(Double(2,1), float, float, Generator)
6762                HalfCauchy,  # aten::cauchy(Double(2, 1), float, float, Generator)
6763                VonMises,  # Variance is not Euclidean
6764            ]
6765            if Dist in xfail:
6766                continue
6767
6768            with torch.random.fork_rng():
6769                sample = f(*values)
6770            traced_sample = traced_f(*values)
6771            self.assertEqual(sample, traced_sample)
6772
6773            # FIXME no nondeterministic nodes found in trace
6774            xfail = [Beta, Dirichlet]
6775            if Dist not in xfail:
6776                self.assertTrue(
6777                    any(n.isNondeterministic() for n in traced_f.graph.nodes())
6778                )
6779
6780    def test_rsample(self):
6781        for Dist, keys, values, sample in self._examples():
6782            if not Dist.has_rsample:
6783                continue
6784
6785            def f(*values):
6786                param = dict(zip(keys, values))
6787                dist = Dist(**param)
6788                return dist.rsample()
6789
6790            traced_f = torch.jit.trace(f, values, check_trace=False)
6791
6792            # FIXME Schema not found for node
6793            xfail = [
6794                Cauchy,  # aten::cauchy(Double(2,1), float, float, Generator)
6795                HalfCauchy,  # aten::cauchy(Double(2, 1), float, float, Generator)
6796            ]
6797            if Dist in xfail:
6798                continue
6799
6800            with torch.random.fork_rng():
6801                sample = f(*values)
6802            traced_sample = traced_f(*values)
6803            self.assertEqual(sample, traced_sample)
6804
6805            # FIXME no nondeterministic nodes found in trace
6806            xfail = [Beta, Dirichlet]
6807            if Dist not in xfail:
6808                self.assertTrue(
6809                    any(n.isNondeterministic() for n in traced_f.graph.nodes())
6810                )
6811
6812    @set_default_dtype(torch.double)
6813    def test_log_prob(self):
6814        for Dist, keys, values, sample in self._examples():
6815            # FIXME traced functions produce incorrect results
6816            xfail = [LowRankMultivariateNormal, MultivariateNormal]
6817            if Dist in xfail:
6818                continue
6819
6820            def f(sample, *values):
6821                param = dict(zip(keys, values))
6822                dist = Dist(**param)
6823                return dist.log_prob(sample)
6824
6825            traced_f = torch.jit.trace(f, (sample,) + values)
6826
6827            # check on different data
6828            values, sample = self._perturb(Dist, keys, values, sample)
6829            expected = f(sample, *values)
6830            actual = traced_f(sample, *values)
6831            self.assertEqual(
6832                expected,
6833                actual,
6834                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6835            )
6836
6837    def test_enumerate_support(self):
6838        for Dist, keys, values, sample in self._examples():
6839            # FIXME traced functions produce incorrect results
6840            xfail = [Binomial]
6841            if Dist in xfail:
6842                continue
6843
6844            def f(*values):
6845                param = dict(zip(keys, values))
6846                dist = Dist(**param)
6847                return dist.enumerate_support()
6848
6849            try:
6850                traced_f = torch.jit.trace(f, values)
6851            except NotImplementedError:
6852                continue
6853
6854            # check on different data
6855            values, sample = self._perturb(Dist, keys, values, sample)
6856            expected = f(*values)
6857            actual = traced_f(*values)
6858            self.assertEqual(
6859                expected,
6860                actual,
6861                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6862            )
6863
6864    def test_mean(self):
6865        for Dist, keys, values, sample in self._examples():
6866
6867            def f(*values):
6868                param = dict(zip(keys, values))
6869                dist = Dist(**param)
6870                return dist.mean
6871
6872            try:
6873                traced_f = torch.jit.trace(f, values)
6874            except NotImplementedError:
6875                continue
6876
6877            # check on different data
6878            values, sample = self._perturb(Dist, keys, values, sample)
6879            expected = f(*values)
6880            actual = traced_f(*values)
6881            expected[expected == float("inf")] = 0.0
6882            actual[actual == float("inf")] = 0.0
6883            self.assertEqual(
6884                expected,
6885                actual,
6886                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6887            )
6888
6889    def test_variance(self):
6890        for Dist, keys, values, sample in self._examples():
6891            if Dist in [Cauchy, HalfCauchy]:
6892                continue  # infinite variance
6893
6894            def f(*values):
6895                param = dict(zip(keys, values))
6896                dist = Dist(**param)
6897                return dist.variance
6898
6899            try:
6900                traced_f = torch.jit.trace(f, values)
6901            except NotImplementedError:
6902                continue
6903
6904            # check on different data
6905            values, sample = self._perturb(Dist, keys, values, sample)
6906            expected = f(*values).clone()
6907            actual = traced_f(*values).clone()
6908            expected[expected == float("inf")] = 0.0
6909            actual[actual == float("inf")] = 0.0
6910            self.assertEqual(
6911                expected,
6912                actual,
6913                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6914            )
6915
6916    @set_default_dtype(torch.double)
6917    def test_entropy(self):
6918        for Dist, keys, values, sample in self._examples():
6919            # FIXME traced functions produce incorrect results
6920            xfail = [LowRankMultivariateNormal, MultivariateNormal]
6921            if Dist in xfail:
6922                continue
6923
6924            def f(*values):
6925                param = dict(zip(keys, values))
6926                dist = Dist(**param)
6927                return dist.entropy()
6928
6929            try:
6930                traced_f = torch.jit.trace(f, values)
6931            except NotImplementedError:
6932                continue
6933
6934            # check on different data
6935            values, sample = self._perturb(Dist, keys, values, sample)
6936            expected = f(*values)
6937            actual = traced_f(*values)
6938            self.assertEqual(
6939                expected,
6940                actual,
6941                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6942            )
6943
6944    @set_default_dtype(torch.double)
6945    def test_cdf(self):
6946        for Dist, keys, values, sample in self._examples():
6947
6948            def f(sample, *values):
6949                param = dict(zip(keys, values))
6950                dist = Dist(**param)
6951                cdf = dist.cdf(sample)
6952                return dist.icdf(cdf)
6953
6954            try:
6955                traced_f = torch.jit.trace(f, (sample,) + values)
6956            except NotImplementedError:
6957                continue
6958
6959            # check on different data
6960            values, sample = self._perturb(Dist, keys, values, sample)
6961            expected = f(sample, *values)
6962            actual = traced_f(sample, *values)
6963            self.assertEqual(
6964                expected,
6965                actual,
6966                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6967            )
6968
6969
6970if __name__ == "__main__" and torch._C.has_lapack:
6971    TestCase._default_dtype_check_enabled = True
6972    run_tests()
6973