xref: /aosp_15_r20/external/pytorch/test/distributions/test_distributions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker"""
4*da0073e9SAndroid Build Coastguard WorkerNote [Randomized statistical tests]
5*da0073e9SAndroid Build Coastguard Worker-----------------------------------
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard WorkerThis note describes how to maintain tests in this file as random sources
8*da0073e9SAndroid Build Coastguard Workerchange. This file contains two types of randomized tests:
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker1. The easier type of randomized test are tests that should always pass but are
11*da0073e9SAndroid Build Coastguard Worker   initialized with random data. If these fail something is wrong, but it's
12*da0073e9SAndroid Build Coastguard Worker   fine to use a fixed seed by inheriting from common.TestCase.
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker2. The trickier tests are statistical tests. These tests explicitly call
15*da0073e9SAndroid Build Coastguard Worker   set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
16*da0073e9SAndroid Build Coastguard Worker   These statistical tests have a known positive failure rate
17*da0073e9SAndroid Build Coastguard Worker   (we set failure_rate=1e-3 by default). We need to balance strength of these
18*da0073e9SAndroid Build Coastguard Worker   tests with annoyance of false alarms. One way that works is to specifically
19*da0073e9SAndroid Build Coastguard Worker   set seeds in each of the randomized tests. When a random generator
20*da0073e9SAndroid Build Coastguard Worker   occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
21*da0073e9SAndroid Build Coastguard Worker   of these statistical tests may (rarely) fail. If one fails in this case,
22*da0073e9SAndroid Build Coastguard Worker   it's fine to increment the seed of the failing test (but you shouldn't need
23*da0073e9SAndroid Build Coastguard Worker   to increment it more than once; otherwise something is probably actually
24*da0073e9SAndroid Build Coastguard Worker   wrong).
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker3. `test_geometric_sample`, `test_binomial_sample` and `test_poisson_sample`
27*da0073e9SAndroid Build Coastguard Worker   are validated against `scipy.stats.` which are not guaranteed to be identical
28*da0073e9SAndroid Build Coastguard Worker   across different versions of scipy (namely, they yield invalid results in 1.7+)
29*da0073e9SAndroid Build Coastguard Worker"""
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerimport math
32*da0073e9SAndroid Build Coastguard Workerimport numbers
33*da0073e9SAndroid Build Coastguard Workerimport unittest
34*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
35*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
36*da0073e9SAndroid Build Coastguard Workerfrom random import shuffle
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerfrom packaging import version
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerimport torch
41*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD
42*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
43*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import grad
44*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.functional import jacobian
45*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import (
46*da0073e9SAndroid Build Coastguard Worker    Bernoulli,
47*da0073e9SAndroid Build Coastguard Worker    Beta,
48*da0073e9SAndroid Build Coastguard Worker    Binomial,
49*da0073e9SAndroid Build Coastguard Worker    Categorical,
50*da0073e9SAndroid Build Coastguard Worker    Cauchy,
51*da0073e9SAndroid Build Coastguard Worker    Chi2,
52*da0073e9SAndroid Build Coastguard Worker    constraints,
53*da0073e9SAndroid Build Coastguard Worker    ContinuousBernoulli,
54*da0073e9SAndroid Build Coastguard Worker    Dirichlet,
55*da0073e9SAndroid Build Coastguard Worker    Distribution,
56*da0073e9SAndroid Build Coastguard Worker    Exponential,
57*da0073e9SAndroid Build Coastguard Worker    ExponentialFamily,
58*da0073e9SAndroid Build Coastguard Worker    FisherSnedecor,
59*da0073e9SAndroid Build Coastguard Worker    Gamma,
60*da0073e9SAndroid Build Coastguard Worker    Geometric,
61*da0073e9SAndroid Build Coastguard Worker    Gumbel,
62*da0073e9SAndroid Build Coastguard Worker    HalfCauchy,
63*da0073e9SAndroid Build Coastguard Worker    HalfNormal,
64*da0073e9SAndroid Build Coastguard Worker    Independent,
65*da0073e9SAndroid Build Coastguard Worker    InverseGamma,
66*da0073e9SAndroid Build Coastguard Worker    kl_divergence,
67*da0073e9SAndroid Build Coastguard Worker    Kumaraswamy,
68*da0073e9SAndroid Build Coastguard Worker    Laplace,
69*da0073e9SAndroid Build Coastguard Worker    LKJCholesky,
70*da0073e9SAndroid Build Coastguard Worker    LogisticNormal,
71*da0073e9SAndroid Build Coastguard Worker    LogNormal,
72*da0073e9SAndroid Build Coastguard Worker    LowRankMultivariateNormal,
73*da0073e9SAndroid Build Coastguard Worker    MixtureSameFamily,
74*da0073e9SAndroid Build Coastguard Worker    Multinomial,
75*da0073e9SAndroid Build Coastguard Worker    MultivariateNormal,
76*da0073e9SAndroid Build Coastguard Worker    NegativeBinomial,
77*da0073e9SAndroid Build Coastguard Worker    Normal,
78*da0073e9SAndroid Build Coastguard Worker    OneHotCategorical,
79*da0073e9SAndroid Build Coastguard Worker    OneHotCategoricalStraightThrough,
80*da0073e9SAndroid Build Coastguard Worker    Pareto,
81*da0073e9SAndroid Build Coastguard Worker    Poisson,
82*da0073e9SAndroid Build Coastguard Worker    RelaxedBernoulli,
83*da0073e9SAndroid Build Coastguard Worker    RelaxedOneHotCategorical,
84*da0073e9SAndroid Build Coastguard Worker    StudentT,
85*da0073e9SAndroid Build Coastguard Worker    TransformedDistribution,
86*da0073e9SAndroid Build Coastguard Worker    Uniform,
87*da0073e9SAndroid Build Coastguard Worker    VonMises,
88*da0073e9SAndroid Build Coastguard Worker    Weibull,
89*da0073e9SAndroid Build Coastguard Worker    Wishart,
90*da0073e9SAndroid Build Coastguard Worker)
91*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.constraint_registry import transform_to
92*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.constraints import Constraint, is_dependent
93*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.dirichlet import _Dirichlet_backward
94*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.kl import _kl_expfamily_expfamily
95*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import (
96*da0073e9SAndroid Build Coastguard Worker    AffineTransform,
97*da0073e9SAndroid Build Coastguard Worker    CatTransform,
98*da0073e9SAndroid Build Coastguard Worker    ExpTransform,
99*da0073e9SAndroid Build Coastguard Worker    identity_transform,
100*da0073e9SAndroid Build Coastguard Worker    StackTransform,
101*da0073e9SAndroid Build Coastguard Worker)
102*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import (
103*da0073e9SAndroid Build Coastguard Worker    lazy_property,
104*da0073e9SAndroid Build Coastguard Worker    probs_to_logits,
105*da0073e9SAndroid Build Coastguard Worker    tril_matrix_to_vec,
106*da0073e9SAndroid Build Coastguard Worker    vec_to_tril_matrix,
107*da0073e9SAndroid Build Coastguard Worker)
108*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import softmax
109*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
110*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
111*da0073e9SAndroid Build Coastguard Worker    gradcheck,
112*da0073e9SAndroid Build Coastguard Worker    load_tests,
113*da0073e9SAndroid Build Coastguard Worker    run_tests,
114*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
115*da0073e9SAndroid Build Coastguard Worker    set_rng_seed,
116*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
117*da0073e9SAndroid Build Coastguard Worker    TestCase,
118*da0073e9SAndroid Build Coastguard Worker)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
122*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
123*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard WorkerTEST_NUMPY = True
126*da0073e9SAndroid Build Coastguard Workertry:
127*da0073e9SAndroid Build Coastguard Worker    import numpy as np
128*da0073e9SAndroid Build Coastguard Worker    import scipy.special
129*da0073e9SAndroid Build Coastguard Worker    import scipy.stats
130*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
131*da0073e9SAndroid Build Coastguard Worker    TEST_NUMPY = False
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workerdef pairwise(Dist, *params):
135*da0073e9SAndroid Build Coastguard Worker    """
136*da0073e9SAndroid Build Coastguard Worker    Creates a pair of distributions `Dist` initialized to test each element of
137*da0073e9SAndroid Build Coastguard Worker    param with each other.
138*da0073e9SAndroid Build Coastguard Worker    """
139*da0073e9SAndroid Build Coastguard Worker    params1 = [torch.tensor([p] * len(p)) for p in params]
140*da0073e9SAndroid Build Coastguard Worker    params2 = [p.transpose(0, 1) for p in params1]
141*da0073e9SAndroid Build Coastguard Worker    return Dist(*params1), Dist(*params2)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workerdef is_all_nan(tensor):
145*da0073e9SAndroid Build Coastguard Worker    """
146*da0073e9SAndroid Build Coastguard Worker    Checks if all entries of a tensor is nan.
147*da0073e9SAndroid Build Coastguard Worker    """
148*da0073e9SAndroid Build Coastguard Worker    return (tensor != tensor).all()
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard WorkerExample = namedtuple("Example", ["Dist", "params"])
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker# Register all distributions for generic tests.
155*da0073e9SAndroid Build Coastguard Workerdef _get_examples():
156*da0073e9SAndroid Build Coastguard Worker    return [
157*da0073e9SAndroid Build Coastguard Worker        Example(
158*da0073e9SAndroid Build Coastguard Worker            Bernoulli,
159*da0073e9SAndroid Build Coastguard Worker            [
160*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
161*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.3], requires_grad=True)},
162*da0073e9SAndroid Build Coastguard Worker                {"probs": 0.3},
163*da0073e9SAndroid Build Coastguard Worker                {"logits": torch.tensor([0.0], requires_grad=True)},
164*da0073e9SAndroid Build Coastguard Worker            ],
165*da0073e9SAndroid Build Coastguard Worker        ),
166*da0073e9SAndroid Build Coastguard Worker        Example(
167*da0073e9SAndroid Build Coastguard Worker            Geometric,
168*da0073e9SAndroid Build Coastguard Worker            [
169*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
170*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.3], requires_grad=True)},
171*da0073e9SAndroid Build Coastguard Worker                {"probs": 0.3},
172*da0073e9SAndroid Build Coastguard Worker            ],
173*da0073e9SAndroid Build Coastguard Worker        ),
174*da0073e9SAndroid Build Coastguard Worker        Example(
175*da0073e9SAndroid Build Coastguard Worker            Beta,
176*da0073e9SAndroid Build Coastguard Worker            [
177*da0073e9SAndroid Build Coastguard Worker                {
178*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.randn(2, 3).exp().requires_grad_(),
179*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.randn(2, 3).exp().requires_grad_(),
180*da0073e9SAndroid Build Coastguard Worker                },
181*da0073e9SAndroid Build Coastguard Worker                {
182*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.randn(4).exp().requires_grad_(),
183*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.randn(4).exp().requires_grad_(),
184*da0073e9SAndroid Build Coastguard Worker                },
185*da0073e9SAndroid Build Coastguard Worker            ],
186*da0073e9SAndroid Build Coastguard Worker        ),
187*da0073e9SAndroid Build Coastguard Worker        Example(
188*da0073e9SAndroid Build Coastguard Worker            Categorical,
189*da0073e9SAndroid Build Coastguard Worker            [
190*da0073e9SAndroid Build Coastguard Worker                {
191*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
192*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
193*da0073e9SAndroid Build Coastguard Worker                    )
194*da0073e9SAndroid Build Coastguard Worker                },
195*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
196*da0073e9SAndroid Build Coastguard Worker                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
197*da0073e9SAndroid Build Coastguard Worker            ],
198*da0073e9SAndroid Build Coastguard Worker        ),
199*da0073e9SAndroid Build Coastguard Worker        Example(
200*da0073e9SAndroid Build Coastguard Worker            Binomial,
201*da0073e9SAndroid Build Coastguard Worker            [
202*da0073e9SAndroid Build Coastguard Worker                {
203*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
204*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
205*da0073e9SAndroid Build Coastguard Worker                    ),
206*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
207*da0073e9SAndroid Build Coastguard Worker                },
208*da0073e9SAndroid Build Coastguard Worker                {
209*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
210*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
211*da0073e9SAndroid Build Coastguard Worker                },
212*da0073e9SAndroid Build Coastguard Worker                {
213*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
214*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([10]),
215*da0073e9SAndroid Build Coastguard Worker                },
216*da0073e9SAndroid Build Coastguard Worker                {
217*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
218*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([10, 8]),
219*da0073e9SAndroid Build Coastguard Worker                },
220*da0073e9SAndroid Build Coastguard Worker                {
221*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
222*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]),
223*da0073e9SAndroid Build Coastguard Worker                },
224*da0073e9SAndroid Build Coastguard Worker                {
225*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
226*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor(0.0),
227*da0073e9SAndroid Build Coastguard Worker                },
228*da0073e9SAndroid Build Coastguard Worker            ],
229*da0073e9SAndroid Build Coastguard Worker        ),
230*da0073e9SAndroid Build Coastguard Worker        Example(
231*da0073e9SAndroid Build Coastguard Worker            NegativeBinomial,
232*da0073e9SAndroid Build Coastguard Worker            [
233*da0073e9SAndroid Build Coastguard Worker                {
234*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
235*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
236*da0073e9SAndroid Build Coastguard Worker                    ),
237*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
238*da0073e9SAndroid Build Coastguard Worker                },
239*da0073e9SAndroid Build Coastguard Worker                {
240*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
241*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
242*da0073e9SAndroid Build Coastguard Worker                },
243*da0073e9SAndroid Build Coastguard Worker                {
244*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
245*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([10]),
246*da0073e9SAndroid Build Coastguard Worker                },
247*da0073e9SAndroid Build Coastguard Worker                {
248*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
249*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([10, 8]),
250*da0073e9SAndroid Build Coastguard Worker                },
251*da0073e9SAndroid Build Coastguard Worker                {
252*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
253*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]),
254*da0073e9SAndroid Build Coastguard Worker                },
255*da0073e9SAndroid Build Coastguard Worker                {
256*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
257*da0073e9SAndroid Build Coastguard Worker                    "total_count": torch.tensor(0.0),
258*da0073e9SAndroid Build Coastguard Worker                },
259*da0073e9SAndroid Build Coastguard Worker            ],
260*da0073e9SAndroid Build Coastguard Worker        ),
261*da0073e9SAndroid Build Coastguard Worker        Example(
262*da0073e9SAndroid Build Coastguard Worker            Multinomial,
263*da0073e9SAndroid Build Coastguard Worker            [
264*da0073e9SAndroid Build Coastguard Worker                {
265*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
266*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
267*da0073e9SAndroid Build Coastguard Worker                    ),
268*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
269*da0073e9SAndroid Build Coastguard Worker                },
270*da0073e9SAndroid Build Coastguard Worker                {
271*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
272*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
273*da0073e9SAndroid Build Coastguard Worker                },
274*da0073e9SAndroid Build Coastguard Worker            ],
275*da0073e9SAndroid Build Coastguard Worker        ),
276*da0073e9SAndroid Build Coastguard Worker        Example(
277*da0073e9SAndroid Build Coastguard Worker            Cauchy,
278*da0073e9SAndroid Build Coastguard Worker            [
279*da0073e9SAndroid Build Coastguard Worker                {"loc": 0.0, "scale": 1.0},
280*da0073e9SAndroid Build Coastguard Worker                {"loc": torch.tensor([0.0]), "scale": 1.0},
281*da0073e9SAndroid Build Coastguard Worker                {
282*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([[0.0], [0.0]]),
283*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([[1.0], [1.0]]),
284*da0073e9SAndroid Build Coastguard Worker                },
285*da0073e9SAndroid Build Coastguard Worker            ],
286*da0073e9SAndroid Build Coastguard Worker        ),
287*da0073e9SAndroid Build Coastguard Worker        Example(
288*da0073e9SAndroid Build Coastguard Worker            Chi2,
289*da0073e9SAndroid Build Coastguard Worker            [
290*da0073e9SAndroid Build Coastguard Worker                {"df": torch.randn(2, 3).exp().requires_grad_()},
291*da0073e9SAndroid Build Coastguard Worker                {"df": torch.randn(1).exp().requires_grad_()},
292*da0073e9SAndroid Build Coastguard Worker            ],
293*da0073e9SAndroid Build Coastguard Worker        ),
294*da0073e9SAndroid Build Coastguard Worker        Example(
295*da0073e9SAndroid Build Coastguard Worker            StudentT,
296*da0073e9SAndroid Build Coastguard Worker            [
297*da0073e9SAndroid Build Coastguard Worker                {"df": torch.randn(2, 3).exp().requires_grad_()},
298*da0073e9SAndroid Build Coastguard Worker                {"df": torch.randn(1).exp().requires_grad_()},
299*da0073e9SAndroid Build Coastguard Worker            ],
300*da0073e9SAndroid Build Coastguard Worker        ),
301*da0073e9SAndroid Build Coastguard Worker        Example(
302*da0073e9SAndroid Build Coastguard Worker            Dirichlet,
303*da0073e9SAndroid Build Coastguard Worker            [
304*da0073e9SAndroid Build Coastguard Worker                {"concentration": torch.randn(2, 3).exp().requires_grad_()},
305*da0073e9SAndroid Build Coastguard Worker                {"concentration": torch.randn(4).exp().requires_grad_()},
306*da0073e9SAndroid Build Coastguard Worker            ],
307*da0073e9SAndroid Build Coastguard Worker        ),
308*da0073e9SAndroid Build Coastguard Worker        Example(
309*da0073e9SAndroid Build Coastguard Worker            Exponential,
310*da0073e9SAndroid Build Coastguard Worker            [
311*da0073e9SAndroid Build Coastguard Worker                {"rate": torch.randn(5, 5).abs().requires_grad_()},
312*da0073e9SAndroid Build Coastguard Worker                {"rate": torch.randn(1).abs().requires_grad_()},
313*da0073e9SAndroid Build Coastguard Worker            ],
314*da0073e9SAndroid Build Coastguard Worker        ),
315*da0073e9SAndroid Build Coastguard Worker        Example(
316*da0073e9SAndroid Build Coastguard Worker            FisherSnedecor,
317*da0073e9SAndroid Build Coastguard Worker            [
318*da0073e9SAndroid Build Coastguard Worker                {
319*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.randn(5, 5).abs().requires_grad_(),
320*da0073e9SAndroid Build Coastguard Worker                    "df2": torch.randn(5, 5).abs().requires_grad_(),
321*da0073e9SAndroid Build Coastguard Worker                },
322*da0073e9SAndroid Build Coastguard Worker                {
323*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.randn(1).abs().requires_grad_(),
324*da0073e9SAndroid Build Coastguard Worker                    "df2": torch.randn(1).abs().requires_grad_(),
325*da0073e9SAndroid Build Coastguard Worker                },
326*da0073e9SAndroid Build Coastguard Worker                {
327*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.tensor([1.0]),
328*da0073e9SAndroid Build Coastguard Worker                    "df2": 1.0,
329*da0073e9SAndroid Build Coastguard Worker                },
330*da0073e9SAndroid Build Coastguard Worker            ],
331*da0073e9SAndroid Build Coastguard Worker        ),
332*da0073e9SAndroid Build Coastguard Worker        Example(
333*da0073e9SAndroid Build Coastguard Worker            Gamma,
334*da0073e9SAndroid Build Coastguard Worker            [
335*da0073e9SAndroid Build Coastguard Worker                {
336*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.randn(2, 3).exp().requires_grad_(),
337*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(2, 3).exp().requires_grad_(),
338*da0073e9SAndroid Build Coastguard Worker                },
339*da0073e9SAndroid Build Coastguard Worker                {
340*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.randn(1).exp().requires_grad_(),
341*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(1).exp().requires_grad_(),
342*da0073e9SAndroid Build Coastguard Worker                },
343*da0073e9SAndroid Build Coastguard Worker            ],
344*da0073e9SAndroid Build Coastguard Worker        ),
345*da0073e9SAndroid Build Coastguard Worker        Example(
346*da0073e9SAndroid Build Coastguard Worker            Gumbel,
347*da0073e9SAndroid Build Coastguard Worker            [
348*da0073e9SAndroid Build Coastguard Worker                {
349*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 5, requires_grad=True),
350*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
351*da0073e9SAndroid Build Coastguard Worker                },
352*da0073e9SAndroid Build Coastguard Worker                {
353*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(1, requires_grad=True),
354*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(1).abs().requires_grad_(),
355*da0073e9SAndroid Build Coastguard Worker                },
356*da0073e9SAndroid Build Coastguard Worker            ],
357*da0073e9SAndroid Build Coastguard Worker        ),
358*da0073e9SAndroid Build Coastguard Worker        Example(HalfCauchy, [{"scale": 1.0}, {"scale": torch.tensor([[1.0], [1.0]])}]),
359*da0073e9SAndroid Build Coastguard Worker        Example(
360*da0073e9SAndroid Build Coastguard Worker            HalfNormal,
361*da0073e9SAndroid Build Coastguard Worker            [
362*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.randn(5, 5).abs().requires_grad_()},
363*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.randn(1).abs().requires_grad_()},
364*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([1e-5, 1e-5], requires_grad=True)},
365*da0073e9SAndroid Build Coastguard Worker            ],
366*da0073e9SAndroid Build Coastguard Worker        ),
367*da0073e9SAndroid Build Coastguard Worker        Example(
368*da0073e9SAndroid Build Coastguard Worker            Independent,
369*da0073e9SAndroid Build Coastguard Worker            [
370*da0073e9SAndroid Build Coastguard Worker                {
371*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
372*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, requires_grad=True),
373*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3).abs().requires_grad_(),
374*da0073e9SAndroid Build Coastguard Worker                    ),
375*da0073e9SAndroid Build Coastguard Worker                    "reinterpreted_batch_ndims": 0,
376*da0073e9SAndroid Build Coastguard Worker                },
377*da0073e9SAndroid Build Coastguard Worker                {
378*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
379*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, requires_grad=True),
380*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3).abs().requires_grad_(),
381*da0073e9SAndroid Build Coastguard Worker                    ),
382*da0073e9SAndroid Build Coastguard Worker                    "reinterpreted_batch_ndims": 1,
383*da0073e9SAndroid Build Coastguard Worker                },
384*da0073e9SAndroid Build Coastguard Worker                {
385*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
386*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, requires_grad=True),
387*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3).abs().requires_grad_(),
388*da0073e9SAndroid Build Coastguard Worker                    ),
389*da0073e9SAndroid Build Coastguard Worker                    "reinterpreted_batch_ndims": 2,
390*da0073e9SAndroid Build Coastguard Worker                },
391*da0073e9SAndroid Build Coastguard Worker                {
392*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
393*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5, requires_grad=True),
394*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5).abs().requires_grad_(),
395*da0073e9SAndroid Build Coastguard Worker                    ),
396*da0073e9SAndroid Build Coastguard Worker                    "reinterpreted_batch_ndims": 2,
397*da0073e9SAndroid Build Coastguard Worker                },
398*da0073e9SAndroid Build Coastguard Worker                {
399*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
400*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5, requires_grad=True),
401*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5).abs().requires_grad_(),
402*da0073e9SAndroid Build Coastguard Worker                    ),
403*da0073e9SAndroid Build Coastguard Worker                    "reinterpreted_batch_ndims": 3,
404*da0073e9SAndroid Build Coastguard Worker                },
405*da0073e9SAndroid Build Coastguard Worker            ],
406*da0073e9SAndroid Build Coastguard Worker        ),
407*da0073e9SAndroid Build Coastguard Worker        Example(
408*da0073e9SAndroid Build Coastguard Worker            Kumaraswamy,
409*da0073e9SAndroid Build Coastguard Worker            [
410*da0073e9SAndroid Build Coastguard Worker                {
411*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
412*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
413*da0073e9SAndroid Build Coastguard Worker                },
414*da0073e9SAndroid Build Coastguard Worker                {
415*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.rand(4).uniform_(1, 2).requires_grad_(),
416*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.rand(4).uniform_(1, 2).requires_grad_(),
417*da0073e9SAndroid Build Coastguard Worker                },
418*da0073e9SAndroid Build Coastguard Worker            ],
419*da0073e9SAndroid Build Coastguard Worker        ),
420*da0073e9SAndroid Build Coastguard Worker        Example(
421*da0073e9SAndroid Build Coastguard Worker            LKJCholesky,
422*da0073e9SAndroid Build Coastguard Worker            [
423*da0073e9SAndroid Build Coastguard Worker                {"dim": 2, "concentration": 0.5},
424*da0073e9SAndroid Build Coastguard Worker                {
425*da0073e9SAndroid Build Coastguard Worker                    "dim": 3,
426*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([0.5, 1.0, 2.0]),
427*da0073e9SAndroid Build Coastguard Worker                },
428*da0073e9SAndroid Build Coastguard Worker                {"dim": 100, "concentration": 4.0},
429*da0073e9SAndroid Build Coastguard Worker            ],
430*da0073e9SAndroid Build Coastguard Worker        ),
431*da0073e9SAndroid Build Coastguard Worker        Example(
432*da0073e9SAndroid Build Coastguard Worker            Laplace,
433*da0073e9SAndroid Build Coastguard Worker            [
434*da0073e9SAndroid Build Coastguard Worker                {
435*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 5, requires_grad=True),
436*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
437*da0073e9SAndroid Build Coastguard Worker                },
438*da0073e9SAndroid Build Coastguard Worker                {
439*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(1, requires_grad=True),
440*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(1).abs().requires_grad_(),
441*da0073e9SAndroid Build Coastguard Worker                },
442*da0073e9SAndroid Build Coastguard Worker                {
443*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
444*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
445*da0073e9SAndroid Build Coastguard Worker                },
446*da0073e9SAndroid Build Coastguard Worker            ],
447*da0073e9SAndroid Build Coastguard Worker        ),
448*da0073e9SAndroid Build Coastguard Worker        Example(
449*da0073e9SAndroid Build Coastguard Worker            LogNormal,
450*da0073e9SAndroid Build Coastguard Worker            [
451*da0073e9SAndroid Build Coastguard Worker                {
452*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 5, requires_grad=True),
453*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
454*da0073e9SAndroid Build Coastguard Worker                },
455*da0073e9SAndroid Build Coastguard Worker                {
456*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(1, requires_grad=True),
457*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(1).abs().requires_grad_(),
458*da0073e9SAndroid Build Coastguard Worker                },
459*da0073e9SAndroid Build Coastguard Worker                {
460*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
461*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
462*da0073e9SAndroid Build Coastguard Worker                },
463*da0073e9SAndroid Build Coastguard Worker            ],
464*da0073e9SAndroid Build Coastguard Worker        ),
465*da0073e9SAndroid Build Coastguard Worker        Example(
466*da0073e9SAndroid Build Coastguard Worker            LogisticNormal,
467*da0073e9SAndroid Build Coastguard Worker            [
468*da0073e9SAndroid Build Coastguard Worker                {
469*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 5).requires_grad_(),
470*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
471*da0073e9SAndroid Build Coastguard Worker                },
472*da0073e9SAndroid Build Coastguard Worker                {
473*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(1).requires_grad_(),
474*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(1).abs().requires_grad_(),
475*da0073e9SAndroid Build Coastguard Worker                },
476*da0073e9SAndroid Build Coastguard Worker                {
477*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
478*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
479*da0073e9SAndroid Build Coastguard Worker                },
480*da0073e9SAndroid Build Coastguard Worker            ],
481*da0073e9SAndroid Build Coastguard Worker        ),
482*da0073e9SAndroid Build Coastguard Worker        Example(
483*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal,
484*da0073e9SAndroid Build Coastguard Worker            [
485*da0073e9SAndroid Build Coastguard Worker                {
486*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 2, requires_grad=True),
487*da0073e9SAndroid Build Coastguard Worker                    "cov_factor": torch.randn(5, 2, 1, requires_grad=True),
488*da0073e9SAndroid Build Coastguard Worker                    "cov_diag": torch.tensor([2.0, 0.25], requires_grad=True),
489*da0073e9SAndroid Build Coastguard Worker                },
490*da0073e9SAndroid Build Coastguard Worker                {
491*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(4, 3, requires_grad=True),
492*da0073e9SAndroid Build Coastguard Worker                    "cov_factor": torch.randn(3, 2, requires_grad=True),
493*da0073e9SAndroid Build Coastguard Worker                    "cov_diag": torch.tensor([5.0, 1.5, 3.0], requires_grad=True),
494*da0073e9SAndroid Build Coastguard Worker                },
495*da0073e9SAndroid Build Coastguard Worker            ],
496*da0073e9SAndroid Build Coastguard Worker        ),
497*da0073e9SAndroid Build Coastguard Worker        Example(
498*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal,
499*da0073e9SAndroid Build Coastguard Worker            [
500*da0073e9SAndroid Build Coastguard Worker                {
501*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 2, requires_grad=True),
502*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
503*da0073e9SAndroid Build Coastguard Worker                        [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
504*da0073e9SAndroid Build Coastguard Worker                    ),
505*da0073e9SAndroid Build Coastguard Worker                },
506*da0073e9SAndroid Build Coastguard Worker                {
507*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(2, 3, requires_grad=True),
508*da0073e9SAndroid Build Coastguard Worker                    "precision_matrix": torch.tensor(
509*da0073e9SAndroid Build Coastguard Worker                        [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]],
510*da0073e9SAndroid Build Coastguard Worker                        requires_grad=True,
511*da0073e9SAndroid Build Coastguard Worker                    ),
512*da0073e9SAndroid Build Coastguard Worker                },
513*da0073e9SAndroid Build Coastguard Worker                {
514*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 3, 2, requires_grad=True),
515*da0073e9SAndroid Build Coastguard Worker                    "scale_tril": torch.tensor(
516*da0073e9SAndroid Build Coastguard Worker                        [
517*da0073e9SAndroid Build Coastguard Worker                            [[2.0, 0.0], [-0.5, 0.25]],
518*da0073e9SAndroid Build Coastguard Worker                            [[2.0, 0.0], [0.3, 0.25]],
519*da0073e9SAndroid Build Coastguard Worker                            [[5.0, 0.0], [-0.5, 1.5]],
520*da0073e9SAndroid Build Coastguard Worker                        ],
521*da0073e9SAndroid Build Coastguard Worker                        requires_grad=True,
522*da0073e9SAndroid Build Coastguard Worker                    ),
523*da0073e9SAndroid Build Coastguard Worker                },
524*da0073e9SAndroid Build Coastguard Worker                {
525*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, -1.0]),
526*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
527*da0073e9SAndroid Build Coastguard Worker                },
528*da0073e9SAndroid Build Coastguard Worker            ],
529*da0073e9SAndroid Build Coastguard Worker        ),
530*da0073e9SAndroid Build Coastguard Worker        Example(
531*da0073e9SAndroid Build Coastguard Worker            Normal,
532*da0073e9SAndroid Build Coastguard Worker            [
533*da0073e9SAndroid Build Coastguard Worker                {
534*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(5, 5, requires_grad=True),
535*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
536*da0073e9SAndroid Build Coastguard Worker                },
537*da0073e9SAndroid Build Coastguard Worker                {
538*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.randn(1, requires_grad=True),
539*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(1).abs().requires_grad_(),
540*da0073e9SAndroid Build Coastguard Worker                },
541*da0073e9SAndroid Build Coastguard Worker                {
542*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
543*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1e-5, 1e-5], requires_grad=True),
544*da0073e9SAndroid Build Coastguard Worker                },
545*da0073e9SAndroid Build Coastguard Worker            ],
546*da0073e9SAndroid Build Coastguard Worker        ),
547*da0073e9SAndroid Build Coastguard Worker        Example(
548*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical,
549*da0073e9SAndroid Build Coastguard Worker            [
550*da0073e9SAndroid Build Coastguard Worker                {
551*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
552*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
553*da0073e9SAndroid Build Coastguard Worker                    )
554*da0073e9SAndroid Build Coastguard Worker                },
555*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
556*da0073e9SAndroid Build Coastguard Worker                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
557*da0073e9SAndroid Build Coastguard Worker            ],
558*da0073e9SAndroid Build Coastguard Worker        ),
559*da0073e9SAndroid Build Coastguard Worker        Example(
560*da0073e9SAndroid Build Coastguard Worker            OneHotCategoricalStraightThrough,
561*da0073e9SAndroid Build Coastguard Worker            [
562*da0073e9SAndroid Build Coastguard Worker                {
563*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
564*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
565*da0073e9SAndroid Build Coastguard Worker                    )
566*da0073e9SAndroid Build Coastguard Worker                },
567*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
568*da0073e9SAndroid Build Coastguard Worker                {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
569*da0073e9SAndroid Build Coastguard Worker            ],
570*da0073e9SAndroid Build Coastguard Worker        ),
571*da0073e9SAndroid Build Coastguard Worker        Example(
572*da0073e9SAndroid Build Coastguard Worker            Pareto,
573*da0073e9SAndroid Build Coastguard Worker            [
574*da0073e9SAndroid Build Coastguard Worker                {"scale": 1.0, "alpha": 1.0},
575*da0073e9SAndroid Build Coastguard Worker                {
576*da0073e9SAndroid Build Coastguard Worker                    "scale": (torch.randn(5, 5).abs() + 0.1).requires_grad_(),
577*da0073e9SAndroid Build Coastguard Worker                    "alpha": (torch.randn(5, 5).abs() + 0.1).requires_grad_(),
578*da0073e9SAndroid Build Coastguard Worker                },
579*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([1.0]), "alpha": 1.0},
580*da0073e9SAndroid Build Coastguard Worker            ],
581*da0073e9SAndroid Build Coastguard Worker        ),
582*da0073e9SAndroid Build Coastguard Worker        Example(
583*da0073e9SAndroid Build Coastguard Worker            Poisson,
584*da0073e9SAndroid Build Coastguard Worker            [
585*da0073e9SAndroid Build Coastguard Worker                {
586*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(5, 5).abs().requires_grad_(),
587*da0073e9SAndroid Build Coastguard Worker                },
588*da0073e9SAndroid Build Coastguard Worker                {
589*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(3).abs().requires_grad_(),
590*da0073e9SAndroid Build Coastguard Worker                },
591*da0073e9SAndroid Build Coastguard Worker                {
592*da0073e9SAndroid Build Coastguard Worker                    "rate": 0.2,
593*da0073e9SAndroid Build Coastguard Worker                },
594*da0073e9SAndroid Build Coastguard Worker                {
595*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([0.0], requires_grad=True),
596*da0073e9SAndroid Build Coastguard Worker                },
597*da0073e9SAndroid Build Coastguard Worker                {
598*da0073e9SAndroid Build Coastguard Worker                    "rate": 0.0,
599*da0073e9SAndroid Build Coastguard Worker                },
600*da0073e9SAndroid Build Coastguard Worker            ],
601*da0073e9SAndroid Build Coastguard Worker        ),
602*da0073e9SAndroid Build Coastguard Worker        Example(
603*da0073e9SAndroid Build Coastguard Worker            RelaxedBernoulli,
604*da0073e9SAndroid Build Coastguard Worker            [
605*da0073e9SAndroid Build Coastguard Worker                {
606*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([0.5], requires_grad=True),
607*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True),
608*da0073e9SAndroid Build Coastguard Worker                },
609*da0073e9SAndroid Build Coastguard Worker                {
610*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([2.0]),
611*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([0.3]),
612*da0073e9SAndroid Build Coastguard Worker                },
613*da0073e9SAndroid Build Coastguard Worker                {
614*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([7.2]),
615*da0073e9SAndroid Build Coastguard Worker                    "logits": torch.tensor([-2.0, 2.0, 1.0, 5.0]),
616*da0073e9SAndroid Build Coastguard Worker                },
617*da0073e9SAndroid Build Coastguard Worker            ],
618*da0073e9SAndroid Build Coastguard Worker        ),
619*da0073e9SAndroid Build Coastguard Worker        Example(
620*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical,
621*da0073e9SAndroid Build Coastguard Worker            [
622*da0073e9SAndroid Build Coastguard Worker                {
623*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([0.5], requires_grad=True),
624*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
625*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True
626*da0073e9SAndroid Build Coastguard Worker                    ),
627*da0073e9SAndroid Build Coastguard Worker                },
628*da0073e9SAndroid Build Coastguard Worker                {
629*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([2.0]),
630*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
631*da0073e9SAndroid Build Coastguard Worker                },
632*da0073e9SAndroid Build Coastguard Worker                {
633*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([7.2]),
634*da0073e9SAndroid Build Coastguard Worker                    "logits": torch.tensor([[-2.0, 2.0], [1.0, 5.0]]),
635*da0073e9SAndroid Build Coastguard Worker                },
636*da0073e9SAndroid Build Coastguard Worker            ],
637*da0073e9SAndroid Build Coastguard Worker        ),
638*da0073e9SAndroid Build Coastguard Worker        Example(
639*da0073e9SAndroid Build Coastguard Worker            TransformedDistribution,
640*da0073e9SAndroid Build Coastguard Worker            [
641*da0073e9SAndroid Build Coastguard Worker                {
642*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
643*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, requires_grad=True),
644*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3).abs().requires_grad_(),
645*da0073e9SAndroid Build Coastguard Worker                    ),
646*da0073e9SAndroid Build Coastguard Worker                    "transforms": [],
647*da0073e9SAndroid Build Coastguard Worker                },
648*da0073e9SAndroid Build Coastguard Worker                {
649*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
650*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, requires_grad=True),
651*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3).abs().requires_grad_(),
652*da0073e9SAndroid Build Coastguard Worker                    ),
653*da0073e9SAndroid Build Coastguard Worker                    "transforms": ExpTransform(),
654*da0073e9SAndroid Build Coastguard Worker                },
655*da0073e9SAndroid Build Coastguard Worker                {
656*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
657*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5, requires_grad=True),
658*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5).abs().requires_grad_(),
659*da0073e9SAndroid Build Coastguard Worker                    ),
660*da0073e9SAndroid Build Coastguard Worker                    "transforms": [
661*da0073e9SAndroid Build Coastguard Worker                        AffineTransform(torch.randn(3, 5), torch.randn(3, 5)),
662*da0073e9SAndroid Build Coastguard Worker                        ExpTransform(),
663*da0073e9SAndroid Build Coastguard Worker                    ],
664*da0073e9SAndroid Build Coastguard Worker                },
665*da0073e9SAndroid Build Coastguard Worker                {
666*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(
667*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5, requires_grad=True),
668*da0073e9SAndroid Build Coastguard Worker                        torch.randn(2, 3, 5).abs().requires_grad_(),
669*da0073e9SAndroid Build Coastguard Worker                    ),
670*da0073e9SAndroid Build Coastguard Worker                    "transforms": AffineTransform(1, 2),
671*da0073e9SAndroid Build Coastguard Worker                },
672*da0073e9SAndroid Build Coastguard Worker                {
673*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Uniform(
674*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(1e8).log(), torch.tensor(1e10).log()
675*da0073e9SAndroid Build Coastguard Worker                    ),
676*da0073e9SAndroid Build Coastguard Worker                    "transforms": ExpTransform(),
677*da0073e9SAndroid Build Coastguard Worker                },
678*da0073e9SAndroid Build Coastguard Worker            ],
679*da0073e9SAndroid Build Coastguard Worker        ),
680*da0073e9SAndroid Build Coastguard Worker        Example(
681*da0073e9SAndroid Build Coastguard Worker            Uniform,
682*da0073e9SAndroid Build Coastguard Worker            [
683*da0073e9SAndroid Build Coastguard Worker                {
684*da0073e9SAndroid Build Coastguard Worker                    "low": torch.zeros(5, 5, requires_grad=True),
685*da0073e9SAndroid Build Coastguard Worker                    "high": torch.ones(5, 5, requires_grad=True),
686*da0073e9SAndroid Build Coastguard Worker                },
687*da0073e9SAndroid Build Coastguard Worker                {
688*da0073e9SAndroid Build Coastguard Worker                    "low": torch.zeros(1, requires_grad=True),
689*da0073e9SAndroid Build Coastguard Worker                    "high": torch.ones(1, requires_grad=True),
690*da0073e9SAndroid Build Coastguard Worker                },
691*da0073e9SAndroid Build Coastguard Worker                {
692*da0073e9SAndroid Build Coastguard Worker                    "low": torch.tensor([1.0, 1.0], requires_grad=True),
693*da0073e9SAndroid Build Coastguard Worker                    "high": torch.tensor([2.0, 3.0], requires_grad=True),
694*da0073e9SAndroid Build Coastguard Worker                },
695*da0073e9SAndroid Build Coastguard Worker            ],
696*da0073e9SAndroid Build Coastguard Worker        ),
697*da0073e9SAndroid Build Coastguard Worker        Example(
698*da0073e9SAndroid Build Coastguard Worker            Weibull,
699*da0073e9SAndroid Build Coastguard Worker            [
700*da0073e9SAndroid Build Coastguard Worker                {
701*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.randn(5, 5).abs().requires_grad_(),
702*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.randn(1).abs().requires_grad_(),
703*da0073e9SAndroid Build Coastguard Worker                }
704*da0073e9SAndroid Build Coastguard Worker            ],
705*da0073e9SAndroid Build Coastguard Worker        ),
706*da0073e9SAndroid Build Coastguard Worker        Example(
707*da0073e9SAndroid Build Coastguard Worker            Wishart,
708*da0073e9SAndroid Build Coastguard Worker            [
709*da0073e9SAndroid Build Coastguard Worker                {
710*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
711*da0073e9SAndroid Build Coastguard Worker                        [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
712*da0073e9SAndroid Build Coastguard Worker                    ),
713*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([3.0], requires_grad=True),
714*da0073e9SAndroid Build Coastguard Worker                },
715*da0073e9SAndroid Build Coastguard Worker                {
716*da0073e9SAndroid Build Coastguard Worker                    "precision_matrix": torch.tensor(
717*da0073e9SAndroid Build Coastguard Worker                        [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]],
718*da0073e9SAndroid Build Coastguard Worker                        requires_grad=True,
719*da0073e9SAndroid Build Coastguard Worker                    ),
720*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([5.0, 4], requires_grad=True),
721*da0073e9SAndroid Build Coastguard Worker                },
722*da0073e9SAndroid Build Coastguard Worker                {
723*da0073e9SAndroid Build Coastguard Worker                    "scale_tril": torch.tensor(
724*da0073e9SAndroid Build Coastguard Worker                        [
725*da0073e9SAndroid Build Coastguard Worker                            [[2.0, 0.0], [-0.5, 0.25]],
726*da0073e9SAndroid Build Coastguard Worker                            [[2.0, 0.0], [0.3, 0.25]],
727*da0073e9SAndroid Build Coastguard Worker                            [[5.0, 0.0], [-0.5, 1.5]],
728*da0073e9SAndroid Build Coastguard Worker                        ],
729*da0073e9SAndroid Build Coastguard Worker                        requires_grad=True,
730*da0073e9SAndroid Build Coastguard Worker                    ),
731*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([5.0, 3.5, 3], requires_grad=True),
732*da0073e9SAndroid Build Coastguard Worker                },
733*da0073e9SAndroid Build Coastguard Worker                {
734*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
735*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([3.0]),
736*da0073e9SAndroid Build Coastguard Worker                },
737*da0073e9SAndroid Build Coastguard Worker                {
738*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
739*da0073e9SAndroid Build Coastguard Worker                    "df": 3.0,
740*da0073e9SAndroid Build Coastguard Worker                },
741*da0073e9SAndroid Build Coastguard Worker            ],
742*da0073e9SAndroid Build Coastguard Worker        ),
743*da0073e9SAndroid Build Coastguard Worker        Example(
744*da0073e9SAndroid Build Coastguard Worker            MixtureSameFamily,
745*da0073e9SAndroid Build Coastguard Worker            [
746*da0073e9SAndroid Build Coastguard Worker                {
747*da0073e9SAndroid Build Coastguard Worker                    "mixture_distribution": Categorical(
748*da0073e9SAndroid Build Coastguard Worker                        torch.rand(5, requires_grad=True)
749*da0073e9SAndroid Build Coastguard Worker                    ),
750*da0073e9SAndroid Build Coastguard Worker                    "component_distribution": Normal(
751*da0073e9SAndroid Build Coastguard Worker                        torch.randn(5, requires_grad=True),
752*da0073e9SAndroid Build Coastguard Worker                        torch.rand(5, requires_grad=True),
753*da0073e9SAndroid Build Coastguard Worker                    ),
754*da0073e9SAndroid Build Coastguard Worker                },
755*da0073e9SAndroid Build Coastguard Worker                {
756*da0073e9SAndroid Build Coastguard Worker                    "mixture_distribution": Categorical(
757*da0073e9SAndroid Build Coastguard Worker                        torch.rand(5, requires_grad=True)
758*da0073e9SAndroid Build Coastguard Worker                    ),
759*da0073e9SAndroid Build Coastguard Worker                    "component_distribution": MultivariateNormal(
760*da0073e9SAndroid Build Coastguard Worker                        loc=torch.randn(5, 2, requires_grad=True),
761*da0073e9SAndroid Build Coastguard Worker                        covariance_matrix=torch.tensor(
762*da0073e9SAndroid Build Coastguard Worker                            [[2.0, 0.3], [0.3, 0.25]], requires_grad=True
763*da0073e9SAndroid Build Coastguard Worker                        ),
764*da0073e9SAndroid Build Coastguard Worker                    ),
765*da0073e9SAndroid Build Coastguard Worker                },
766*da0073e9SAndroid Build Coastguard Worker            ],
767*da0073e9SAndroid Build Coastguard Worker        ),
768*da0073e9SAndroid Build Coastguard Worker        Example(
769*da0073e9SAndroid Build Coastguard Worker            VonMises,
770*da0073e9SAndroid Build Coastguard Worker            [
771*da0073e9SAndroid Build Coastguard Worker                {
772*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor(1.0, requires_grad=True),
773*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor(10.0, requires_grad=True),
774*da0073e9SAndroid Build Coastguard Worker                },
775*da0073e9SAndroid Build Coastguard Worker                {
776*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([0.0, math.pi / 2], requires_grad=True),
777*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([1.0, 10.0], requires_grad=True),
778*da0073e9SAndroid Build Coastguard Worker                },
779*da0073e9SAndroid Build Coastguard Worker            ],
780*da0073e9SAndroid Build Coastguard Worker        ),
781*da0073e9SAndroid Build Coastguard Worker        Example(
782*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli,
783*da0073e9SAndroid Build Coastguard Worker            [
784*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
785*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([0.3], requires_grad=True)},
786*da0073e9SAndroid Build Coastguard Worker                {"probs": 0.3},
787*da0073e9SAndroid Build Coastguard Worker                {"logits": torch.tensor([0.0], requires_grad=True)},
788*da0073e9SAndroid Build Coastguard Worker            ],
789*da0073e9SAndroid Build Coastguard Worker        ),
790*da0073e9SAndroid Build Coastguard Worker        Example(
791*da0073e9SAndroid Build Coastguard Worker            InverseGamma,
792*da0073e9SAndroid Build Coastguard Worker            [
793*da0073e9SAndroid Build Coastguard Worker                {
794*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.randn(2, 3).exp().requires_grad_(),
795*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(2, 3).exp().requires_grad_(),
796*da0073e9SAndroid Build Coastguard Worker                },
797*da0073e9SAndroid Build Coastguard Worker                {
798*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.randn(1).exp().requires_grad_(),
799*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.randn(1).exp().requires_grad_(),
800*da0073e9SAndroid Build Coastguard Worker                },
801*da0073e9SAndroid Build Coastguard Worker            ],
802*da0073e9SAndroid Build Coastguard Worker        ),
803*da0073e9SAndroid Build Coastguard Worker    ]
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Workerdef _get_bad_examples():
807*da0073e9SAndroid Build Coastguard Worker    return [
808*da0073e9SAndroid Build Coastguard Worker        Example(
809*da0073e9SAndroid Build Coastguard Worker            Bernoulli,
810*da0073e9SAndroid Build Coastguard Worker            [
811*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
812*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([-0.5], requires_grad=True)},
813*da0073e9SAndroid Build Coastguard Worker                {"probs": 1.00001},
814*da0073e9SAndroid Build Coastguard Worker            ],
815*da0073e9SAndroid Build Coastguard Worker        ),
816*da0073e9SAndroid Build Coastguard Worker        Example(
817*da0073e9SAndroid Build Coastguard Worker            Beta,
818*da0073e9SAndroid Build Coastguard Worker            [
819*da0073e9SAndroid Build Coastguard Worker                {
820*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.tensor([0.0], requires_grad=True),
821*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.tensor([0.0], requires_grad=True),
822*da0073e9SAndroid Build Coastguard Worker                },
823*da0073e9SAndroid Build Coastguard Worker                {
824*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.tensor([-1.0], requires_grad=True),
825*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.tensor([-2.0], requires_grad=True),
826*da0073e9SAndroid Build Coastguard Worker                },
827*da0073e9SAndroid Build Coastguard Worker            ],
828*da0073e9SAndroid Build Coastguard Worker        ),
829*da0073e9SAndroid Build Coastguard Worker        Example(
830*da0073e9SAndroid Build Coastguard Worker            Geometric,
831*da0073e9SAndroid Build Coastguard Worker            [
832*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
833*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([-0.3], requires_grad=True)},
834*da0073e9SAndroid Build Coastguard Worker                {"probs": 1.00000001},
835*da0073e9SAndroid Build Coastguard Worker            ],
836*da0073e9SAndroid Build Coastguard Worker        ),
837*da0073e9SAndroid Build Coastguard Worker        Example(
838*da0073e9SAndroid Build Coastguard Worker            Categorical,
839*da0073e9SAndroid Build Coastguard Worker            [
840*da0073e9SAndroid Build Coastguard Worker                {
841*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
842*da0073e9SAndroid Build Coastguard Worker                        [[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
843*da0073e9SAndroid Build Coastguard Worker                    )
844*da0073e9SAndroid Build Coastguard Worker                },
845*da0073e9SAndroid Build Coastguard Worker                {
846*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
847*da0073e9SAndroid Build Coastguard Worker                        [[-1.0, 10.0], [0.0, -1.0]], requires_grad=True
848*da0073e9SAndroid Build Coastguard Worker                    )
849*da0073e9SAndroid Build Coastguard Worker                },
850*da0073e9SAndroid Build Coastguard Worker            ],
851*da0073e9SAndroid Build Coastguard Worker        ),
852*da0073e9SAndroid Build Coastguard Worker        Example(
853*da0073e9SAndroid Build Coastguard Worker            Binomial,
854*da0073e9SAndroid Build Coastguard Worker            [
855*da0073e9SAndroid Build Coastguard Worker                {
856*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
857*da0073e9SAndroid Build Coastguard Worker                        [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
858*da0073e9SAndroid Build Coastguard Worker                    ),
859*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
860*da0073e9SAndroid Build Coastguard Worker                },
861*da0073e9SAndroid Build Coastguard Worker                {
862*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
863*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
864*da0073e9SAndroid Build Coastguard Worker                },
865*da0073e9SAndroid Build Coastguard Worker            ],
866*da0073e9SAndroid Build Coastguard Worker        ),
867*da0073e9SAndroid Build Coastguard Worker        Example(
868*da0073e9SAndroid Build Coastguard Worker            NegativeBinomial,
869*da0073e9SAndroid Build Coastguard Worker            [
870*da0073e9SAndroid Build Coastguard Worker                {
871*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
872*da0073e9SAndroid Build Coastguard Worker                        [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True
873*da0073e9SAndroid Build Coastguard Worker                    ),
874*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
875*da0073e9SAndroid Build Coastguard Worker                },
876*da0073e9SAndroid Build Coastguard Worker                {
877*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
878*da0073e9SAndroid Build Coastguard Worker                    "total_count": 10,
879*da0073e9SAndroid Build Coastguard Worker                },
880*da0073e9SAndroid Build Coastguard Worker            ],
881*da0073e9SAndroid Build Coastguard Worker        ),
882*da0073e9SAndroid Build Coastguard Worker        Example(
883*da0073e9SAndroid Build Coastguard Worker            Cauchy,
884*da0073e9SAndroid Build Coastguard Worker            [
885*da0073e9SAndroid Build Coastguard Worker                {"loc": 0.0, "scale": -1.0},
886*da0073e9SAndroid Build Coastguard Worker                {"loc": torch.tensor([0.0]), "scale": 0.0},
887*da0073e9SAndroid Build Coastguard Worker                {
888*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([[0.0], [-2.0]]),
889*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([[-0.000001], [1.0]]),
890*da0073e9SAndroid Build Coastguard Worker                },
891*da0073e9SAndroid Build Coastguard Worker            ],
892*da0073e9SAndroid Build Coastguard Worker        ),
893*da0073e9SAndroid Build Coastguard Worker        Example(
894*da0073e9SAndroid Build Coastguard Worker            Chi2,
895*da0073e9SAndroid Build Coastguard Worker            [
896*da0073e9SAndroid Build Coastguard Worker                {"df": torch.tensor([0.0], requires_grad=True)},
897*da0073e9SAndroid Build Coastguard Worker                {"df": torch.tensor([-2.0], requires_grad=True)},
898*da0073e9SAndroid Build Coastguard Worker            ],
899*da0073e9SAndroid Build Coastguard Worker        ),
900*da0073e9SAndroid Build Coastguard Worker        Example(
901*da0073e9SAndroid Build Coastguard Worker            StudentT,
902*da0073e9SAndroid Build Coastguard Worker            [
903*da0073e9SAndroid Build Coastguard Worker                {"df": torch.tensor([0.0], requires_grad=True)},
904*da0073e9SAndroid Build Coastguard Worker                {"df": torch.tensor([-2.0], requires_grad=True)},
905*da0073e9SAndroid Build Coastguard Worker            ],
906*da0073e9SAndroid Build Coastguard Worker        ),
907*da0073e9SAndroid Build Coastguard Worker        Example(
908*da0073e9SAndroid Build Coastguard Worker            Dirichlet,
909*da0073e9SAndroid Build Coastguard Worker            [
910*da0073e9SAndroid Build Coastguard Worker                {"concentration": torch.tensor([0.0], requires_grad=True)},
911*da0073e9SAndroid Build Coastguard Worker                {"concentration": torch.tensor([-2.0], requires_grad=True)},
912*da0073e9SAndroid Build Coastguard Worker            ],
913*da0073e9SAndroid Build Coastguard Worker        ),
914*da0073e9SAndroid Build Coastguard Worker        Example(
915*da0073e9SAndroid Build Coastguard Worker            Exponential,
916*da0073e9SAndroid Build Coastguard Worker            [
917*da0073e9SAndroid Build Coastguard Worker                {"rate": torch.tensor([0.0, 0.0], requires_grad=True)},
918*da0073e9SAndroid Build Coastguard Worker                {"rate": torch.tensor([-2.0], requires_grad=True)},
919*da0073e9SAndroid Build Coastguard Worker            ],
920*da0073e9SAndroid Build Coastguard Worker        ),
921*da0073e9SAndroid Build Coastguard Worker        Example(
922*da0073e9SAndroid Build Coastguard Worker            FisherSnedecor,
923*da0073e9SAndroid Build Coastguard Worker            [
924*da0073e9SAndroid Build Coastguard Worker                {
925*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.tensor([0.0, 0.0], requires_grad=True),
926*da0073e9SAndroid Build Coastguard Worker                    "df2": torch.tensor([-1.0, -100.0], requires_grad=True),
927*da0073e9SAndroid Build Coastguard Worker                },
928*da0073e9SAndroid Build Coastguard Worker                {
929*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.tensor([1.0, 1.0], requires_grad=True),
930*da0073e9SAndroid Build Coastguard Worker                    "df2": torch.tensor([0.0, 0.0], requires_grad=True),
931*da0073e9SAndroid Build Coastguard Worker                },
932*da0073e9SAndroid Build Coastguard Worker            ],
933*da0073e9SAndroid Build Coastguard Worker        ),
934*da0073e9SAndroid Build Coastguard Worker        Example(
935*da0073e9SAndroid Build Coastguard Worker            Gamma,
936*da0073e9SAndroid Build Coastguard Worker            [
937*da0073e9SAndroid Build Coastguard Worker                {
938*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([0.0, 0.0], requires_grad=True),
939*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([-1.0, -100.0], requires_grad=True),
940*da0073e9SAndroid Build Coastguard Worker                },
941*da0073e9SAndroid Build Coastguard Worker                {
942*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([1.0, 1.0], requires_grad=True),
943*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([0.0, 0.0], requires_grad=True),
944*da0073e9SAndroid Build Coastguard Worker                },
945*da0073e9SAndroid Build Coastguard Worker            ],
946*da0073e9SAndroid Build Coastguard Worker        ),
947*da0073e9SAndroid Build Coastguard Worker        Example(
948*da0073e9SAndroid Build Coastguard Worker            Gumbel,
949*da0073e9SAndroid Build Coastguard Worker            [
950*da0073e9SAndroid Build Coastguard Worker                {
951*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
952*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
953*da0073e9SAndroid Build Coastguard Worker                },
954*da0073e9SAndroid Build Coastguard Worker                {
955*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
956*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
957*da0073e9SAndroid Build Coastguard Worker                },
958*da0073e9SAndroid Build Coastguard Worker            ],
959*da0073e9SAndroid Build Coastguard Worker        ),
960*da0073e9SAndroid Build Coastguard Worker        Example(
961*da0073e9SAndroid Build Coastguard Worker            HalfCauchy,
962*da0073e9SAndroid Build Coastguard Worker            [
963*da0073e9SAndroid Build Coastguard Worker                {"scale": -1.0},
964*da0073e9SAndroid Build Coastguard Worker                {"scale": 0.0},
965*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([[-0.000001], [1.0]])},
966*da0073e9SAndroid Build Coastguard Worker            ],
967*da0073e9SAndroid Build Coastguard Worker        ),
968*da0073e9SAndroid Build Coastguard Worker        Example(
969*da0073e9SAndroid Build Coastguard Worker            HalfNormal,
970*da0073e9SAndroid Build Coastguard Worker            [
971*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([0.0, 1.0], requires_grad=True)},
972*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([1.0, -1.0], requires_grad=True)},
973*da0073e9SAndroid Build Coastguard Worker            ],
974*da0073e9SAndroid Build Coastguard Worker        ),
975*da0073e9SAndroid Build Coastguard Worker        Example(
976*da0073e9SAndroid Build Coastguard Worker            LKJCholesky,
977*da0073e9SAndroid Build Coastguard Worker            [
978*da0073e9SAndroid Build Coastguard Worker                {"dim": -2, "concentration": 0.1},
979*da0073e9SAndroid Build Coastguard Worker                {
980*da0073e9SAndroid Build Coastguard Worker                    "dim": 1,
981*da0073e9SAndroid Build Coastguard Worker                    "concentration": 2.0,
982*da0073e9SAndroid Build Coastguard Worker                },
983*da0073e9SAndroid Build Coastguard Worker                {
984*da0073e9SAndroid Build Coastguard Worker                    "dim": 2,
985*da0073e9SAndroid Build Coastguard Worker                    "concentration": 0.0,
986*da0073e9SAndroid Build Coastguard Worker                },
987*da0073e9SAndroid Build Coastguard Worker            ],
988*da0073e9SAndroid Build Coastguard Worker        ),
989*da0073e9SAndroid Build Coastguard Worker        Example(
990*da0073e9SAndroid Build Coastguard Worker            Laplace,
991*da0073e9SAndroid Build Coastguard Worker            [
992*da0073e9SAndroid Build Coastguard Worker                {
993*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
994*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
995*da0073e9SAndroid Build Coastguard Worker                },
996*da0073e9SAndroid Build Coastguard Worker                {
997*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
998*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
999*da0073e9SAndroid Build Coastguard Worker                },
1000*da0073e9SAndroid Build Coastguard Worker            ],
1001*da0073e9SAndroid Build Coastguard Worker        ),
1002*da0073e9SAndroid Build Coastguard Worker        Example(
1003*da0073e9SAndroid Build Coastguard Worker            LogNormal,
1004*da0073e9SAndroid Build Coastguard Worker            [
1005*da0073e9SAndroid Build Coastguard Worker                {
1006*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1007*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
1008*da0073e9SAndroid Build Coastguard Worker                },
1009*da0073e9SAndroid Build Coastguard Worker                {
1010*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1011*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
1012*da0073e9SAndroid Build Coastguard Worker                },
1013*da0073e9SAndroid Build Coastguard Worker            ],
1014*da0073e9SAndroid Build Coastguard Worker        ),
1015*da0073e9SAndroid Build Coastguard Worker        Example(
1016*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal,
1017*da0073e9SAndroid Build Coastguard Worker            [
1018*da0073e9SAndroid Build Coastguard Worker                {
1019*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1020*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
1021*da0073e9SAndroid Build Coastguard Worker                        [[1.0, 0.0], [0.0, -2.0]], requires_grad=True
1022*da0073e9SAndroid Build Coastguard Worker                    ),
1023*da0073e9SAndroid Build Coastguard Worker                },
1024*da0073e9SAndroid Build Coastguard Worker            ],
1025*da0073e9SAndroid Build Coastguard Worker        ),
1026*da0073e9SAndroid Build Coastguard Worker        Example(
1027*da0073e9SAndroid Build Coastguard Worker            Normal,
1028*da0073e9SAndroid Build Coastguard Worker            [
1029*da0073e9SAndroid Build Coastguard Worker                {
1030*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1031*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0, 1.0], requires_grad=True),
1032*da0073e9SAndroid Build Coastguard Worker                },
1033*da0073e9SAndroid Build Coastguard Worker                {
1034*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 1.0], requires_grad=True),
1035*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0, -1.0], requires_grad=True),
1036*da0073e9SAndroid Build Coastguard Worker                },
1037*da0073e9SAndroid Build Coastguard Worker                {
1038*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([1.0, 0.0], requires_grad=True),
1039*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1e-5, -1e-5], requires_grad=True),
1040*da0073e9SAndroid Build Coastguard Worker                },
1041*da0073e9SAndroid Build Coastguard Worker            ],
1042*da0073e9SAndroid Build Coastguard Worker        ),
1043*da0073e9SAndroid Build Coastguard Worker        Example(
1044*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical,
1045*da0073e9SAndroid Build Coastguard Worker            [
1046*da0073e9SAndroid Build Coastguard Worker                {
1047*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
1048*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True
1049*da0073e9SAndroid Build Coastguard Worker                    )
1050*da0073e9SAndroid Build Coastguard Worker                },
1051*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
1052*da0073e9SAndroid Build Coastguard Worker            ],
1053*da0073e9SAndroid Build Coastguard Worker        ),
1054*da0073e9SAndroid Build Coastguard Worker        Example(
1055*da0073e9SAndroid Build Coastguard Worker            OneHotCategoricalStraightThrough,
1056*da0073e9SAndroid Build Coastguard Worker            [
1057*da0073e9SAndroid Build Coastguard Worker                {
1058*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
1059*da0073e9SAndroid Build Coastguard Worker                        [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True
1060*da0073e9SAndroid Build Coastguard Worker                    )
1061*da0073e9SAndroid Build Coastguard Worker                },
1062*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
1063*da0073e9SAndroid Build Coastguard Worker            ],
1064*da0073e9SAndroid Build Coastguard Worker        ),
1065*da0073e9SAndroid Build Coastguard Worker        Example(
1066*da0073e9SAndroid Build Coastguard Worker            Pareto,
1067*da0073e9SAndroid Build Coastguard Worker            [
1068*da0073e9SAndroid Build Coastguard Worker                {"scale": 0.0, "alpha": 0.0},
1069*da0073e9SAndroid Build Coastguard Worker                {
1070*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0, 0.0], requires_grad=True),
1071*da0073e9SAndroid Build Coastguard Worker                    "alpha": torch.tensor([-1e-5, 0.0], requires_grad=True),
1072*da0073e9SAndroid Build Coastguard Worker                },
1073*da0073e9SAndroid Build Coastguard Worker                {"scale": torch.tensor([1.0]), "alpha": -1.0},
1074*da0073e9SAndroid Build Coastguard Worker            ],
1075*da0073e9SAndroid Build Coastguard Worker        ),
1076*da0073e9SAndroid Build Coastguard Worker        Example(
1077*da0073e9SAndroid Build Coastguard Worker            Poisson,
1078*da0073e9SAndroid Build Coastguard Worker            [
1079*da0073e9SAndroid Build Coastguard Worker                {
1080*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([-0.1], requires_grad=True),
1081*da0073e9SAndroid Build Coastguard Worker                },
1082*da0073e9SAndroid Build Coastguard Worker                {
1083*da0073e9SAndroid Build Coastguard Worker                    "rate": -1.0,
1084*da0073e9SAndroid Build Coastguard Worker                },
1085*da0073e9SAndroid Build Coastguard Worker            ],
1086*da0073e9SAndroid Build Coastguard Worker        ),
1087*da0073e9SAndroid Build Coastguard Worker        Example(
1088*da0073e9SAndroid Build Coastguard Worker            RelaxedBernoulli,
1089*da0073e9SAndroid Build Coastguard Worker            [
1090*da0073e9SAndroid Build Coastguard Worker                {
1091*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([1.5], requires_grad=True),
1092*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([1.7, 0.2, 0.4], requires_grad=True),
1093*da0073e9SAndroid Build Coastguard Worker                },
1094*da0073e9SAndroid Build Coastguard Worker                {
1095*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([2.0]),
1096*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([-1.0]),
1097*da0073e9SAndroid Build Coastguard Worker                },
1098*da0073e9SAndroid Build Coastguard Worker            ],
1099*da0073e9SAndroid Build Coastguard Worker        ),
1100*da0073e9SAndroid Build Coastguard Worker        Example(
1101*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical,
1102*da0073e9SAndroid Build Coastguard Worker            [
1103*da0073e9SAndroid Build Coastguard Worker                {
1104*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([0.5], requires_grad=True),
1105*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor(
1106*da0073e9SAndroid Build Coastguard Worker                        [[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True
1107*da0073e9SAndroid Build Coastguard Worker                    ),
1108*da0073e9SAndroid Build Coastguard Worker                },
1109*da0073e9SAndroid Build Coastguard Worker                {
1110*da0073e9SAndroid Build Coastguard Worker                    "temperature": torch.tensor([2.0]),
1111*da0073e9SAndroid Build Coastguard Worker                    "probs": torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]),
1112*da0073e9SAndroid Build Coastguard Worker                },
1113*da0073e9SAndroid Build Coastguard Worker            ],
1114*da0073e9SAndroid Build Coastguard Worker        ),
1115*da0073e9SAndroid Build Coastguard Worker        Example(
1116*da0073e9SAndroid Build Coastguard Worker            TransformedDistribution,
1117*da0073e9SAndroid Build Coastguard Worker            [
1118*da0073e9SAndroid Build Coastguard Worker                {
1119*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(0, 1),
1120*da0073e9SAndroid Build Coastguard Worker                    "transforms": lambda x: x,
1121*da0073e9SAndroid Build Coastguard Worker                },
1122*da0073e9SAndroid Build Coastguard Worker                {
1123*da0073e9SAndroid Build Coastguard Worker                    "base_distribution": Normal(0, 1),
1124*da0073e9SAndroid Build Coastguard Worker                    "transforms": [lambda x: x],
1125*da0073e9SAndroid Build Coastguard Worker                },
1126*da0073e9SAndroid Build Coastguard Worker            ],
1127*da0073e9SAndroid Build Coastguard Worker        ),
1128*da0073e9SAndroid Build Coastguard Worker        Example(
1129*da0073e9SAndroid Build Coastguard Worker            Uniform,
1130*da0073e9SAndroid Build Coastguard Worker            [
1131*da0073e9SAndroid Build Coastguard Worker                {
1132*da0073e9SAndroid Build Coastguard Worker                    "low": torch.tensor([2.0], requires_grad=True),
1133*da0073e9SAndroid Build Coastguard Worker                    "high": torch.tensor([2.0], requires_grad=True),
1134*da0073e9SAndroid Build Coastguard Worker                },
1135*da0073e9SAndroid Build Coastguard Worker                {
1136*da0073e9SAndroid Build Coastguard Worker                    "low": torch.tensor([0.0], requires_grad=True),
1137*da0073e9SAndroid Build Coastguard Worker                    "high": torch.tensor([0.0], requires_grad=True),
1138*da0073e9SAndroid Build Coastguard Worker                },
1139*da0073e9SAndroid Build Coastguard Worker                {
1140*da0073e9SAndroid Build Coastguard Worker                    "low": torch.tensor([1.0], requires_grad=True),
1141*da0073e9SAndroid Build Coastguard Worker                    "high": torch.tensor([0.0], requires_grad=True),
1142*da0073e9SAndroid Build Coastguard Worker                },
1143*da0073e9SAndroid Build Coastguard Worker            ],
1144*da0073e9SAndroid Build Coastguard Worker        ),
1145*da0073e9SAndroid Build Coastguard Worker        Example(
1146*da0073e9SAndroid Build Coastguard Worker            Weibull,
1147*da0073e9SAndroid Build Coastguard Worker            [
1148*da0073e9SAndroid Build Coastguard Worker                {
1149*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([0.0], requires_grad=True),
1150*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([0.0], requires_grad=True),
1151*da0073e9SAndroid Build Coastguard Worker                },
1152*da0073e9SAndroid Build Coastguard Worker                {
1153*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0], requires_grad=True),
1154*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([-1.0], requires_grad=True),
1155*da0073e9SAndroid Build Coastguard Worker                },
1156*da0073e9SAndroid Build Coastguard Worker            ],
1157*da0073e9SAndroid Build Coastguard Worker        ),
1158*da0073e9SAndroid Build Coastguard Worker        Example(
1159*da0073e9SAndroid Build Coastguard Worker            Wishart,
1160*da0073e9SAndroid Build Coastguard Worker            [
1161*da0073e9SAndroid Build Coastguard Worker                {
1162*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
1163*da0073e9SAndroid Build Coastguard Worker                        [[1.0, 0.0], [0.0, -2.0]], requires_grad=True
1164*da0073e9SAndroid Build Coastguard Worker                    ),
1165*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([1.5], requires_grad=True),
1166*da0073e9SAndroid Build Coastguard Worker                },
1167*da0073e9SAndroid Build Coastguard Worker                {
1168*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
1169*da0073e9SAndroid Build Coastguard Worker                        [[1.0, 1.0], [1.0, -2.0]], requires_grad=True
1170*da0073e9SAndroid Build Coastguard Worker                    ),
1171*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([3.0], requires_grad=True),
1172*da0073e9SAndroid Build Coastguard Worker                },
1173*da0073e9SAndroid Build Coastguard Worker                {
1174*da0073e9SAndroid Build Coastguard Worker                    "covariance_matrix": torch.tensor(
1175*da0073e9SAndroid Build Coastguard Worker                        [[1.0, 1.0], [1.0, -2.0]], requires_grad=True
1176*da0073e9SAndroid Build Coastguard Worker                    ),
1177*da0073e9SAndroid Build Coastguard Worker                    "df": 3.0,
1178*da0073e9SAndroid Build Coastguard Worker                },
1179*da0073e9SAndroid Build Coastguard Worker            ],
1180*da0073e9SAndroid Build Coastguard Worker        ),
1181*da0073e9SAndroid Build Coastguard Worker        Example(
1182*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli,
1183*da0073e9SAndroid Build Coastguard Worker            [
1184*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
1185*da0073e9SAndroid Build Coastguard Worker                {"probs": torch.tensor([-0.5], requires_grad=True)},
1186*da0073e9SAndroid Build Coastguard Worker                {"probs": 1.00001},
1187*da0073e9SAndroid Build Coastguard Worker            ],
1188*da0073e9SAndroid Build Coastguard Worker        ),
1189*da0073e9SAndroid Build Coastguard Worker        Example(
1190*da0073e9SAndroid Build Coastguard Worker            InverseGamma,
1191*da0073e9SAndroid Build Coastguard Worker            [
1192*da0073e9SAndroid Build Coastguard Worker                {
1193*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([0.0, 0.0], requires_grad=True),
1194*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([-1.0, -100.0], requires_grad=True),
1195*da0073e9SAndroid Build Coastguard Worker                },
1196*da0073e9SAndroid Build Coastguard Worker                {
1197*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([1.0, 1.0], requires_grad=True),
1198*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([0.0, 0.0], requires_grad=True),
1199*da0073e9SAndroid Build Coastguard Worker                },
1200*da0073e9SAndroid Build Coastguard Worker            ],
1201*da0073e9SAndroid Build Coastguard Worker        ),
1202*da0073e9SAndroid Build Coastguard Worker    ]
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Workerclass DistributionsTestCase(TestCase):
1206*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
1207*da0073e9SAndroid Build Coastguard Worker        """The tests assume that the validation flag is set."""
1208*da0073e9SAndroid Build Coastguard Worker        torch.distributions.Distribution.set_default_validate_args(True)
1209*da0073e9SAndroid Build Coastguard Worker        super().setUp()
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
1213*da0073e9SAndroid Build Coastguard Workerclass TestDistributions(DistributionsTestCase):
1214*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
1215*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker    def _gradcheck_log_prob(self, dist_ctor, ctor_params):
1218*da0073e9SAndroid Build Coastguard Worker        # performs gradient checks on log_prob
1219*da0073e9SAndroid Build Coastguard Worker        distribution = dist_ctor(*ctor_params)
1220*da0073e9SAndroid Build Coastguard Worker        s = distribution.sample()
1221*da0073e9SAndroid Build Coastguard Worker        if not distribution.support.is_discrete:
1222*da0073e9SAndroid Build Coastguard Worker            s = s.detach().requires_grad_()
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        expected_shape = distribution.batch_shape + distribution.event_shape
1225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.size(), expected_shape)
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker        def apply_fn(s, *params):
1228*da0073e9SAndroid Build Coastguard Worker            return dist_ctor(*params).log_prob(s)
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker        gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True)
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    def _check_forward_ad(self, fn):
1233*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
1234*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(1.0)
1235*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor(1.0)
1236*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(x, t)
1237*da0073e9SAndroid Build Coastguard Worker            dual_out = fn(dual)
1238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1239*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0
1240*da0073e9SAndroid Build Coastguard Worker            )
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    def _check_log_prob(self, dist, asset_fn):
1243*da0073e9SAndroid Build Coastguard Worker        # checks that the log_prob matches a reference function
1244*da0073e9SAndroid Build Coastguard Worker        s = dist.sample()
1245*da0073e9SAndroid Build Coastguard Worker        log_probs = dist.log_prob(s)
1246*da0073e9SAndroid Build Coastguard Worker        log_probs_data_flat = log_probs.view(-1)
1247*da0073e9SAndroid Build Coastguard Worker        s_data_flat = s.view(len(log_probs_data_flat), -1)
1248*da0073e9SAndroid Build Coastguard Worker        for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)):
1249*da0073e9SAndroid Build Coastguard Worker            asset_fn(i, val.squeeze(), log_prob)
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker    def _check_sampler_sampler(
1252*da0073e9SAndroid Build Coastguard Worker        self,
1253*da0073e9SAndroid Build Coastguard Worker        torch_dist,
1254*da0073e9SAndroid Build Coastguard Worker        ref_dist,
1255*da0073e9SAndroid Build Coastguard Worker        message,
1256*da0073e9SAndroid Build Coastguard Worker        multivariate=False,
1257*da0073e9SAndroid Build Coastguard Worker        circular=False,
1258*da0073e9SAndroid Build Coastguard Worker        num_samples=10000,
1259*da0073e9SAndroid Build Coastguard Worker        failure_rate=1e-3,
1260*da0073e9SAndroid Build Coastguard Worker    ):
1261*da0073e9SAndroid Build Coastguard Worker        # Checks that the .sample() method matches a reference function.
1262*da0073e9SAndroid Build Coastguard Worker        torch_samples = torch_dist.sample((num_samples,)).squeeze()
1263*da0073e9SAndroid Build Coastguard Worker        torch_samples = torch_samples.cpu().numpy()
1264*da0073e9SAndroid Build Coastguard Worker        ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
1265*da0073e9SAndroid Build Coastguard Worker        if multivariate:
1266*da0073e9SAndroid Build Coastguard Worker            # Project onto a random axis.
1267*da0073e9SAndroid Build Coastguard Worker            axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
1268*da0073e9SAndroid Build Coastguard Worker            axis /= np.linalg.norm(axis)
1269*da0073e9SAndroid Build Coastguard Worker            torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
1270*da0073e9SAndroid Build Coastguard Worker            ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
1271*da0073e9SAndroid Build Coastguard Worker        samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
1272*da0073e9SAndroid Build Coastguard Worker        if circular:
1273*da0073e9SAndroid Build Coastguard Worker            samples = [(np.cos(x), v) for (x, v) in samples]
1274*da0073e9SAndroid Build Coastguard Worker        shuffle(
1275*da0073e9SAndroid Build Coastguard Worker            samples
1276*da0073e9SAndroid Build Coastguard Worker        )  # necessary to prevent stable sort from making uneven bins for discrete
1277*da0073e9SAndroid Build Coastguard Worker        samples.sort(key=lambda x: x[0])
1278*da0073e9SAndroid Build Coastguard Worker        samples = np.array(samples)[:, 1]
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker        # Aggregate into bins filled with roughly zero-mean unit-variance RVs.
1281*da0073e9SAndroid Build Coastguard Worker        num_bins = 10
1282*da0073e9SAndroid Build Coastguard Worker        samples_per_bin = len(samples) // num_bins
1283*da0073e9SAndroid Build Coastguard Worker        bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
1284*da0073e9SAndroid Build Coastguard Worker        stddev = samples_per_bin**-0.5
1285*da0073e9SAndroid Build Coastguard Worker        threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
1286*da0073e9SAndroid Build Coastguard Worker        message = f"{message}.sample() is biased:\n{bins}"
1287*da0073e9SAndroid Build Coastguard Worker        for bias in bins:
1288*da0073e9SAndroid Build Coastguard Worker            self.assertLess(-threshold, bias, message)
1289*da0073e9SAndroid Build Coastguard Worker            self.assertLess(bias, threshold, message)
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1292*da0073e9SAndroid Build Coastguard Worker    def _check_sampler_discrete(
1293*da0073e9SAndroid Build Coastguard Worker        self, torch_dist, ref_dist, message, num_samples=10000, failure_rate=1e-3
1294*da0073e9SAndroid Build Coastguard Worker    ):
1295*da0073e9SAndroid Build Coastguard Worker        """Runs a Chi2-test for the support, but ignores tail instead of combining"""
1296*da0073e9SAndroid Build Coastguard Worker        torch_samples = torch_dist.sample((num_samples,)).squeeze()
1297*da0073e9SAndroid Build Coastguard Worker        torch_samples = (
1298*da0073e9SAndroid Build Coastguard Worker            torch_samples.float()
1299*da0073e9SAndroid Build Coastguard Worker            if torch_samples.dtype == torch.bfloat16
1300*da0073e9SAndroid Build Coastguard Worker            else torch_samples
1301*da0073e9SAndroid Build Coastguard Worker        )
1302*da0073e9SAndroid Build Coastguard Worker        torch_samples = torch_samples.cpu().numpy()
1303*da0073e9SAndroid Build Coastguard Worker        unique, counts = np.unique(torch_samples, return_counts=True)
1304*da0073e9SAndroid Build Coastguard Worker        pmf = ref_dist.pmf(unique)
1305*da0073e9SAndroid Build Coastguard Worker        pmf = pmf / pmf.sum()  # renormalize to 1.0 for chisq test
1306*da0073e9SAndroid Build Coastguard Worker        msk = (counts > 5) & ((pmf * num_samples) > 5)
1307*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(
1308*da0073e9SAndroid Build Coastguard Worker            pmf[msk].sum(),
1309*da0073e9SAndroid Build Coastguard Worker            0.9,
1310*da0073e9SAndroid Build Coastguard Worker            "Distribution is too sparse for test; try increasing num_samples",
1311*da0073e9SAndroid Build Coastguard Worker        )
1312*da0073e9SAndroid Build Coastguard Worker        # Add a remainder bucket that combines counts for all values
1313*da0073e9SAndroid Build Coastguard Worker        # below threshold, if such values exist (i.e. mask has False entries).
1314*da0073e9SAndroid Build Coastguard Worker        if not msk.all():
1315*da0073e9SAndroid Build Coastguard Worker            counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)])
1316*da0073e9SAndroid Build Coastguard Worker            pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)])
1317*da0073e9SAndroid Build Coastguard Worker        chisq, p = scipy.stats.chisquare(counts, pmf * num_samples)
1318*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(p, failure_rate, message)
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker    def _check_enumerate_support(self, dist, examples):
1321*da0073e9SAndroid Build Coastguard Worker        for params, expected in examples:
1322*da0073e9SAndroid Build Coastguard Worker            params = {k: torch.tensor(v) for k, v in params.items()}
1323*da0073e9SAndroid Build Coastguard Worker            d = dist(**params)
1324*da0073e9SAndroid Build Coastguard Worker            actual = d.enumerate_support(expand=False)
1325*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(expected, dtype=actual.dtype)
1326*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
1327*da0073e9SAndroid Build Coastguard Worker            actual = d.enumerate_support(expand=True)
1328*da0073e9SAndroid Build Coastguard Worker            expected_with_expand = expected.expand(
1329*da0073e9SAndroid Build Coastguard Worker                (-1,) + d.batch_shape + d.event_shape
1330*da0073e9SAndroid Build Coastguard Worker            )
1331*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected_with_expand)
1332*da0073e9SAndroid Build Coastguard Worker
1333*da0073e9SAndroid Build Coastguard Worker    def test_repr(self):
1334*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1335*da0073e9SAndroid Build Coastguard Worker            for param in params:
1336*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
1337*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker    def test_sample_detached(self):
1340*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1341*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
1342*da0073e9SAndroid Build Coastguard Worker                variable_params = [
1343*da0073e9SAndroid Build Coastguard Worker                    p for p in param.values() if getattr(p, "requires_grad", False)
1344*da0073e9SAndroid Build Coastguard Worker                ]
1345*da0073e9SAndroid Build Coastguard Worker                if not variable_params:
1346*da0073e9SAndroid Build Coastguard Worker                    continue
1347*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
1348*da0073e9SAndroid Build Coastguard Worker                sample = dist.sample()
1349*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(
1350*da0073e9SAndroid Build Coastguard Worker                    sample.requires_grad,
1351*da0073e9SAndroid Build Coastguard Worker                    msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .sample() is not detached",
1352*da0073e9SAndroid Build Coastguard Worker                )
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
1355*da0073e9SAndroid Build Coastguard Worker    def test_rsample_requires_grad(self):
1356*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1357*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
1358*da0073e9SAndroid Build Coastguard Worker                if not any(getattr(p, "requires_grad", False) for p in param.values()):
1359*da0073e9SAndroid Build Coastguard Worker                    continue
1360*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
1361*da0073e9SAndroid Build Coastguard Worker                if not dist.has_rsample:
1362*da0073e9SAndroid Build Coastguard Worker                    continue
1363*da0073e9SAndroid Build Coastguard Worker                sample = dist.rsample()
1364*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1365*da0073e9SAndroid Build Coastguard Worker                    sample.requires_grad,
1366*da0073e9SAndroid Build Coastguard Worker                    msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .rsample() does not require grad",
1367*da0073e9SAndroid Build Coastguard Worker                )
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_support_type(self):
1370*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1371*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
1372*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
1373*da0073e9SAndroid Build Coastguard Worker                try:
1374*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
1375*da0073e9SAndroid Build Coastguard Worker                        type(dist.sample()) is type(dist.enumerate_support()),
1376*da0073e9SAndroid Build Coastguard Worker                        msg=(
1377*da0073e9SAndroid Build Coastguard Worker                            "{} example {}/{}, return type mismatch between "
1378*da0073e9SAndroid Build Coastguard Worker                            + "sample and enumerate_support."
1379*da0073e9SAndroid Build Coastguard Worker                        ).format(Dist.__name__, i + 1, len(params)),
1380*da0073e9SAndroid Build Coastguard Worker                    )
1381*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
1382*da0073e9SAndroid Build Coastguard Worker                    pass
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker    def test_lazy_property_grad(self):
1385*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, requires_grad=True)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        class Dummy:
1388*da0073e9SAndroid Build Coastguard Worker            @lazy_property
1389*da0073e9SAndroid Build Coastguard Worker            def y(self):
1390*da0073e9SAndroid Build Coastguard Worker                return x + 1
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker        def test():
1393*da0073e9SAndroid Build Coastguard Worker            x.grad = None
1394*da0073e9SAndroid Build Coastguard Worker            Dummy().y.backward()
1395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, torch.ones(1))
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        test()
1398*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1399*da0073e9SAndroid Build Coastguard Worker            test()
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(2)
1402*da0073e9SAndroid Build Coastguard Worker        cov = torch.eye(2, requires_grad=True)
1403*da0073e9SAndroid Build Coastguard Worker        distn = MultivariateNormal(mean, cov)
1404*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1405*da0073e9SAndroid Build Coastguard Worker            distn.scale_tril
1406*da0073e9SAndroid Build Coastguard Worker        distn.scale_tril.sum().backward()
1407*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(cov.grad)
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker    def test_has_examples(self):
1410*da0073e9SAndroid Build Coastguard Worker        distributions_with_examples = {e.Dist for e in _get_examples()}
1411*da0073e9SAndroid Build Coastguard Worker        for Dist in globals().values():
1412*da0073e9SAndroid Build Coastguard Worker            if (
1413*da0073e9SAndroid Build Coastguard Worker                isinstance(Dist, type)
1414*da0073e9SAndroid Build Coastguard Worker                and issubclass(Dist, Distribution)
1415*da0073e9SAndroid Build Coastguard Worker                and Dist is not Distribution
1416*da0073e9SAndroid Build Coastguard Worker                and Dist is not ExponentialFamily
1417*da0073e9SAndroid Build Coastguard Worker            ):
1418*da0073e9SAndroid Build Coastguard Worker                self.assertIn(
1419*da0073e9SAndroid Build Coastguard Worker                    Dist,
1420*da0073e9SAndroid Build Coastguard Worker                    distributions_with_examples,
1421*da0073e9SAndroid Build Coastguard Worker                    f"Please add {Dist.__name__} to the _get_examples list in test_distributions.py",
1422*da0073e9SAndroid Build Coastguard Worker                )
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    def test_support_attributes(self):
1425*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1426*da0073e9SAndroid Build Coastguard Worker            for param in params:
1427*da0073e9SAndroid Build Coastguard Worker                d = Dist(**param)
1428*da0073e9SAndroid Build Coastguard Worker                event_dim = len(d.event_shape)
1429*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(d.support.event_dim, event_dim)
1430*da0073e9SAndroid Build Coastguard Worker                try:
1431*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(Dist.support.event_dim, event_dim)
1432*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
1433*da0073e9SAndroid Build Coastguard Worker                    pass
1434*da0073e9SAndroid Build Coastguard Worker                is_discrete = d.support.is_discrete
1435*da0073e9SAndroid Build Coastguard Worker                try:
1436*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(Dist.support.is_discrete, is_discrete)
1437*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
1438*da0073e9SAndroid Build Coastguard Worker                    pass
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker    def test_distribution_expand(self):
1441*da0073e9SAndroid Build Coastguard Worker        shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
1442*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1443*da0073e9SAndroid Build Coastguard Worker            for param in params:
1444*da0073e9SAndroid Build Coastguard Worker                for shape in shapes:
1445*da0073e9SAndroid Build Coastguard Worker                    d = Dist(**param)
1446*da0073e9SAndroid Build Coastguard Worker                    expanded_shape = shape + d.batch_shape
1447*da0073e9SAndroid Build Coastguard Worker                    original_shape = d.batch_shape + d.event_shape
1448*da0073e9SAndroid Build Coastguard Worker                    expected_shape = shape + original_shape
1449*da0073e9SAndroid Build Coastguard Worker                    expanded = d.expand(batch_shape=list(expanded_shape))
1450*da0073e9SAndroid Build Coastguard Worker                    sample = expanded.sample()
1451*da0073e9SAndroid Build Coastguard Worker                    actual_shape = expanded.sample().shape
1452*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expanded.__class__, d.__class__)
1453*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(d.sample().shape, original_shape)
1454*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
1455*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual_shape, expected_shape)
1456*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expanded.batch_shape, expanded_shape)
1457*da0073e9SAndroid Build Coastguard Worker                    try:
1458*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1459*da0073e9SAndroid Build Coastguard Worker                            expanded.mean, d.mean.expand(expanded_shape + d.event_shape)
1460*da0073e9SAndroid Build Coastguard Worker                        )
1461*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
1462*da0073e9SAndroid Build Coastguard Worker                            expanded.variance,
1463*da0073e9SAndroid Build Coastguard Worker                            d.variance.expand(expanded_shape + d.event_shape),
1464*da0073e9SAndroid Build Coastguard Worker                        )
1465*da0073e9SAndroid Build Coastguard Worker                    except NotImplementedError:
1466*da0073e9SAndroid Build Coastguard Worker                        pass
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker    def test_distribution_subclass_expand(self):
1469*da0073e9SAndroid Build Coastguard Worker        expand_by = torch.Size((2,))
1470*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker            class SubClass(Dist):
1473*da0073e9SAndroid Build Coastguard Worker                pass
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker            for param in params:
1476*da0073e9SAndroid Build Coastguard Worker                d = SubClass(**param)
1477*da0073e9SAndroid Build Coastguard Worker                expanded_shape = expand_by + d.batch_shape
1478*da0073e9SAndroid Build Coastguard Worker                original_shape = d.batch_shape + d.event_shape
1479*da0073e9SAndroid Build Coastguard Worker                expected_shape = expand_by + original_shape
1480*da0073e9SAndroid Build Coastguard Worker                expanded = d.expand(batch_shape=expanded_shape)
1481*da0073e9SAndroid Build Coastguard Worker                sample = expanded.sample()
1482*da0073e9SAndroid Build Coastguard Worker                actual_shape = expanded.sample().shape
1483*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expanded.__class__, d.__class__)
1484*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(d.sample().shape, original_shape)
1485*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
1486*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_shape, expected_shape)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1489*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli(self):
1490*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1491*da0073e9SAndroid Build Coastguard Worker        r = torch.tensor(0.3, requires_grad=True)
1492*da0073e9SAndroid Build Coastguard Worker        s = 0.3
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3))
1494*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(Bernoulli(p).sample().requires_grad)
1495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,))
1496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(r).sample().size(), ())
1497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1498*da0073e9SAndroid Build Coastguard Worker            Bernoulli(r).sample((3, 2)).size(),
1499*da0073e9SAndroid Build Coastguard Worker            (
1500*da0073e9SAndroid Build Coastguard Worker                3,
1501*da0073e9SAndroid Build Coastguard Worker                2,
1502*da0073e9SAndroid Build Coastguard Worker            ),
1503*da0073e9SAndroid Build Coastguard Worker        )
1504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(s).sample().size(), ())
1505*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Bernoulli, (p,))
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, val, log_prob):
1508*da0073e9SAndroid Build Coastguard Worker            prob = p[idx]
1509*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, math.log(prob if val else 1 - prob))
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Bernoulli(p), ref_log_prob)
1512*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
1513*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker        # check entropy computation
1516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1517*da0073e9SAndroid Build Coastguard Worker            Bernoulli(p).entropy(),
1518*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0.6108, 0.5004, 0.6730]),
1519*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
1520*da0073e9SAndroid Build Coastguard Worker            rtol=0,
1521*da0073e9SAndroid Build Coastguard Worker        )
1522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
1523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1524*da0073e9SAndroid Build Coastguard Worker            Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4, rtol=0
1525*da0073e9SAndroid Build Coastguard Worker        )
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(torch.bernoulli)
1528*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.bernoulli_())
1529*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.bernoulli_(x.clone().detach()))
1530*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.bernoulli_(x))
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_enumerate_support(self):
1533*da0073e9SAndroid Build Coastguard Worker        examples = [
1534*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1]}, [[0], [1]]),
1535*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1, 0.9]}, [[0], [1]]),
1536*da0073e9SAndroid Build Coastguard Worker            ({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]),
1537*da0073e9SAndroid Build Coastguard Worker        ]
1538*da0073e9SAndroid Build Coastguard Worker        self._check_enumerate_support(Bernoulli, examples)
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_3d(self):
1541*da0073e9SAndroid Build Coastguard Worker        p = torch.full((2, 3, 5), 0.5).requires_grad_()
1542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
1543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1544*da0073e9SAndroid Build Coastguard Worker            Bernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5)
1545*da0073e9SAndroid Build Coastguard Worker        )
1546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1549*da0073e9SAndroid Build Coastguard Worker    def test_geometric(self):
1550*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1551*da0073e9SAndroid Build Coastguard Worker        r = torch.tensor(0.3, requires_grad=True)
1552*da0073e9SAndroid Build Coastguard Worker        s = 0.3
1553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
1554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(1).sample(), 0)
1555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(1).log_prob(torch.tensor(1.0)), -inf)
1556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(1).log_prob(torch.tensor(0.0)), 0)
1557*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(Geometric(p).sample().requires_grad)
1558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(r).sample((8,)).size(), (8,))
1559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(r).sample().size(), ())
1560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2))
1561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Geometric(s).sample().size(), ())
1562*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Geometric, (p,))
1563*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: Geometric(0))
1564*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, Geometric(r).rsample)
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.geometric_(0.2))
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1569*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1570*da0073e9SAndroid Build Coastguard Worker    def test_geometric_log_prob_and_entropy(self):
1571*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1572*da0073e9SAndroid Build Coastguard Worker        s = 0.3
1573*da0073e9SAndroid Build Coastguard Worker
1574*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, val, log_prob):
1575*da0073e9SAndroid Build Coastguard Worker            prob = p[idx].detach()
1576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val))
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Geometric(p), ref_log_prob)
1579*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker        # check entropy computation
1582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1583*da0073e9SAndroid Build Coastguard Worker            Geometric(p).entropy(),
1584*da0073e9SAndroid Build Coastguard Worker            scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(),
1585*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
1586*da0073e9SAndroid Build Coastguard Worker            rtol=0,
1587*da0073e9SAndroid Build Coastguard Worker        )
1588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1589*da0073e9SAndroid Build Coastguard Worker            float(Geometric(s).entropy()),
1590*da0073e9SAndroid Build Coastguard Worker            scipy.stats.geom(s, loc=-1).entropy().item(),
1591*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
1592*da0073e9SAndroid Build Coastguard Worker            rtol=0,
1593*da0073e9SAndroid Build Coastguard Worker        )
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1596*da0073e9SAndroid Build Coastguard Worker    def test_geometric_sample(self):
1597*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
1598*da0073e9SAndroid Build Coastguard Worker        for prob in [0.01, 0.18, 0.8]:
1599*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_discrete(
1600*da0073e9SAndroid Build Coastguard Worker                Geometric(prob),
1601*da0073e9SAndroid Build Coastguard Worker                scipy.stats.geom(p=prob, loc=-1),
1602*da0073e9SAndroid Build Coastguard Worker                f"Geometric(prob={prob})",
1603*da0073e9SAndroid Build Coastguard Worker            )
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1606*da0073e9SAndroid Build Coastguard Worker    def test_binomial(self):
1607*da0073e9SAndroid Build Coastguard Worker        p = torch.arange(0.05, 1, 0.1).requires_grad_()
1608*da0073e9SAndroid Build Coastguard Worker        for total_count in [1, 2, 10]:
1609*da0073e9SAndroid Build Coastguard Worker            self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p])
1610*da0073e9SAndroid Build Coastguard Worker            self._gradcheck_log_prob(
1611*da0073e9SAndroid Build Coastguard Worker                lambda p: Binomial(total_count, None, p.log()), [p]
1612*da0073e9SAndroid Build Coastguard Worker            )
1613*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker    test_binomial_half = set_default_dtype(torch.float16)(test_binomial)
1616*da0073e9SAndroid Build Coastguard Worker    test_binomial_bfloat16 = set_default_dtype(torch.bfloat16)(test_binomial)
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1619*da0073e9SAndroid Build Coastguard Worker    def test_binomial_sample(self):
1620*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
1621*da0073e9SAndroid Build Coastguard Worker        for prob in [0.01, 0.1, 0.5, 0.8, 0.9]:
1622*da0073e9SAndroid Build Coastguard Worker            for count in [2, 10, 100, 500]:
1623*da0073e9SAndroid Build Coastguard Worker                self._check_sampler_discrete(
1624*da0073e9SAndroid Build Coastguard Worker                    Binomial(total_count=count, probs=prob),
1625*da0073e9SAndroid Build Coastguard Worker                    scipy.stats.binom(count, prob),
1626*da0073e9SAndroid Build Coastguard Worker                    f"Binomial(total_count={count}, probs={prob})",
1627*da0073e9SAndroid Build Coastguard Worker                )
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1630*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1631*da0073e9SAndroid Build Coastguard Worker    def test_binomial_log_prob_and_entropy(self):
1632*da0073e9SAndroid Build Coastguard Worker        probs = torch.arange(0.05, 1, 0.1)
1633*da0073e9SAndroid Build Coastguard Worker        for total_count in [1, 2, 10]:
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker            def ref_log_prob(idx, x, log_prob):
1636*da0073e9SAndroid Build Coastguard Worker                p = probs.view(-1)[idx].item()
1637*da0073e9SAndroid Build Coastguard Worker                expected = scipy.stats.binom(total_count, p).logpmf(x)
1638*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker            self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
1641*da0073e9SAndroid Build Coastguard Worker            logits = probs_to_logits(probs, is_binary=True)
1642*da0073e9SAndroid Build Coastguard Worker            self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker            bin = Binomial(total_count, logits=logits)
1645*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1646*da0073e9SAndroid Build Coastguard Worker                bin.entropy(),
1647*da0073e9SAndroid Build Coastguard Worker                scipy.stats.binom(
1648*da0073e9SAndroid Build Coastguard Worker                    total_count, bin.probs.detach().numpy(), loc=-1
1649*da0073e9SAndroid Build Coastguard Worker                ).entropy(),
1650*da0073e9SAndroid Build Coastguard Worker                atol=1e-3,
1651*da0073e9SAndroid Build Coastguard Worker                rtol=0,
1652*da0073e9SAndroid Build Coastguard Worker            )
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker    def test_binomial_stable(self):
1655*da0073e9SAndroid Build Coastguard Worker        logits = torch.tensor([-100.0, 100.0], dtype=torch.float)
1656*da0073e9SAndroid Build Coastguard Worker        total_count = 1.0
1657*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.0, 0.0], dtype=torch.float)
1658*da0073e9SAndroid Build Coastguard Worker        log_prob = Binomial(total_count, logits=logits).log_prob(x)
1659*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isfinite(log_prob).all())
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker        # make sure that the grad at logits=0, value=0 is 0.5
1662*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.0, requires_grad=True)
1663*da0073e9SAndroid Build Coastguard Worker        y = Binomial(total_count, logits=x).log_prob(torch.tensor(0.0))
1664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad(y, x)[0], torch.tensor(-0.5))
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1667*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1668*da0073e9SAndroid Build Coastguard Worker    def test_binomial_log_prob_vectorized_count(self):
1669*da0073e9SAndroid Build Coastguard Worker        probs = torch.tensor([0.2, 0.7, 0.9])
1670*da0073e9SAndroid Build Coastguard Worker        for total_count, sample in [
1671*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])),
1672*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])),
1673*da0073e9SAndroid Build Coastguard Worker        ]:
1674*da0073e9SAndroid Build Coastguard Worker            log_prob = Binomial(total_count, probs).log_prob(sample)
1675*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.binom(
1676*da0073e9SAndroid Build Coastguard Worker                total_count.cpu().numpy(), probs.cpu().numpy()
1677*da0073e9SAndroid Build Coastguard Worker            ).logpmf(sample)
1678*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    def test_binomial_enumerate_support(self):
1681*da0073e9SAndroid Build Coastguard Worker        examples = [
1682*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]),
1683*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]),
1684*da0073e9SAndroid Build Coastguard Worker            (
1685*da0073e9SAndroid Build Coastguard Worker                {"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3},
1686*da0073e9SAndroid Build Coastguard Worker                [[[0]], [[1]], [[2]], [[3]]],
1687*da0073e9SAndroid Build Coastguard Worker            ),
1688*da0073e9SAndroid Build Coastguard Worker        ]
1689*da0073e9SAndroid Build Coastguard Worker        self._check_enumerate_support(Binomial, examples)
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1692*da0073e9SAndroid Build Coastguard Worker    def test_binomial_extreme_vals(self):
1693*da0073e9SAndroid Build Coastguard Worker        total_count = 100
1694*da0073e9SAndroid Build Coastguard Worker        bin0 = Binomial(total_count, 0)
1695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin0.sample(), 0)
1696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin0.log_prob(torch.tensor([0.0]))[0], 0, atol=1e-3, rtol=0)
1697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float(bin0.log_prob(torch.tensor([1.0])).exp()), 0)
1698*da0073e9SAndroid Build Coastguard Worker        bin1 = Binomial(total_count, 1)
1699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin1.sample(), total_count)
1700*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1701*da0073e9SAndroid Build Coastguard Worker            bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0
1702*da0073e9SAndroid Build Coastguard Worker        )
1703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1704*da0073e9SAndroid Build Coastguard Worker            float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0
1705*da0073e9SAndroid Build Coastguard Worker        )
1706*da0073e9SAndroid Build Coastguard Worker        zero_counts = torch.zeros(torch.Size((2, 2)))
1707*da0073e9SAndroid Build Coastguard Worker        bin2 = Binomial(zero_counts, 1)
1708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin2.sample(), zero_counts)
1709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
1710*da0073e9SAndroid Build Coastguard Worker
1711*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1712*da0073e9SAndroid Build Coastguard Worker    def test_binomial_vectorized_count(self):
1713*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
1714*da0073e9SAndroid Build Coastguard Worker        total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64)
1715*da0073e9SAndroid Build Coastguard Worker        bin0 = Binomial(total_count, torch.tensor(1.0))
1716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin0.sample(), total_count)
1717*da0073e9SAndroid Build Coastguard Worker        bin1 = Binomial(total_count, torch.tensor(0.5))
1718*da0073e9SAndroid Build Coastguard Worker        samples = bin1.sample(torch.Size((100000,)))
1719*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((samples <= total_count.type_as(samples)).all())
1720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0)
1721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0)
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1724*da0073e9SAndroid Build Coastguard Worker    def test_negative_binomial(self):
1725*da0073e9SAndroid Build Coastguard Worker        p = torch.arange(0.05, 1, 0.1).requires_grad_()
1726*da0073e9SAndroid Build Coastguard Worker        for total_count in [1, 2, 10]:
1727*da0073e9SAndroid Build Coastguard Worker            self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p])
1728*da0073e9SAndroid Build Coastguard Worker            self._gradcheck_log_prob(
1729*da0073e9SAndroid Build Coastguard Worker                lambda p: NegativeBinomial(total_count, None, p.log()), [p]
1730*da0073e9SAndroid Build Coastguard Worker            )
1731*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample)
1732*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy)
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1735*da0073e9SAndroid Build Coastguard Worker    def test_negative_binomial_log_prob(self):
1736*da0073e9SAndroid Build Coastguard Worker        probs = torch.arange(0.05, 1, 0.1)
1737*da0073e9SAndroid Build Coastguard Worker        for total_count in [1, 2, 10]:
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker            def ref_log_prob(idx, x, log_prob):
1740*da0073e9SAndroid Build Coastguard Worker                p = probs.view(-1)[idx].item()
1741*da0073e9SAndroid Build Coastguard Worker                expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
1742*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Worker            self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
1745*da0073e9SAndroid Build Coastguard Worker            logits = probs_to_logits(probs, is_binary=True)
1746*da0073e9SAndroid Build Coastguard Worker            self._check_log_prob(
1747*da0073e9SAndroid Build Coastguard Worker                NegativeBinomial(total_count, logits=logits), ref_log_prob
1748*da0073e9SAndroid Build Coastguard Worker            )
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1751*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1752*da0073e9SAndroid Build Coastguard Worker    def test_negative_binomial_log_prob_vectorized_count(self):
1753*da0073e9SAndroid Build Coastguard Worker        probs = torch.tensor([0.2, 0.7, 0.9])
1754*da0073e9SAndroid Build Coastguard Worker        for total_count, sample in [
1755*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])),
1756*da0073e9SAndroid Build Coastguard Worker            (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])),
1757*da0073e9SAndroid Build Coastguard Worker        ]:
1758*da0073e9SAndroid Build Coastguard Worker            log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
1759*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.nbinom(
1760*da0073e9SAndroid Build Coastguard Worker                total_count.cpu().numpy(), 1 - probs.cpu().numpy()
1761*da0073e9SAndroid Build Coastguard Worker            ).logpmf(sample)
1762*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1765*da0073e9SAndroid Build Coastguard Worker    def test_zero_excluded_binomial(self):
1766*da0073e9SAndroid Build Coastguard Worker        vals = Binomial(
1767*da0073e9SAndroid Build Coastguard Worker            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda()
1768*da0073e9SAndroid Build Coastguard Worker        ).sample(torch.Size((100000000,)))
1769*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((vals >= 0).all())
1770*da0073e9SAndroid Build Coastguard Worker        vals = Binomial(
1771*da0073e9SAndroid Build Coastguard Worker            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda()
1772*da0073e9SAndroid Build Coastguard Worker        ).sample(torch.Size((100000000,)))
1773*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((vals < 2).all())
1774*da0073e9SAndroid Build Coastguard Worker        vals = Binomial(
1775*da0073e9SAndroid Build Coastguard Worker            total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda()
1776*da0073e9SAndroid Build Coastguard Worker        ).sample(torch.Size((10000,)))
1777*da0073e9SAndroid Build Coastguard Worker        # vals should be roughly half zeroes, half ones
1778*da0073e9SAndroid Build Coastguard Worker        assert (vals == 0.0).sum() > 4000
1779*da0073e9SAndroid Build Coastguard Worker        assert (vals == 1.0).sum() > 4000
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1782*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_1d(self):
1783*da0073e9SAndroid Build Coastguard Worker        total_count = 10
1784*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1785*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
1786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
1787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
1788*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1789*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1790*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1793*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1794*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_1d_log_prob_and_entropy(self):
1795*da0073e9SAndroid Build Coastguard Worker        total_count = 10
1796*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1797*da0073e9SAndroid Build Coastguard Worker        dist = Multinomial(total_count, probs=p)
1798*da0073e9SAndroid Build Coastguard Worker        x = dist.sample()
1799*da0073e9SAndroid Build Coastguard Worker        log_prob = dist.log_prob(x)
1800*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
1801*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multinomial.logpmf(
1802*da0073e9SAndroid Build Coastguard Worker                x.numpy(), n=total_count, p=dist.probs.detach().numpy()
1803*da0073e9SAndroid Build Coastguard Worker            )
1804*da0073e9SAndroid Build Coastguard Worker        )
1805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log_prob, expected)
1806*da0073e9SAndroid Build Coastguard Worker
1807*da0073e9SAndroid Build Coastguard Worker        dist = Multinomial(total_count, logits=p.log())
1808*da0073e9SAndroid Build Coastguard Worker        x = dist.sample()
1809*da0073e9SAndroid Build Coastguard Worker        log_prob = dist.log_prob(x)
1810*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(
1811*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multinomial.logpmf(
1812*da0073e9SAndroid Build Coastguard Worker                x.numpy(), n=total_count, p=dist.probs.detach().numpy()
1813*da0073e9SAndroid Build Coastguard Worker            )
1814*da0073e9SAndroid Build Coastguard Worker        )
1815*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log_prob, expected)
1816*da0073e9SAndroid Build Coastguard Worker
1817*da0073e9SAndroid Build Coastguard Worker        expected = scipy.stats.multinomial.entropy(
1818*da0073e9SAndroid Build Coastguard Worker            total_count, dist.probs.detach().numpy()
1819*da0073e9SAndroid Build Coastguard Worker        )
1820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0)
1821*da0073e9SAndroid Build Coastguard Worker
1822*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1823*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_2d(self):
1824*da0073e9SAndroid Build Coastguard Worker        total_count = 10
1825*da0073e9SAndroid Build Coastguard Worker        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1826*da0073e9SAndroid Build Coastguard Worker        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1827*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor(probabilities, requires_grad=True)
1828*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(probabilities_1, requires_grad=True)
1829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
1830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1831*da0073e9SAndroid Build Coastguard Worker            Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
1832*da0073e9SAndroid Build Coastguard Worker        )
1833*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
1834*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)
1835*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1836*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of probs
1839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1840*da0073e9SAndroid Build Coastguard Worker            Multinomial(total_count, s).sample(),
1841*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64),
1842*da0073e9SAndroid Build Coastguard Worker        )
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1845*da0073e9SAndroid Build Coastguard Worker    def test_categorical_1d(self):
1846*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1847*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(Categorical(p).mean))
1848*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(Categorical(p).variance))
1849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample().size(), ())
1850*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(Categorical(p).sample().requires_grad)
1851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
1852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample((1,)).size(), (1,))
1853*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Categorical, (p,))
1854*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, Categorical(p).rsample)
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1857*da0073e9SAndroid Build Coastguard Worker    def test_categorical_2d(self):
1858*da0073e9SAndroid Build Coastguard Worker        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1859*da0073e9SAndroid Build Coastguard Worker        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1860*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor(probabilities, requires_grad=True)
1861*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(probabilities_1, requires_grad=True)
1862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).mean.size(), (2,))
1863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).variance.size(), (2,))
1864*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(Categorical(p).mean))
1865*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(Categorical(p).variance))
1866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample().size(), (2,))
1867*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
1868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2))
1869*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Categorical, (p,))
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of probs
1872*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)
1873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1874*da0073e9SAndroid Build Coastguard Worker            Categorical(s).sample(sample_shape=(2,)), torch.tensor([[0, 1], [0, 1]])
1875*da0073e9SAndroid Build Coastguard Worker        )
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, val, log_prob):
1878*da0073e9SAndroid Build Coastguard Worker            sample_prob = p[idx][val] / p[idx].sum()
1879*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, math.log(sample_prob))
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Categorical(p), ref_log_prob)
1882*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker        # check entropy computation
1885*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1886*da0073e9SAndroid Build Coastguard Worker            Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0
1887*da0073e9SAndroid Build Coastguard Worker        )
1888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
1889*da0073e9SAndroid Build Coastguard Worker        # issue gh-40553
1890*da0073e9SAndroid Build Coastguard Worker        logits = p.log()
1891*da0073e9SAndroid Build Coastguard Worker        logits[1, 1] = logits[0, 2] = float("-inf")
1892*da0073e9SAndroid Build Coastguard Worker        e = Categorical(logits=logits).entropy()
1893*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0)
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker    def test_categorical_enumerate_support(self):
1896*da0073e9SAndroid Build Coastguard Worker        examples = [
1897*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]),
1898*da0073e9SAndroid Build Coastguard Worker            ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]),
1899*da0073e9SAndroid Build Coastguard Worker        ]
1900*da0073e9SAndroid Build Coastguard Worker        self._check_enumerate_support(Categorical, examples)
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1903*da0073e9SAndroid Build Coastguard Worker    def test_one_hot_categorical_1d(self):
1904*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
1906*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(OneHotCategorical(p).sample().requires_grad)
1907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
1908*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3))
1909*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(OneHotCategorical, (p,))
1910*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1913*da0073e9SAndroid Build Coastguard Worker    def test_one_hot_categorical_2d(self):
1914*da0073e9SAndroid Build Coastguard Worker        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1915*da0073e9SAndroid Build Coastguard Worker        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1916*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor(probabilities, requires_grad=True)
1917*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(probabilities_1, requires_grad=True)
1918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
1919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1920*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)
1921*da0073e9SAndroid Build Coastguard Worker        )
1922*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
1923*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(OneHotCategorical, (p,))
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker        dist = OneHotCategorical(p)
1926*da0073e9SAndroid Build Coastguard Worker        x = dist.sample()
1927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    def test_one_hot_categorical_enumerate_support(self):
1930*da0073e9SAndroid Build Coastguard Worker        examples = [
1931*da0073e9SAndroid Build Coastguard Worker            ({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
1932*da0073e9SAndroid Build Coastguard Worker            ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]),
1933*da0073e9SAndroid Build Coastguard Worker        ]
1934*da0073e9SAndroid Build Coastguard Worker        self._check_enumerate_support(OneHotCategorical, examples)
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker    def test_poisson_forward_ad(self):
1937*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(torch.poisson)
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker    def test_poisson_shape(self):
1940*da0073e9SAndroid Build Coastguard Worker        rate = torch.randn(2, 3).abs().requires_grad_()
1941*da0073e9SAndroid Build Coastguard Worker        rate_1d = torch.randn(1).abs().requires_grad_()
1942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Poisson(rate).sample().size(), (2, 3))
1943*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
1944*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
1945*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1))
1946*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,))
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1949*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1950*da0073e9SAndroid Build Coastguard Worker    def test_poisson_log_prob(self):
1951*da0073e9SAndroid Build Coastguard Worker        rate = torch.randn(2, 3).abs().requires_grad_()
1952*da0073e9SAndroid Build Coastguard Worker        rate_1d = torch.randn(1).abs().requires_grad_()
1953*da0073e9SAndroid Build Coastguard Worker        rate_zero = torch.zeros([], requires_grad=True)
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(ref_rate, idx, x, log_prob):
1956*da0073e9SAndroid Build Coastguard Worker            l = ref_rate.view(-1)[idx].detach()
1957*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.poisson.logpmf(x, l)
1958*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
1959*da0073e9SAndroid Build Coastguard Worker
1960*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)
1961*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args))
1962*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(
1963*da0073e9SAndroid Build Coastguard Worker            Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args)
1964*da0073e9SAndroid Build Coastguard Worker        )
1965*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Poisson, (rate,))
1966*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Poisson, (rate_1d,))
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker        # We cannot check gradients automatically for zero rates because the finite difference
1969*da0073e9SAndroid Build Coastguard Worker        # approximation enters the forbidden parameter space. We instead compare with the
1970*da0073e9SAndroid Build Coastguard Worker        # theoretical results.
1971*da0073e9SAndroid Build Coastguard Worker        dist = Poisson(rate_zero)
1972*da0073e9SAndroid Build Coastguard Worker        dist.log_prob(torch.ones_like(rate_zero)).backward()
1973*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rate_zero.grad, torch.inf)
1974*da0073e9SAndroid Build Coastguard Worker
1975*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1976*da0073e9SAndroid Build Coastguard Worker    def test_poisson_sample(self):
1977*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
1978*da0073e9SAndroid Build Coastguard Worker        saved_dtype = torch.get_default_dtype()
1979*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double, torch.bfloat16, torch.half]:
1980*da0073e9SAndroid Build Coastguard Worker            torch.set_default_dtype(dtype)
1981*da0073e9SAndroid Build Coastguard Worker            for rate in [0.1, 1.0, 5.0]:
1982*da0073e9SAndroid Build Coastguard Worker                self._check_sampler_discrete(
1983*da0073e9SAndroid Build Coastguard Worker                    Poisson(rate),
1984*da0073e9SAndroid Build Coastguard Worker                    scipy.stats.poisson(rate),
1985*da0073e9SAndroid Build Coastguard Worker                    f"Poisson(lambda={rate})",
1986*da0073e9SAndroid Build Coastguard Worker                    failure_rate=1e-3,
1987*da0073e9SAndroid Build Coastguard Worker                )
1988*da0073e9SAndroid Build Coastguard Worker        torch.set_default_dtype(saved_dtype)
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1991*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1992*da0073e9SAndroid Build Coastguard Worker    def test_poisson_gpu_sample(self):
1993*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
1994*da0073e9SAndroid Build Coastguard Worker        for rate in [0.12, 0.9, 4.0]:
1995*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_discrete(
1996*da0073e9SAndroid Build Coastguard Worker                Poisson(torch.tensor([rate]).cuda()),
1997*da0073e9SAndroid Build Coastguard Worker                scipy.stats.poisson(rate),
1998*da0073e9SAndroid Build Coastguard Worker                f"Poisson(lambda={rate}, cuda)",
1999*da0073e9SAndroid Build Coastguard Worker                failure_rate=1e-3,
2000*da0073e9SAndroid Build Coastguard Worker            )
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2003*da0073e9SAndroid Build Coastguard Worker    def test_relaxed_bernoulli(self):
2004*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
2005*da0073e9SAndroid Build Coastguard Worker        r = torch.tensor(0.3, requires_grad=True)
2006*da0073e9SAndroid Build Coastguard Worker        s = 0.3
2007*da0073e9SAndroid Build Coastguard Worker        temp = torch.tensor(0.67, requires_grad=True)
2008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3))
2009*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad)
2010*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,))
2011*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ())
2012*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2013*da0073e9SAndroid Build Coastguard Worker            RelaxedBernoulli(temp, r).sample((3, 2)).size(),
2014*da0073e9SAndroid Build Coastguard Worker            (
2015*da0073e9SAndroid Build Coastguard Worker                3,
2016*da0073e9SAndroid Build Coastguard Worker                2,
2017*da0073e9SAndroid Build Coastguard Worker            ),
2018*da0073e9SAndroid Build Coastguard Worker        )
2019*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ())
2020*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(RelaxedBernoulli, (temp, p))
2021*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(RelaxedBernoulli, (temp, r))
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker        # test that rsample doesn't fail
2024*da0073e9SAndroid Build Coastguard Worker        s = RelaxedBernoulli(temp, p).rsample()
2025*da0073e9SAndroid Build Coastguard Worker        s.backward(torch.ones_like(s))
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2028*da0073e9SAndroid Build Coastguard Worker    def test_rounded_relaxed_bernoulli(self):
2029*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2030*da0073e9SAndroid Build Coastguard Worker
2031*da0073e9SAndroid Build Coastguard Worker        class Rounded:
2032*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dist):
2033*da0073e9SAndroid Build Coastguard Worker                self.dist = dist
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker            def sample(self, *args, **kwargs):
2036*da0073e9SAndroid Build Coastguard Worker                return torch.round(self.dist.sample(*args, **kwargs))
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker        for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
2039*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_discrete(
2040*da0073e9SAndroid Build Coastguard Worker                Rounded(RelaxedBernoulli(temp, probs)),
2041*da0073e9SAndroid Build Coastguard Worker                scipy.stats.bernoulli(probs),
2042*da0073e9SAndroid Build Coastguard Worker                f"Rounded(RelaxedBernoulli(temp={temp}, probs={probs}))",
2043*da0073e9SAndroid Build Coastguard Worker                failure_rate=1e-3,
2044*da0073e9SAndroid Build Coastguard Worker            )
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker        for probs in [0.001, 0.2, 0.999]:
2047*da0073e9SAndroid Build Coastguard Worker            equal_probs = torch.tensor(0.5)
2048*da0073e9SAndroid Build Coastguard Worker            dist = RelaxedBernoulli(1e10, probs)
2049*da0073e9SAndroid Build Coastguard Worker            s = dist.rsample()
2050*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(equal_probs, s)
2051*da0073e9SAndroid Build Coastguard Worker
2052*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2053*da0073e9SAndroid Build Coastguard Worker    def test_relaxed_one_hot_categorical_1d(self):
2054*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
2055*da0073e9SAndroid Build Coastguard Worker        temp = torch.tensor(0.67, requires_grad=True)
2056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2057*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,)
2058*da0073e9SAndroid Build Coastguard Worker        )
2059*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
2060*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad
2061*da0073e9SAndroid Build Coastguard Worker        )
2062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2063*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(),
2064*da0073e9SAndroid Build Coastguard Worker            (2, 2, 3),
2065*da0073e9SAndroid Build Coastguard Worker        )
2066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2067*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(),
2068*da0073e9SAndroid Build Coastguard Worker            (1, 3),
2069*da0073e9SAndroid Build Coastguard Worker        )
2070*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2071*da0073e9SAndroid Build Coastguard Worker            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p)
2072*da0073e9SAndroid Build Coastguard Worker        )
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2075*da0073e9SAndroid Build Coastguard Worker    def test_relaxed_one_hot_categorical_2d(self):
2076*da0073e9SAndroid Build Coastguard Worker        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
2077*da0073e9SAndroid Build Coastguard Worker        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
2078*da0073e9SAndroid Build Coastguard Worker        temp = torch.tensor([3.0], requires_grad=True)
2079*da0073e9SAndroid Build Coastguard Worker        # The lower the temperature, the more unstable the log_prob gradcheck is
2080*da0073e9SAndroid Build Coastguard Worker        # w.r.t. the sample. Values below 0.25 empirically fail the default tol.
2081*da0073e9SAndroid Build Coastguard Worker        temp_2 = torch.tensor([0.25], requires_grad=True)
2082*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor(probabilities, requires_grad=True)
2083*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(probabilities_1, requires_grad=True)
2084*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
2085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2086*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(),
2087*da0073e9SAndroid Build Coastguard Worker            (3, 4, 2, 3),
2088*da0073e9SAndroid Build Coastguard Worker        )
2089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2090*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3)
2091*da0073e9SAndroid Build Coastguard Worker        )
2092*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2093*da0073e9SAndroid Build Coastguard Worker            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p)
2094*da0073e9SAndroid Build Coastguard Worker        )
2095*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2096*da0073e9SAndroid Build Coastguard Worker            lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False),
2097*da0073e9SAndroid Build Coastguard Worker            (temp_2, p),
2098*da0073e9SAndroid Build Coastguard Worker        )
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2101*da0073e9SAndroid Build Coastguard Worker    def test_argmax_relaxed_categorical(self):
2102*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker        class ArgMax:
2105*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dist):
2106*da0073e9SAndroid Build Coastguard Worker                self.dist = dist
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker            def sample(self, *args, **kwargs):
2109*da0073e9SAndroid Build Coastguard Worker                s = self.dist.sample(*args, **kwargs)
2110*da0073e9SAndroid Build Coastguard Worker                _, idx = torch.max(s, -1)
2111*da0073e9SAndroid Build Coastguard Worker                return idx
2112*da0073e9SAndroid Build Coastguard Worker
2113*da0073e9SAndroid Build Coastguard Worker        class ScipyCategorical:
2114*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dist):
2115*da0073e9SAndroid Build Coastguard Worker                self.dist = dist
2116*da0073e9SAndroid Build Coastguard Worker
2117*da0073e9SAndroid Build Coastguard Worker            def pmf(self, samples):
2118*da0073e9SAndroid Build Coastguard Worker                new_samples = np.zeros(samples.shape + self.dist.p.shape)
2119*da0073e9SAndroid Build Coastguard Worker                new_samples[np.arange(samples.shape[0]), samples] = 1
2120*da0073e9SAndroid Build Coastguard Worker                return self.dist.pmf(new_samples)
2121*da0073e9SAndroid Build Coastguard Worker
2122*da0073e9SAndroid Build Coastguard Worker        for probs, temp in product(
2123*da0073e9SAndroid Build Coastguard Worker            [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0]
2124*da0073e9SAndroid Build Coastguard Worker        ):
2125*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_discrete(
2126*da0073e9SAndroid Build Coastguard Worker                ArgMax(RelaxedOneHotCategorical(temp, probs)),
2127*da0073e9SAndroid Build Coastguard Worker                ScipyCategorical(scipy.stats.multinomial(1, probs)),
2128*da0073e9SAndroid Build Coastguard Worker                f"Rounded(RelaxedOneHotCategorical(temp={temp}, probs={probs}))",
2129*da0073e9SAndroid Build Coastguard Worker                failure_rate=1e-3,
2130*da0073e9SAndroid Build Coastguard Worker            )
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker        for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]:
2133*da0073e9SAndroid Build Coastguard Worker            equal_probs = torch.ones(probs.size()) / probs.size()[0]
2134*da0073e9SAndroid Build Coastguard Worker            dist = RelaxedOneHotCategorical(1e10, probs)
2135*da0073e9SAndroid Build Coastguard Worker            s = dist.rsample()
2136*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(equal_probs, s)
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2139*da0073e9SAndroid Build Coastguard Worker    def test_uniform(self):
2140*da0073e9SAndroid Build Coastguard Worker        low = torch.zeros(5, 5, requires_grad=True)
2141*da0073e9SAndroid Build Coastguard Worker        high = (torch.ones(5, 5) * 3).requires_grad_()
2142*da0073e9SAndroid Build Coastguard Worker        low_1d = torch.zeros(1, requires_grad=True)
2143*da0073e9SAndroid Build Coastguard Worker        high_1d = (torch.ones(1) * 3).requires_grad_()
2144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
2146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
2149*da0073e9SAndroid Build Coastguard Worker
2150*da0073e9SAndroid Build Coastguard Worker        # Check log_prob computation when value outside range
2151*da0073e9SAndroid Build Coastguard Worker        uniform = Uniform(low_1d, high_1d, validate_args=False)
2152*da0073e9SAndroid Build Coastguard Worker        above_high = torch.tensor([4.0])
2153*da0073e9SAndroid Build Coastguard Worker        below_low = torch.tensor([-1.0])
2154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.log_prob(below_low).item(), -inf)
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker        # check cdf computation when value outside range
2158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.cdf(below_low).item(), 0)
2159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.cdf(above_high).item(), 1)
2160*da0073e9SAndroid Build Coastguard Worker
2161*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2162*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Uniform, (low, high))
2163*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Uniform, (low, 1.0))
2164*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Uniform, (0.0, high))
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
2167*da0073e9SAndroid Build Coastguard Worker        rand = low.new(low.size()).uniform_()
2168*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
2169*da0073e9SAndroid Build Coastguard Worker        u = Uniform(low, high).rsample()
2170*da0073e9SAndroid Build Coastguard Worker        u.backward(torch.ones_like(u))
2171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(low.grad, 1 - rand)
2172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(high.grad, rand)
2173*da0073e9SAndroid Build Coastguard Worker        low.grad.zero_()
2174*da0073e9SAndroid Build Coastguard Worker        high.grad.zero_()
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.uniform_())
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2179*da0073e9SAndroid Build Coastguard Worker    def test_vonmises_sample(self):
2180*da0073e9SAndroid Build Coastguard Worker        for loc in [0.0, math.pi / 2.0]:
2181*da0073e9SAndroid Build Coastguard Worker            for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]:
2182*da0073e9SAndroid Build Coastguard Worker                self._check_sampler_sampler(
2183*da0073e9SAndroid Build Coastguard Worker                    VonMises(loc, concentration),
2184*da0073e9SAndroid Build Coastguard Worker                    scipy.stats.vonmises(loc=loc, kappa=concentration),
2185*da0073e9SAndroid Build Coastguard Worker                    f"VonMises(loc={loc}, concentration={concentration})",
2186*da0073e9SAndroid Build Coastguard Worker                    num_samples=int(1e5),
2187*da0073e9SAndroid Build Coastguard Worker                    circular=True,
2188*da0073e9SAndroid Build Coastguard Worker                )
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker    def test_vonmises_logprob(self):
2191*da0073e9SAndroid Build Coastguard Worker        concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
2192*da0073e9SAndroid Build Coastguard Worker        for concentration in concentrations:
2193*da0073e9SAndroid Build Coastguard Worker            grid = torch.arange(0.0, 2 * math.pi, 1e-4)
2194*da0073e9SAndroid Build Coastguard Worker            prob = VonMises(0.0, concentration).log_prob(grid).exp()
2195*da0073e9SAndroid Build Coastguard Worker            norm = prob.mean().item() * 2 * math.pi
2196*da0073e9SAndroid Build Coastguard Worker            self.assertLess(abs(norm - 1), 1e-3)
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2199*da0073e9SAndroid Build Coastguard Worker    def test_cauchy(self):
2200*da0073e9SAndroid Build Coastguard Worker        loc = torch.zeros(5, 5, requires_grad=True)
2201*da0073e9SAndroid Build Coastguard Worker        scale = torch.ones(5, 5, requires_grad=True)
2202*da0073e9SAndroid Build Coastguard Worker        loc_1d = torch.zeros(1, requires_grad=True)
2203*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.ones(1, requires_grad=True)
2204*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
2205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf)
2206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
2207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5))
2208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
2209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))
2211*da0073e9SAndroid Build Coastguard Worker
2212*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2213*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Cauchy, (loc, scale))
2214*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Cauchy, (loc, 1.0))
2215*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Cauchy, (0.0, scale))
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
2218*da0073e9SAndroid Build Coastguard Worker        eps = loc.new(loc.size()).cauchy_()
2219*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
2220*da0073e9SAndroid Build Coastguard Worker        c = Cauchy(loc, scale).rsample()
2221*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.ones_like(c))
2222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loc.grad, torch.ones_like(scale))
2223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale.grad, eps)
2224*da0073e9SAndroid Build Coastguard Worker        loc.grad.zero_()
2225*da0073e9SAndroid Build Coastguard Worker        scale.grad.zero_()
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.cauchy_())
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2230*da0073e9SAndroid Build Coastguard Worker    def test_halfcauchy(self):
2231*da0073e9SAndroid Build Coastguard Worker        scale = torch.ones(5, 5, requires_grad=True)
2232*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.ones(1, requires_grad=True)
2233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
2234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(scale_1d).variance, inf)
2235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
2236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
2237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
2238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
2239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))
2240*da0073e9SAndroid Build Coastguard Worker
2241*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2242*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(HalfCauchy, (scale,))
2243*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(HalfCauchy, (1.0,))
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
2246*da0073e9SAndroid Build Coastguard Worker        eps = scale.new(scale.size()).cauchy_().abs_()
2247*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
2248*da0073e9SAndroid Build Coastguard Worker        c = HalfCauchy(scale).rsample()
2249*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.ones_like(c))
2250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale.grad, eps)
2251*da0073e9SAndroid Build Coastguard Worker        scale.grad.zero_()
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2254*da0073e9SAndroid Build Coastguard Worker    def test_halfnormal(self):
2255*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 5).abs().requires_grad_()
2256*da0073e9SAndroid Build Coastguard Worker        std_1d = torch.randn(1).abs().requires_grad_()
2257*da0073e9SAndroid Build Coastguard Worker        std_delta = torch.tensor([1e-5, 1e-5])
2258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
2259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
2260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
2261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
2262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(0.6).sample((1,)).size(), (1,))
2263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))
2264*da0073e9SAndroid Build Coastguard Worker
2265*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of std
2266*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2268*da0073e9SAndroid Build Coastguard Worker            HalfNormal(std_delta).sample(sample_shape=(1, 2)),
2269*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
2270*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
2271*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2272*da0073e9SAndroid Build Coastguard Worker        )
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(HalfNormal, (std,))
2275*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(HalfNormal, (1.0,))
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker        # check .log_prob() can broadcast.
2278*da0073e9SAndroid Build Coastguard Worker        dist = HalfNormal(torch.ones(2, 1, 4))
2279*da0073e9SAndroid Build Coastguard Worker        log_prob = dist.log_prob(torch.ones(3, 1))
2280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log_prob.shape, (2, 3, 4))
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2283*da0073e9SAndroid Build Coastguard Worker    def test_halfnormal_logprob(self):
2284*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 1).abs().requires_grad_()
2285*da0073e9SAndroid Build Coastguard Worker
2286*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
2287*da0073e9SAndroid Build Coastguard Worker            s = std.view(-1)[idx].detach()
2288*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.halfnorm(scale=s).logpdf(x)
2289*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2290*da0073e9SAndroid Build Coastguard Worker
2291*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(HalfNormal(std), ref_log_prob)
2292*da0073e9SAndroid Build Coastguard Worker
2293*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2294*da0073e9SAndroid Build Coastguard Worker    def test_halfnormal_sample(self):
2295*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2296*da0073e9SAndroid Build Coastguard Worker        for std in [0.1, 1.0, 10.0]:
2297*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
2298*da0073e9SAndroid Build Coastguard Worker                HalfNormal(std),
2299*da0073e9SAndroid Build Coastguard Worker                scipy.stats.halfnorm(scale=std),
2300*da0073e9SAndroid Build Coastguard Worker                f"HalfNormal(scale={std})",
2301*da0073e9SAndroid Build Coastguard Worker            )
2302*da0073e9SAndroid Build Coastguard Worker
2303*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2304*da0073e9SAndroid Build Coastguard Worker    def test_inversegamma(self):
2305*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(2, 3).exp().requires_grad_()
2306*da0073e9SAndroid Build Coastguard Worker        beta = torch.randn(2, 3).exp().requires_grad_()
2307*da0073e9SAndroid Build Coastguard Worker        alpha_1d = torch.randn(1).exp().requires_grad_()
2308*da0073e9SAndroid Build Coastguard Worker        beta_1d = torch.randn(1).exp().requires_grad_()
2309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(alpha, beta).sample().size(), (2, 3))
2310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
2311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
2312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample().size(), (1,))
2313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(0.5, 0.5).sample().size(), ())
2314*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(InverseGamma(0.5, 0.5).sample((1,)).size(), (1,))
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(InverseGamma, (alpha, beta))
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker        dist = InverseGamma(torch.ones(4), torch.ones(2, 1, 1))
2319*da0073e9SAndroid Build Coastguard Worker        log_prob = dist.log_prob(torch.ones(3, 1))
2320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log_prob.shape, (2, 3, 4))
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2323*da0073e9SAndroid Build Coastguard Worker    def test_inversegamma_sample(self):
2324*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2325*da0073e9SAndroid Build Coastguard Worker        for concentration, rate in product([2, 5], [0.1, 1.0, 10.0]):
2326*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
2327*da0073e9SAndroid Build Coastguard Worker                InverseGamma(concentration, rate),
2328*da0073e9SAndroid Build Coastguard Worker                scipy.stats.invgamma(concentration, scale=rate),
2329*da0073e9SAndroid Build Coastguard Worker                "InverseGamma()",
2330*da0073e9SAndroid Build Coastguard Worker            )
2331*da0073e9SAndroid Build Coastguard Worker
2332*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2333*da0073e9SAndroid Build Coastguard Worker    def test_lognormal(self):
2334*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 5, requires_grad=True)
2335*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 5).abs().requires_grad_()
2336*da0073e9SAndroid Build Coastguard Worker        mean_1d = torch.randn(1, requires_grad=True)
2337*da0073e9SAndroid Build Coastguard Worker        std_1d = torch.randn(1).abs().requires_grad_()
2338*da0073e9SAndroid Build Coastguard Worker        mean_delta = torch.tensor([1.0, 0.0])
2339*da0073e9SAndroid Build Coastguard Worker        std_delta = torch.tensor([1e-5, 1e-5])
2340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
2341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
2342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
2343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
2344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(0.2, 0.6).sample((1,)).size(), (1,))
2345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,))
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of mean, std
2348*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2350*da0073e9SAndroid Build Coastguard Worker            LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
2351*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
2352*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
2353*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2354*da0073e9SAndroid Build Coastguard Worker        )
2355*da0073e9SAndroid Build Coastguard Worker
2356*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(LogNormal, (mean, std))
2357*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(LogNormal, (mean, 1.0))
2358*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(LogNormal, (0.0, std))
2359*da0073e9SAndroid Build Coastguard Worker
2360*da0073e9SAndroid Build Coastguard Worker        # check .log_prob() can broadcast.
2361*da0073e9SAndroid Build Coastguard Worker        dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
2362*da0073e9SAndroid Build Coastguard Worker        log_prob = dist.log_prob(torch.ones(3, 1))
2363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log_prob.shape, (2, 3, 4))
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.log_normal_())
2366*da0073e9SAndroid Build Coastguard Worker
2367*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2368*da0073e9SAndroid Build Coastguard Worker    def test_lognormal_logprob(self):
2369*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 1, requires_grad=True)
2370*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 1).abs().requires_grad_()
2371*da0073e9SAndroid Build Coastguard Worker
2372*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
2373*da0073e9SAndroid Build Coastguard Worker            m = mean.view(-1)[idx].detach()
2374*da0073e9SAndroid Build Coastguard Worker            s = std.view(-1)[idx].detach()
2375*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
2376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(LogNormal(mean, std), ref_log_prob)
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2381*da0073e9SAndroid Build Coastguard Worker    def test_lognormal_sample(self):
2382*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2383*da0073e9SAndroid Build Coastguard Worker        for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2384*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
2385*da0073e9SAndroid Build Coastguard Worker                LogNormal(mean, std),
2386*da0073e9SAndroid Build Coastguard Worker                scipy.stats.lognorm(scale=math.exp(mean), s=std),
2387*da0073e9SAndroid Build Coastguard Worker                f"LogNormal(loc={mean}, scale={std})",
2388*da0073e9SAndroid Build Coastguard Worker            )
2389*da0073e9SAndroid Build Coastguard Worker
2390*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2391*da0073e9SAndroid Build Coastguard Worker    def test_logisticnormal(self):
2392*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
2393*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 5).requires_grad_()
2394*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 5).abs().requires_grad_()
2395*da0073e9SAndroid Build Coastguard Worker        mean_1d = torch.randn(1).requires_grad_()
2396*da0073e9SAndroid Build Coastguard Worker        std_1d = torch.randn(1).abs().requires_grad_()
2397*da0073e9SAndroid Build Coastguard Worker        mean_delta = torch.tensor([1.0, 0.0])
2398*da0073e9SAndroid Build Coastguard Worker        std_delta = torch.tensor([1e-5, 1e-5])
2399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
2400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
2401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
2402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
2403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(0.2, 0.6).sample().size(), (2,))
2404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of mean, std
2407*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2409*da0073e9SAndroid Build Coastguard Worker            LogisticNormal(mean_delta, std_delta).sample(),
2410*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
2411*da0073e9SAndroid Build Coastguard Worker                [
2412*da0073e9SAndroid Build Coastguard Worker                    math.exp(1) / (1.0 + 1.0 + math.exp(1)),
2413*da0073e9SAndroid Build Coastguard Worker                    1.0 / (1.0 + 1.0 + math.exp(1)),
2414*da0073e9SAndroid Build Coastguard Worker                    1.0 / (1.0 + 1.0 + math.exp(1)),
2415*da0073e9SAndroid Build Coastguard Worker                ]
2416*da0073e9SAndroid Build Coastguard Worker            ),
2417*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
2418*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2419*da0073e9SAndroid Build Coastguard Worker        )
2420*da0073e9SAndroid Build Coastguard Worker
2421*da0073e9SAndroid Build Coastguard Worker        # TODO: gradcheck seems to mutate the sample values so that the simplex
2422*da0073e9SAndroid Build Coastguard Worker        # constraint fails by a very small margin.
2423*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2424*da0073e9SAndroid Build Coastguard Worker            lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std)
2425*da0073e9SAndroid Build Coastguard Worker        )
2426*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2427*da0073e9SAndroid Build Coastguard Worker            lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0)
2428*da0073e9SAndroid Build Coastguard Worker        )
2429*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2430*da0073e9SAndroid Build Coastguard Worker            lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std)
2431*da0073e9SAndroid Build Coastguard Worker        )
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2434*da0073e9SAndroid Build Coastguard Worker    def test_logisticnormal_logprob(self):
2435*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 7).requires_grad_()
2436*da0073e9SAndroid Build Coastguard Worker        std = torch.randn(5, 7).abs().requires_grad_()
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        # Smoke test for now
2439*da0073e9SAndroid Build Coastguard Worker        # TODO: Once _check_log_prob works with multidimensional distributions,
2440*da0073e9SAndroid Build Coastguard Worker        #       add proper testing of the log probabilities.
2441*da0073e9SAndroid Build Coastguard Worker        dist = LogisticNormal(mean, std)
2442*da0073e9SAndroid Build Coastguard Worker        assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,)
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker    def _get_logistic_normal_ref_sampler(self, base_dist):
2445*da0073e9SAndroid Build Coastguard Worker        def _sampler(num_samples):
2446*da0073e9SAndroid Build Coastguard Worker            x = base_dist.rvs(num_samples)
2447*da0073e9SAndroid Build Coastguard Worker            offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1))
2448*da0073e9SAndroid Build Coastguard Worker            z = 1.0 / (1.0 + np.exp(offset - x))
2449*da0073e9SAndroid Build Coastguard Worker            z_cumprod = np.cumprod(1 - z, axis=-1)
2450*da0073e9SAndroid Build Coastguard Worker            y1 = np.pad(z, ((0, 0), (0, 1)), mode="constant", constant_values=1.0)
2451*da0073e9SAndroid Build Coastguard Worker            y2 = np.pad(
2452*da0073e9SAndroid Build Coastguard Worker                z_cumprod, ((0, 0), (1, 0)), mode="constant", constant_values=1.0
2453*da0073e9SAndroid Build Coastguard Worker            )
2454*da0073e9SAndroid Build Coastguard Worker            return y1 * y2
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker        return _sampler
2457*da0073e9SAndroid Build Coastguard Worker
2458*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2459*da0073e9SAndroid Build Coastguard Worker    def test_logisticnormal_sample(self):
2460*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2461*da0073e9SAndroid Build Coastguard Worker        means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)])
2462*da0073e9SAndroid Build Coastguard Worker        covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)])
2463*da0073e9SAndroid Build Coastguard Worker        for mean, cov in product(means, covs):
2464*da0073e9SAndroid Build Coastguard Worker            base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
2465*da0073e9SAndroid Build Coastguard Worker            ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
2466*da0073e9SAndroid Build Coastguard Worker            ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist)
2467*da0073e9SAndroid Build Coastguard Worker            mean_th = torch.tensor(mean)
2468*da0073e9SAndroid Build Coastguard Worker            std_th = torch.tensor(np.sqrt(np.diag(cov)))
2469*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
2470*da0073e9SAndroid Build Coastguard Worker                LogisticNormal(mean_th, std_th),
2471*da0073e9SAndroid Build Coastguard Worker                ref_dist,
2472*da0073e9SAndroid Build Coastguard Worker                f"LogisticNormal(loc={mean_th}, scale={std_th})",
2473*da0073e9SAndroid Build Coastguard Worker                multivariate=True,
2474*da0073e9SAndroid Build Coastguard Worker            )
2475*da0073e9SAndroid Build Coastguard Worker
2476*da0073e9SAndroid Build Coastguard Worker    def test_mixture_same_family_shape(self):
2477*da0073e9SAndroid Build Coastguard Worker        normal_case_1d = MixtureSameFamily(
2478*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5))
2479*da0073e9SAndroid Build Coastguard Worker        )
2480*da0073e9SAndroid Build Coastguard Worker        normal_case_1d_batch = MixtureSameFamily(
2481*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(3, 5)), Normal(torch.randn(3, 5), torch.rand(3, 5))
2482*da0073e9SAndroid Build Coastguard Worker        )
2483*da0073e9SAndroid Build Coastguard Worker        normal_case_1d_multi_batch = MixtureSameFamily(
2484*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(4, 3, 5)),
2485*da0073e9SAndroid Build Coastguard Worker            Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)),
2486*da0073e9SAndroid Build Coastguard Worker        )
2487*da0073e9SAndroid Build Coastguard Worker        normal_case_2d = MixtureSameFamily(
2488*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(5)),
2489*da0073e9SAndroid Build Coastguard Worker            Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1),
2490*da0073e9SAndroid Build Coastguard Worker        )
2491*da0073e9SAndroid Build Coastguard Worker        normal_case_2d_batch = MixtureSameFamily(
2492*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(3, 5)),
2493*da0073e9SAndroid Build Coastguard Worker            Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1),
2494*da0073e9SAndroid Build Coastguard Worker        )
2495*da0073e9SAndroid Build Coastguard Worker        normal_case_2d_multi_batch = MixtureSameFamily(
2496*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(4, 3, 5)),
2497*da0073e9SAndroid Build Coastguard Worker            Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1),
2498*da0073e9SAndroid Build Coastguard Worker        )
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d.sample().size(), ())
2501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
2502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
2503*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
2504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
2505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
2506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3))
2507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3))
2508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3))
2509*da0073e9SAndroid Build Coastguard Worker
2510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d.sample().size(), (2,))
2511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
2512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
2513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
2514*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
2515*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
2516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2))
2517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2))
2518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2519*da0073e9SAndroid Build Coastguard Worker            normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2)
2520*da0073e9SAndroid Build Coastguard Worker        )
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2523*da0073e9SAndroid Build Coastguard Worker    def test_mixture_same_family_log_prob(self):
2524*da0073e9SAndroid Build Coastguard Worker        probs = torch.rand(5, 5).softmax(dim=-1)
2525*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5, 5)
2526*da0073e9SAndroid Build Coastguard Worker        scale = torch.rand(5, 5)
2527*da0073e9SAndroid Build Coastguard Worker
2528*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
2529*da0073e9SAndroid Build Coastguard Worker            p = probs[idx].numpy()
2530*da0073e9SAndroid Build Coastguard Worker            m = loc[idx].numpy()
2531*da0073e9SAndroid Build Coastguard Worker            s = scale[idx].numpy()
2532*da0073e9SAndroid Build Coastguard Worker            mix = scipy.stats.multinomial(1, p)
2533*da0073e9SAndroid Build Coastguard Worker            comp = scipy.stats.norm(m, s)
2534*da0073e9SAndroid Build Coastguard Worker            expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p))
2535*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
2536*da0073e9SAndroid Build Coastguard Worker
2537*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(
2538*da0073e9SAndroid Build Coastguard Worker            MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
2539*da0073e9SAndroid Build Coastguard Worker            ref_log_prob,
2540*da0073e9SAndroid Build Coastguard Worker        )
2541*da0073e9SAndroid Build Coastguard Worker
2542*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2543*da0073e9SAndroid Build Coastguard Worker    def test_mixture_same_family_sample(self):
2544*da0073e9SAndroid Build Coastguard Worker        probs = torch.rand(5).softmax(dim=-1)
2545*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5)
2546*da0073e9SAndroid Build Coastguard Worker        scale = torch.rand(5)
2547*da0073e9SAndroid Build Coastguard Worker
2548*da0073e9SAndroid Build Coastguard Worker        class ScipyMixtureNormal:
2549*da0073e9SAndroid Build Coastguard Worker            def __init__(self, probs, mu, std):
2550*da0073e9SAndroid Build Coastguard Worker                self.probs = probs
2551*da0073e9SAndroid Build Coastguard Worker                self.mu = mu
2552*da0073e9SAndroid Build Coastguard Worker                self.std = std
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker            def rvs(self, n_sample):
2555*da0073e9SAndroid Build Coastguard Worker                comp_samples = [
2556*da0073e9SAndroid Build Coastguard Worker                    scipy.stats.norm(m, s).rvs(n_sample)
2557*da0073e9SAndroid Build Coastguard Worker                    for m, s in zip(self.mu, self.std)
2558*da0073e9SAndroid Build Coastguard Worker                ]
2559*da0073e9SAndroid Build Coastguard Worker                mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample)
2560*da0073e9SAndroid Build Coastguard Worker                samples = []
2561*da0073e9SAndroid Build Coastguard Worker                for i in range(n_sample):
2562*da0073e9SAndroid Build Coastguard Worker                    samples.append(comp_samples[mix_samples[i].argmax()][i])
2563*da0073e9SAndroid Build Coastguard Worker                return np.asarray(samples)
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
2566*da0073e9SAndroid Build Coastguard Worker            MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
2567*da0073e9SAndroid Build Coastguard Worker            ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
2568*da0073e9SAndroid Build Coastguard Worker            f"""MixtureSameFamily(Categorical(probs={probs}),
2569*da0073e9SAndroid Build Coastguard Worker            Normal(loc={loc}, scale={scale}))""",
2570*da0073e9SAndroid Build Coastguard Worker        )
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2573*da0073e9SAndroid Build Coastguard Worker    def test_normal(self):
2574*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5, 5, requires_grad=True)
2575*da0073e9SAndroid Build Coastguard Worker        scale = torch.randn(5, 5).abs().requires_grad_()
2576*da0073e9SAndroid Build Coastguard Worker        loc_1d = torch.randn(1, requires_grad=True)
2577*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.randn(1).abs().requires_grad_()
2578*da0073e9SAndroid Build Coastguard Worker        loc_delta = torch.tensor([1.0, 0.0])
2579*da0073e9SAndroid Build Coastguard Worker        scale_delta = torch.tensor([1e-5, 1e-5])
2580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
2581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
2582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,))
2584*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(0.2, 0.6).sample((1,)).size(), (1,))
2585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,))
2586*da0073e9SAndroid Build Coastguard Worker
2587*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of mean, std
2588*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
2589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2590*da0073e9SAndroid Build Coastguard Worker            Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
2591*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
2592*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
2593*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2594*da0073e9SAndroid Build Coastguard Worker        )
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Normal, (loc, scale))
2597*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Normal, (loc, 1.0))
2598*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Normal, (0.0, scale))
2599*da0073e9SAndroid Build Coastguard Worker
2600*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
2601*da0073e9SAndroid Build Coastguard Worker        eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale))
2602*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
2603*da0073e9SAndroid Build Coastguard Worker        z = Normal(loc, scale).rsample()
2604*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones_like(z))
2605*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loc.grad, torch.ones_like(loc))
2606*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale.grad, eps)
2607*da0073e9SAndroid Build Coastguard Worker        loc.grad.zero_()
2608*da0073e9SAndroid Build Coastguard Worker        scale.grad.zero_()
2609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (5, 5))
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
2612*da0073e9SAndroid Build Coastguard Worker            m = loc.view(-1)[idx]
2613*da0073e9SAndroid Build Coastguard Worker            s = scale.view(-1)[idx]
2614*da0073e9SAndroid Build Coastguard Worker            expected = math.exp(-((x - m) ** 2) / (2 * s**2)) / math.sqrt(
2615*da0073e9SAndroid Build Coastguard Worker                2 * math.pi * s**2
2616*da0073e9SAndroid Build Coastguard Worker            )
2617*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0)
2618*da0073e9SAndroid Build Coastguard Worker
2619*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Normal(loc, scale), ref_log_prob)
2620*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(torch.normal)
2621*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: torch.normal(x, 0.5))
2622*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: torch.normal(0.2, x))
2623*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: torch.normal(x, x))
2624*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.normal_())
2625*da0073e9SAndroid Build Coastguard Worker
2626*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2627*da0073e9SAndroid Build Coastguard Worker    def test_normal_sample(self):
2628*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2629*da0073e9SAndroid Build Coastguard Worker        for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2630*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
2631*da0073e9SAndroid Build Coastguard Worker                Normal(loc, scale),
2632*da0073e9SAndroid Build Coastguard Worker                scipy.stats.norm(loc=loc, scale=scale),
2633*da0073e9SAndroid Build Coastguard Worker                f"Normal(mean={loc}, std={scale})",
2634*da0073e9SAndroid Build Coastguard Worker            )
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2637*da0073e9SAndroid Build Coastguard Worker    def test_lowrank_multivariate_normal_shape(self):
2638*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 3, requires_grad=True)
2639*da0073e9SAndroid Build Coastguard Worker        mean_no_batch = torch.randn(3, requires_grad=True)
2640*da0073e9SAndroid Build Coastguard Worker        mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker        # construct PSD covariance
2643*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(3, 1, requires_grad=True)
2644*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(3).abs().requires_grad_()
2645*da0073e9SAndroid Build Coastguard Worker
2646*da0073e9SAndroid Build Coastguard Worker        # construct batch of PSD covariances
2647*da0073e9SAndroid Build Coastguard Worker        cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True)
2648*da0073e9SAndroid Build Coastguard Worker        cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_()
2649*da0073e9SAndroid Build Coastguard Worker
2650*da0073e9SAndroid Build Coastguard Worker        # ensure that sample, batch, event shapes all handled correctly
2651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2652*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample().size(),
2653*da0073e9SAndroid Build Coastguard Worker            (5, 3),
2654*da0073e9SAndroid Build Coastguard Worker        )
2655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2656*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2657*da0073e9SAndroid Build Coastguard Worker            .sample()
2658*da0073e9SAndroid Build Coastguard Worker            .size(),
2659*da0073e9SAndroid Build Coastguard Worker            (3,),
2660*da0073e9SAndroid Build Coastguard Worker        )
2661*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2662*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2663*da0073e9SAndroid Build Coastguard Worker            .sample()
2664*da0073e9SAndroid Build Coastguard Worker            .size(),
2665*da0073e9SAndroid Build Coastguard Worker            (6, 5, 3),
2666*da0073e9SAndroid Build Coastguard Worker        )
2667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2668*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2,)).size(),
2669*da0073e9SAndroid Build Coastguard Worker            (2, 5, 3),
2670*da0073e9SAndroid Build Coastguard Worker        )
2671*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2672*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2673*da0073e9SAndroid Build Coastguard Worker            .sample((2,))
2674*da0073e9SAndroid Build Coastguard Worker            .size(),
2675*da0073e9SAndroid Build Coastguard Worker            (2, 3),
2676*da0073e9SAndroid Build Coastguard Worker        )
2677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2678*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2679*da0073e9SAndroid Build Coastguard Worker            .sample((2,))
2680*da0073e9SAndroid Build Coastguard Worker            .size(),
2681*da0073e9SAndroid Build Coastguard Worker            (2, 6, 5, 3),
2682*da0073e9SAndroid Build Coastguard Worker        )
2683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2684*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2, 7)).size(),
2685*da0073e9SAndroid Build Coastguard Worker            (2, 7, 5, 3),
2686*da0073e9SAndroid Build Coastguard Worker        )
2687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2688*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
2689*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2690*da0073e9SAndroid Build Coastguard Worker            .size(),
2691*da0073e9SAndroid Build Coastguard Worker            (2, 7, 3),
2692*da0073e9SAndroid Build Coastguard Worker        )
2693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2694*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
2695*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2696*da0073e9SAndroid Build Coastguard Worker            .size(),
2697*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2698*da0073e9SAndroid Build Coastguard Worker        )
2699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2700*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched)
2701*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2702*da0073e9SAndroid Build Coastguard Worker            .size(),
2703*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2704*da0073e9SAndroid Build Coastguard Worker        )
2705*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2706*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(
2707*da0073e9SAndroid Build Coastguard Worker                mean_no_batch, cov_factor_batched, cov_diag_batched
2708*da0073e9SAndroid Build Coastguard Worker            )
2709*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2710*da0073e9SAndroid Build Coastguard Worker            .size(),
2711*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2712*da0073e9SAndroid Build Coastguard Worker        )
2713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2714*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(
2715*da0073e9SAndroid Build Coastguard Worker                mean_multi_batch, cov_factor_batched, cov_diag_batched
2716*da0073e9SAndroid Build Coastguard Worker            )
2717*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2718*da0073e9SAndroid Build Coastguard Worker            .size(),
2719*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2720*da0073e9SAndroid Build Coastguard Worker        )
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker        # check gradients
2723*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2724*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal, (mean, cov_factor, cov_diag)
2725*da0073e9SAndroid Build Coastguard Worker        )
2726*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2727*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal, (mean_multi_batch, cov_factor, cov_diag)
2728*da0073e9SAndroid Build Coastguard Worker        )
2729*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(
2730*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal,
2731*da0073e9SAndroid Build Coastguard Worker            (mean_multi_batch, cov_factor_batched, cov_diag_batched),
2732*da0073e9SAndroid Build Coastguard Worker        )
2733*da0073e9SAndroid Build Coastguard Worker
2734*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2735*da0073e9SAndroid Build Coastguard Worker    def test_lowrank_multivariate_normal_log_prob(self):
2736*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(3, requires_grad=True)
2737*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(3, 1, requires_grad=True)
2738*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(3).abs().requires_grad_()
2739*da0073e9SAndroid Build Coastguard Worker        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker        # check that logprob values match scipy logpdf,
2742*da0073e9SAndroid Build Coastguard Worker        # and that covariance and scale_tril parameters are equivalent
2743*da0073e9SAndroid Build Coastguard Worker        dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2744*da0073e9SAndroid Build Coastguard Worker        ref_dist = scipy.stats.multivariate_normal(
2745*da0073e9SAndroid Build Coastguard Worker            mean.detach().numpy(), cov.detach().numpy()
2746*da0073e9SAndroid Build Coastguard Worker        )
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker        x = dist1.sample((10,))
2749*da0073e9SAndroid Build Coastguard Worker        expected = ref_dist.logpdf(x.numpy())
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2752*da0073e9SAndroid Build Coastguard Worker            0.0,
2753*da0073e9SAndroid Build Coastguard Worker            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
2754*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
2755*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2756*da0073e9SAndroid Build Coastguard Worker        )
2757*da0073e9SAndroid Build Coastguard Worker
2758*da0073e9SAndroid Build Coastguard Worker        # Double-check that batched versions behave the same as unbatched
2759*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 3, requires_grad=True)
2760*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(5, 3, 2, requires_grad=True)
2761*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(5, 3).abs().requires_grad_()
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker        dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2764*da0073e9SAndroid Build Coastguard Worker        dist_unbatched = [
2765*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i])
2766*da0073e9SAndroid Build Coastguard Worker            for i in range(mean.size(0))
2767*da0073e9SAndroid Build Coastguard Worker        ]
2768*da0073e9SAndroid Build Coastguard Worker
2769*da0073e9SAndroid Build Coastguard Worker        x = dist_batched.sample((10,))
2770*da0073e9SAndroid Build Coastguard Worker        batched_prob = dist_batched.log_prob(x)
2771*da0073e9SAndroid Build Coastguard Worker        unbatched_prob = torch.stack(
2772*da0073e9SAndroid Build Coastguard Worker            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
2773*da0073e9SAndroid Build Coastguard Worker        ).t()
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
2776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
2777*da0073e9SAndroid Build Coastguard Worker
2778*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2779*da0073e9SAndroid Build Coastguard Worker    def test_lowrank_multivariate_normal_sample(self):
2780*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2781*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, requires_grad=True)
2782*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(5, 1, requires_grad=True)
2783*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(5).abs().requires_grad_()
2784*da0073e9SAndroid Build Coastguard Worker        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2785*da0073e9SAndroid Build Coastguard Worker
2786*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
2787*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(mean, cov_factor, cov_diag),
2788*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multivariate_normal(
2789*da0073e9SAndroid Build Coastguard Worker                mean.detach().numpy(), cov.detach().numpy()
2790*da0073e9SAndroid Build Coastguard Worker            ),
2791*da0073e9SAndroid Build Coastguard Worker            f"LowRankMultivariateNormal(loc={mean}, cov_factor={cov_factor}, cov_diag={cov_diag})",
2792*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
2793*da0073e9SAndroid Build Coastguard Worker        )
2794*da0073e9SAndroid Build Coastguard Worker
2795*da0073e9SAndroid Build Coastguard Worker    def test_lowrank_multivariate_normal_properties(self):
2796*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5)
2797*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(5, 2)
2798*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(5).abs()
2799*da0073e9SAndroid Build Coastguard Worker        cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
2800*da0073e9SAndroid Build Coastguard Worker        m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag)
2801*da0073e9SAndroid Build Coastguard Worker        m2 = MultivariateNormal(loc=loc, covariance_matrix=cov)
2802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.mean, m2.mean)
2803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.variance, m2.variance)
2804*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.covariance_matrix, m2.covariance_matrix)
2805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.scale_tril, m2.scale_tril)
2806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.precision_matrix, m2.precision_matrix)
2807*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m1.entropy(), m2.entropy())
2808*da0073e9SAndroid Build Coastguard Worker
2809*da0073e9SAndroid Build Coastguard Worker    def test_lowrank_multivariate_normal_moments(self):
2810*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2811*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5)
2812*da0073e9SAndroid Build Coastguard Worker        cov_factor = torch.randn(5, 2)
2813*da0073e9SAndroid Build Coastguard Worker        cov_diag = torch.randn(5).abs()
2814*da0073e9SAndroid Build Coastguard Worker        d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
2815*da0073e9SAndroid Build Coastguard Worker        samples = d.rsample((100000,))
2816*da0073e9SAndroid Build Coastguard Worker        empirical_mean = samples.mean(0)
2817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
2818*da0073e9SAndroid Build Coastguard Worker        empirical_var = samples.var(0)
2819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0)
2820*da0073e9SAndroid Build Coastguard Worker
2821*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2822*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_shape(self):
2823*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 3, requires_grad=True)
2824*da0073e9SAndroid Build Coastguard Worker        mean_no_batch = torch.randn(3, requires_grad=True)
2825*da0073e9SAndroid Build Coastguard Worker        mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
2826*da0073e9SAndroid Build Coastguard Worker
2827*da0073e9SAndroid Build Coastguard Worker        # construct PSD covariance
2828*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(3, 10)
2829*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
2830*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
2831*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
2832*da0073e9SAndroid Build Coastguard Worker
2833*da0073e9SAndroid Build Coastguard Worker        # construct batch of PSD covariances
2834*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(6, 5, 3, 10)
2835*da0073e9SAndroid Build Coastguard Worker        cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
2836*da0073e9SAndroid Build Coastguard Worker        prec_batched = cov_batched.inverse()
2837*da0073e9SAndroid Build Coastguard Worker        scale_tril_batched = torch.linalg.cholesky(cov_batched)
2838*da0073e9SAndroid Build Coastguard Worker
2839*da0073e9SAndroid Build Coastguard Worker        # ensure that sample, batch, event shapes all handled correctly
2840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
2841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
2842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2843*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3)
2844*da0073e9SAndroid Build Coastguard Worker        )
2845*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
2846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2847*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3)
2848*da0073e9SAndroid Build Coastguard Worker        )
2849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2850*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3)
2851*da0073e9SAndroid Build Coastguard Worker        )
2852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2853*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3)
2854*da0073e9SAndroid Build Coastguard Worker        )
2855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2856*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3)
2857*da0073e9SAndroid Build Coastguard Worker        )
2858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2859*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(),
2860*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2861*da0073e9SAndroid Build Coastguard Worker        )
2862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2863*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3)
2864*da0073e9SAndroid Build Coastguard Worker        )
2865*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2866*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(),
2867*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2868*da0073e9SAndroid Build Coastguard Worker        )
2869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2870*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(),
2871*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2872*da0073e9SAndroid Build Coastguard Worker        )
2873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2874*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(),
2875*da0073e9SAndroid Build Coastguard Worker            (2, 7, 5, 3),
2876*da0073e9SAndroid Build Coastguard Worker        )
2877*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2878*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, precision_matrix=prec_batched)
2879*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2880*da0073e9SAndroid Build Coastguard Worker            .size(),
2881*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2882*da0073e9SAndroid Build Coastguard Worker        )
2883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2884*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(),
2885*da0073e9SAndroid Build Coastguard Worker            (2, 7, 5, 3),
2886*da0073e9SAndroid Build Coastguard Worker        )
2887*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2888*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, scale_tril=scale_tril_batched)
2889*da0073e9SAndroid Build Coastguard Worker            .sample((2, 7))
2890*da0073e9SAndroid Build Coastguard Worker            .size(),
2891*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, 3),
2892*da0073e9SAndroid Build Coastguard Worker        )
2893*da0073e9SAndroid Build Coastguard Worker
2894*da0073e9SAndroid Build Coastguard Worker        # check gradients
2895*da0073e9SAndroid Build Coastguard Worker        # We write a custom gradcheck function to maintain the symmetry
2896*da0073e9SAndroid Build Coastguard Worker        # of the perturbed covariances and their inverses (precision)
2897*da0073e9SAndroid Build Coastguard Worker        def multivariate_normal_log_prob_gradcheck(
2898*da0073e9SAndroid Build Coastguard Worker            mean, covariance=None, precision=None, scale_tril=None
2899*da0073e9SAndroid Build Coastguard Worker        ):
2900*da0073e9SAndroid Build Coastguard Worker            mvn_samples = (
2901*da0073e9SAndroid Build Coastguard Worker                MultivariateNormal(mean, covariance, precision, scale_tril)
2902*da0073e9SAndroid Build Coastguard Worker                .sample()
2903*da0073e9SAndroid Build Coastguard Worker                .requires_grad_()
2904*da0073e9SAndroid Build Coastguard Worker            )
2905*da0073e9SAndroid Build Coastguard Worker
2906*da0073e9SAndroid Build Coastguard Worker            def gradcheck_func(samples, mu, sigma, prec, scale_tril):
2907*da0073e9SAndroid Build Coastguard Worker                if sigma is not None:
2908*da0073e9SAndroid Build Coastguard Worker                    sigma = 0.5 * (sigma + sigma.mT)  # Ensure symmetry of covariance
2909*da0073e9SAndroid Build Coastguard Worker                if prec is not None:
2910*da0073e9SAndroid Build Coastguard Worker                    prec = 0.5 * (prec + prec.mT)  # Ensure symmetry of precision
2911*da0073e9SAndroid Build Coastguard Worker                if scale_tril is not None:
2912*da0073e9SAndroid Build Coastguard Worker                    scale_tril = scale_tril.tril()
2913*da0073e9SAndroid Build Coastguard Worker                return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
2914*da0073e9SAndroid Build Coastguard Worker
2915*da0073e9SAndroid Build Coastguard Worker            gradcheck(
2916*da0073e9SAndroid Build Coastguard Worker                gradcheck_func,
2917*da0073e9SAndroid Build Coastguard Worker                (mvn_samples, mean, covariance, precision, scale_tril),
2918*da0073e9SAndroid Build Coastguard Worker                raise_exception=True,
2919*da0073e9SAndroid Build Coastguard Worker            )
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean, cov)
2922*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov)
2923*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched)
2924*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean, None, prec)
2925*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched)
2926*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
2927*da0073e9SAndroid Build Coastguard Worker        multivariate_normal_log_prob_gradcheck(
2928*da0073e9SAndroid Build Coastguard Worker            mean_no_batch, None, None, scale_tril_batched
2929*da0073e9SAndroid Build Coastguard Worker        )
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
2932*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_stable_with_precision_matrix(self):
2933*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2934*da0073e9SAndroid Build Coastguard Worker        P = torch.exp(-((x - x.unsqueeze(-1)) ** 2))  # RBF kernel
2935*da0073e9SAndroid Build Coastguard Worker        MultivariateNormal(x.new_zeros(10), precision_matrix=P)
2936*da0073e9SAndroid Build Coastguard Worker
2937*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2938*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_log_prob(self):
2939*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(3, requires_grad=True)
2940*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(3, 10)
2941*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
2942*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
2943*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
2944*da0073e9SAndroid Build Coastguard Worker
2945*da0073e9SAndroid Build Coastguard Worker        # check that logprob values match scipy logpdf,
2946*da0073e9SAndroid Build Coastguard Worker        # and that covariance and scale_tril parameters are equivalent
2947*da0073e9SAndroid Build Coastguard Worker        dist1 = MultivariateNormal(mean, cov)
2948*da0073e9SAndroid Build Coastguard Worker        dist2 = MultivariateNormal(mean, precision_matrix=prec)
2949*da0073e9SAndroid Build Coastguard Worker        dist3 = MultivariateNormal(mean, scale_tril=scale_tril)
2950*da0073e9SAndroid Build Coastguard Worker        ref_dist = scipy.stats.multivariate_normal(
2951*da0073e9SAndroid Build Coastguard Worker            mean.detach().numpy(), cov.detach().numpy()
2952*da0073e9SAndroid Build Coastguard Worker        )
2953*da0073e9SAndroid Build Coastguard Worker
2954*da0073e9SAndroid Build Coastguard Worker        x = dist1.sample((10,))
2955*da0073e9SAndroid Build Coastguard Worker        expected = ref_dist.logpdf(x.numpy())
2956*da0073e9SAndroid Build Coastguard Worker
2957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2958*da0073e9SAndroid Build Coastguard Worker            0.0,
2959*da0073e9SAndroid Build Coastguard Worker            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
2960*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
2961*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2962*da0073e9SAndroid Build Coastguard Worker        )
2963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2964*da0073e9SAndroid Build Coastguard Worker            0.0,
2965*da0073e9SAndroid Build Coastguard Worker            np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2),
2966*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
2967*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2968*da0073e9SAndroid Build Coastguard Worker        )
2969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2970*da0073e9SAndroid Build Coastguard Worker            0.0,
2971*da0073e9SAndroid Build Coastguard Worker            np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2),
2972*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
2973*da0073e9SAndroid Build Coastguard Worker            rtol=0,
2974*da0073e9SAndroid Build Coastguard Worker        )
2975*da0073e9SAndroid Build Coastguard Worker
2976*da0073e9SAndroid Build Coastguard Worker        # Double-check that batched versions behave the same as unbatched
2977*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5, 3, requires_grad=True)
2978*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(5, 3, 10)
2979*da0073e9SAndroid Build Coastguard Worker        cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
2980*da0073e9SAndroid Build Coastguard Worker
2981*da0073e9SAndroid Build Coastguard Worker        dist_batched = MultivariateNormal(mean, cov)
2982*da0073e9SAndroid Build Coastguard Worker        dist_unbatched = [
2983*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0))
2984*da0073e9SAndroid Build Coastguard Worker        ]
2985*da0073e9SAndroid Build Coastguard Worker
2986*da0073e9SAndroid Build Coastguard Worker        x = dist_batched.sample((10,))
2987*da0073e9SAndroid Build Coastguard Worker        batched_prob = dist_batched.log_prob(x)
2988*da0073e9SAndroid Build Coastguard Worker        unbatched_prob = torch.stack(
2989*da0073e9SAndroid Build Coastguard Worker            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
2990*da0073e9SAndroid Build Coastguard Worker        ).t()
2991*da0073e9SAndroid Build Coastguard Worker
2992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
2993*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
2994*da0073e9SAndroid Build Coastguard Worker
2995*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2996*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_sample(self):
2997*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
2998*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(3, requires_grad=True)
2999*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(3, 10)
3000*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3001*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
3002*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3005*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, cov),
3006*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multivariate_normal(
3007*da0073e9SAndroid Build Coastguard Worker                mean.detach().numpy(), cov.detach().numpy()
3008*da0073e9SAndroid Build Coastguard Worker            ),
3009*da0073e9SAndroid Build Coastguard Worker            f"MultivariateNormal(loc={mean}, cov={cov})",
3010*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3011*da0073e9SAndroid Build Coastguard Worker        )
3012*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3013*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, precision_matrix=prec),
3014*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multivariate_normal(
3015*da0073e9SAndroid Build Coastguard Worker                mean.detach().numpy(), cov.detach().numpy()
3016*da0073e9SAndroid Build Coastguard Worker            ),
3017*da0073e9SAndroid Build Coastguard Worker            f"MultivariateNormal(loc={mean}, atol={prec})",
3018*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3019*da0073e9SAndroid Build Coastguard Worker        )
3020*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3021*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(mean, scale_tril=scale_tril),
3022*da0073e9SAndroid Build Coastguard Worker            scipy.stats.multivariate_normal(
3023*da0073e9SAndroid Build Coastguard Worker                mean.detach().numpy(), cov.detach().numpy()
3024*da0073e9SAndroid Build Coastguard Worker            ),
3025*da0073e9SAndroid Build Coastguard Worker            f"MultivariateNormal(loc={mean}, scale_tril={scale_tril})",
3026*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3027*da0073e9SAndroid Build Coastguard Worker        )
3028*da0073e9SAndroid Build Coastguard Worker
3029*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3030*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_properties(self):
3031*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5)
3032*da0073e9SAndroid Build Coastguard Worker        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
3033*da0073e9SAndroid Build Coastguard Worker        m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
3034*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
3035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3036*da0073e9SAndroid Build Coastguard Worker            m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])
3037*da0073e9SAndroid Build Coastguard Worker        )
3038*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
3039*da0073e9SAndroid Build Coastguard Worker
3040*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3041*da0073e9SAndroid Build Coastguard Worker    def test_multivariate_normal_moments(self):
3042*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3043*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(5)
3044*da0073e9SAndroid Build Coastguard Worker        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
3045*da0073e9SAndroid Build Coastguard Worker        d = MultivariateNormal(mean, scale_tril=scale_tril)
3046*da0073e9SAndroid Build Coastguard Worker        samples = d.rsample((100000,))
3047*da0073e9SAndroid Build Coastguard Worker        empirical_mean = samples.mean(0)
3048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0)
3049*da0073e9SAndroid Build Coastguard Worker        empirical_var = samples.var(0)
3050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
3051*da0073e9SAndroid Build Coastguard Worker
3052*da0073e9SAndroid Build Coastguard Worker    # We applied same tests in Multivariate Normal distribution for Wishart distribution
3053*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3054*da0073e9SAndroid Build Coastguard Worker    def test_wishart_shape(self):
3055*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3056*da0073e9SAndroid Build Coastguard Worker        ndim = 3
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker        df = torch.rand(5, requires_grad=True) + ndim
3059*da0073e9SAndroid Build Coastguard Worker        df_no_batch = torch.rand([], requires_grad=True) + ndim
3060*da0073e9SAndroid Build Coastguard Worker        df_multi_batch = torch.rand(6, 5, requires_grad=True) + ndim
3061*da0073e9SAndroid Build Coastguard Worker
3062*da0073e9SAndroid Build Coastguard Worker        # construct PSD covariance
3063*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(ndim, 10)
3064*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3065*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
3066*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3067*da0073e9SAndroid Build Coastguard Worker
3068*da0073e9SAndroid Build Coastguard Worker        # construct batch of PSD covariances
3069*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(6, 5, ndim, 10)
3070*da0073e9SAndroid Build Coastguard Worker        cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
3071*da0073e9SAndroid Build Coastguard Worker        prec_batched = cov_batched.inverse()
3072*da0073e9SAndroid Build Coastguard Worker        scale_tril_batched = torch.linalg.cholesky(cov_batched)
3073*da0073e9SAndroid Build Coastguard Worker
3074*da0073e9SAndroid Build Coastguard Worker        # ensure that sample, batch, event shapes all handled correctly
3075*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Wishart(df, cov).sample().size(), (5, ndim, ndim))
3076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (ndim, ndim))
3077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3078*da0073e9SAndroid Build Coastguard Worker            Wishart(df_multi_batch, cov).sample().size(), (6, 5, ndim, ndim)
3079*da0073e9SAndroid Build Coastguard Worker        )
3080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, ndim, ndim))
3081*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, ndim, ndim))
3082*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3083*da0073e9SAndroid Build Coastguard Worker            Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, ndim, ndim)
3084*da0073e9SAndroid Build Coastguard Worker        )
3085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, ndim, ndim))
3086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3087*da0073e9SAndroid Build Coastguard Worker            Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, ndim, ndim)
3088*da0073e9SAndroid Build Coastguard Worker        )
3089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3090*da0073e9SAndroid Build Coastguard Worker            Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim)
3091*da0073e9SAndroid Build Coastguard Worker        )
3092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3093*da0073e9SAndroid Build Coastguard Worker            Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim)
3094*da0073e9SAndroid Build Coastguard Worker        )
3095*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3096*da0073e9SAndroid Build Coastguard Worker            Wishart(df_no_batch, cov_batched).sample((2, 7)).size(),
3097*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, ndim, ndim),
3098*da0073e9SAndroid Build Coastguard Worker        )
3099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3100*da0073e9SAndroid Build Coastguard Worker            Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(),
3101*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, ndim, ndim),
3102*da0073e9SAndroid Build Coastguard Worker        )
3103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3104*da0073e9SAndroid Build Coastguard Worker            Wishart(df, precision_matrix=prec).sample((2, 7)).size(),
3105*da0073e9SAndroid Build Coastguard Worker            (2, 7, 5, ndim, ndim),
3106*da0073e9SAndroid Build Coastguard Worker        )
3107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3108*da0073e9SAndroid Build Coastguard Worker            Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(),
3109*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, ndim, ndim),
3110*da0073e9SAndroid Build Coastguard Worker        )
3111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3112*da0073e9SAndroid Build Coastguard Worker            Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(),
3113*da0073e9SAndroid Build Coastguard Worker            (2, 7, 5, ndim, ndim),
3114*da0073e9SAndroid Build Coastguard Worker        )
3115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3116*da0073e9SAndroid Build Coastguard Worker            Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(),
3117*da0073e9SAndroid Build Coastguard Worker            (2, 7, 6, 5, ndim, ndim),
3118*da0073e9SAndroid Build Coastguard Worker        )
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker        # check gradients
3121*da0073e9SAndroid Build Coastguard Worker        # Modified and applied the same tests for multivariate_normal
3122*da0073e9SAndroid Build Coastguard Worker        def wishart_log_prob_gradcheck(
3123*da0073e9SAndroid Build Coastguard Worker            df=None, covariance=None, precision=None, scale_tril=None
3124*da0073e9SAndroid Build Coastguard Worker        ):
3125*da0073e9SAndroid Build Coastguard Worker            wishart_samples = (
3126*da0073e9SAndroid Build Coastguard Worker                Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()
3127*da0073e9SAndroid Build Coastguard Worker            )
3128*da0073e9SAndroid Build Coastguard Worker
3129*da0073e9SAndroid Build Coastguard Worker            def gradcheck_func(samples, nu, sigma, prec, scale_tril):
3130*da0073e9SAndroid Build Coastguard Worker                if sigma is not None:
3131*da0073e9SAndroid Build Coastguard Worker                    sigma = 0.5 * (sigma + sigma.mT)  # Ensure symmetry of covariance
3132*da0073e9SAndroid Build Coastguard Worker                if prec is not None:
3133*da0073e9SAndroid Build Coastguard Worker                    prec = 0.5 * (prec + prec.mT)  # Ensure symmetry of precision
3134*da0073e9SAndroid Build Coastguard Worker                if scale_tril is not None:
3135*da0073e9SAndroid Build Coastguard Worker                    scale_tril = scale_tril.tril()
3136*da0073e9SAndroid Build Coastguard Worker                return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker            gradcheck(
3139*da0073e9SAndroid Build Coastguard Worker                gradcheck_func,
3140*da0073e9SAndroid Build Coastguard Worker                (wishart_samples, df, covariance, precision, scale_tril),
3141*da0073e9SAndroid Build Coastguard Worker                raise_exception=True,
3142*da0073e9SAndroid Build Coastguard Worker            )
3143*da0073e9SAndroid Build Coastguard Worker
3144*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df, cov)
3145*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df_multi_batch, cov)
3146*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
3147*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df, None, prec)
3148*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
3149*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df, None, None, scale_tril)
3150*da0073e9SAndroid Build Coastguard Worker        wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)
3151*da0073e9SAndroid Build Coastguard Worker
3152*da0073e9SAndroid Build Coastguard Worker    def test_wishart_stable_with_precision_matrix(self):
3153*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3154*da0073e9SAndroid Build Coastguard Worker        ndim = 10
3155*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(ndim)
3156*da0073e9SAndroid Build Coastguard Worker        P = torch.exp(-((x - x.unsqueeze(-1)) ** 2))  # RBF kernel
3157*da0073e9SAndroid Build Coastguard Worker        Wishart(torch.tensor(ndim), precision_matrix=P)
3158*da0073e9SAndroid Build Coastguard Worker
3159*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3160*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3161*da0073e9SAndroid Build Coastguard Worker    def test_wishart_log_prob(self):
3162*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3163*da0073e9SAndroid Build Coastguard Worker        ndim = 3
3164*da0073e9SAndroid Build Coastguard Worker        df = torch.rand([], requires_grad=True) + ndim - 1
3165*da0073e9SAndroid Build Coastguard Worker        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3166*da0073e9SAndroid Build Coastguard Worker        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3167*da0073e9SAndroid Build Coastguard Worker            df += 1.0
3168*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(ndim, 10)
3169*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3170*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
3171*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3172*da0073e9SAndroid Build Coastguard Worker
3173*da0073e9SAndroid Build Coastguard Worker        # check that logprob values match scipy logpdf,
3174*da0073e9SAndroid Build Coastguard Worker        # and that covariance and scale_tril parameters are equivalent
3175*da0073e9SAndroid Build Coastguard Worker        dist1 = Wishart(df, cov)
3176*da0073e9SAndroid Build Coastguard Worker        dist2 = Wishart(df, precision_matrix=prec)
3177*da0073e9SAndroid Build Coastguard Worker        dist3 = Wishart(df, scale_tril=scale_tril)
3178*da0073e9SAndroid Build Coastguard Worker        ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker        x = dist1.sample((1000,))
3181*da0073e9SAndroid Build Coastguard Worker        expected = ref_dist.logpdf(x.transpose(0, 2).numpy())
3182*da0073e9SAndroid Build Coastguard Worker
3183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3184*da0073e9SAndroid Build Coastguard Worker            0.0,
3185*da0073e9SAndroid Build Coastguard Worker            np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2),
3186*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
3187*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3188*da0073e9SAndroid Build Coastguard Worker        )
3189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3190*da0073e9SAndroid Build Coastguard Worker            0.0,
3191*da0073e9SAndroid Build Coastguard Worker            np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2),
3192*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
3193*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3194*da0073e9SAndroid Build Coastguard Worker        )
3195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3196*da0073e9SAndroid Build Coastguard Worker            0.0,
3197*da0073e9SAndroid Build Coastguard Worker            np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2),
3198*da0073e9SAndroid Build Coastguard Worker            atol=1e-3,
3199*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3200*da0073e9SAndroid Build Coastguard Worker        )
3201*da0073e9SAndroid Build Coastguard Worker
3202*da0073e9SAndroid Build Coastguard Worker        # Double-check that batched versions behave the same as unbatched
3203*da0073e9SAndroid Build Coastguard Worker        df = torch.rand(5, requires_grad=True) + ndim - 1
3204*da0073e9SAndroid Build Coastguard Worker        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3205*da0073e9SAndroid Build Coastguard Worker        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3206*da0073e9SAndroid Build Coastguard Worker            df += 1.0
3207*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(5, ndim, 10)
3208*da0073e9SAndroid Build Coastguard Worker        cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
3209*da0073e9SAndroid Build Coastguard Worker
3210*da0073e9SAndroid Build Coastguard Worker        dist_batched = Wishart(df, cov)
3211*da0073e9SAndroid Build Coastguard Worker        dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker        x = dist_batched.sample((1000,))
3214*da0073e9SAndroid Build Coastguard Worker        batched_prob = dist_batched.log_prob(x)
3215*da0073e9SAndroid Build Coastguard Worker        unbatched_prob = torch.stack(
3216*da0073e9SAndroid Build Coastguard Worker            [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]
3217*da0073e9SAndroid Build Coastguard Worker        ).t()
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob.shape, unbatched_prob.shape)
3220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)
3221*da0073e9SAndroid Build Coastguard Worker
3222*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3223*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3224*da0073e9SAndroid Build Coastguard Worker    def test_wishart_sample(self):
3225*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3226*da0073e9SAndroid Build Coastguard Worker        ndim = 3
3227*da0073e9SAndroid Build Coastguard Worker        df = torch.rand([], requires_grad=True) + ndim - 1
3228*da0073e9SAndroid Build Coastguard Worker        # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
3229*da0073e9SAndroid Build Coastguard Worker        if version.parse(scipy.__version__) < version.parse("1.7.0"):
3230*da0073e9SAndroid Build Coastguard Worker            df += 1.0
3231*da0073e9SAndroid Build Coastguard Worker        tmp = torch.randn(ndim, 10)
3232*da0073e9SAndroid Build Coastguard Worker        cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
3233*da0073e9SAndroid Build Coastguard Worker        prec = cov.inverse().requires_grad_()
3234*da0073e9SAndroid Build Coastguard Worker        scale_tril = torch.linalg.cholesky(cov).requires_grad_()
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker        ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3239*da0073e9SAndroid Build Coastguard Worker            Wishart(df, cov),
3240*da0073e9SAndroid Build Coastguard Worker            ref_dist,
3241*da0073e9SAndroid Build Coastguard Worker            f"Wishart(df={df}, covariance_matrix={cov})",
3242*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3243*da0073e9SAndroid Build Coastguard Worker        )
3244*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3245*da0073e9SAndroid Build Coastguard Worker            Wishart(df, precision_matrix=prec),
3246*da0073e9SAndroid Build Coastguard Worker            ref_dist,
3247*da0073e9SAndroid Build Coastguard Worker            f"Wishart(df={df}, precision_matrix={prec})",
3248*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3249*da0073e9SAndroid Build Coastguard Worker        )
3250*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3251*da0073e9SAndroid Build Coastguard Worker            Wishart(df, scale_tril=scale_tril),
3252*da0073e9SAndroid Build Coastguard Worker            ref_dist,
3253*da0073e9SAndroid Build Coastguard Worker            f"Wishart(df={df}, scale_tril={scale_tril})",
3254*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3255*da0073e9SAndroid Build Coastguard Worker        )
3256*da0073e9SAndroid Build Coastguard Worker
3257*da0073e9SAndroid Build Coastguard Worker    def test_wishart_properties(self):
3258*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3259*da0073e9SAndroid Build Coastguard Worker        ndim = 5
3260*da0073e9SAndroid Build Coastguard Worker        df = torch.rand([]) + ndim - 1
3261*da0073e9SAndroid Build Coastguard Worker        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
3262*da0073e9SAndroid Build Coastguard Worker        m = Wishart(df=df, scale_tril=scale_tril)
3263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
3264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3265*da0073e9SAndroid Build Coastguard Worker            m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])
3266*da0073e9SAndroid Build Coastguard Worker        )
3267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
3268*da0073e9SAndroid Build Coastguard Worker
3269*da0073e9SAndroid Build Coastguard Worker    def test_wishart_moments(self):
3270*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3271*da0073e9SAndroid Build Coastguard Worker        ndim = 3
3272*da0073e9SAndroid Build Coastguard Worker        df = torch.rand([]) + ndim - 1
3273*da0073e9SAndroid Build Coastguard Worker        scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim))
3274*da0073e9SAndroid Build Coastguard Worker        d = Wishart(df=df, scale_tril=scale_tril)
3275*da0073e9SAndroid Build Coastguard Worker        samples = d.rsample((ndim * ndim * 100000,))
3276*da0073e9SAndroid Build Coastguard Worker        empirical_mean = samples.mean(0)
3277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0)
3278*da0073e9SAndroid Build Coastguard Worker        empirical_var = samples.var(0)
3279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0)
3280*da0073e9SAndroid Build Coastguard Worker
3281*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3282*da0073e9SAndroid Build Coastguard Worker    def test_exponential(self):
3283*da0073e9SAndroid Build Coastguard Worker        rate = torch.randn(5, 5).abs().requires_grad_()
3284*da0073e9SAndroid Build Coastguard Worker        rate_1d = torch.randn(1).abs().requires_grad_()
3285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate).sample().size(), (5, 5))
3286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
3287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
3288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
3289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
3290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
3291*da0073e9SAndroid Build Coastguard Worker
3292*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Exponential, (rate,))
3293*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
3294*da0073e9SAndroid Build Coastguard Worker        eps = rate.new(rate.size()).exponential_()
3295*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
3296*da0073e9SAndroid Build Coastguard Worker        z = Exponential(rate).rsample()
3297*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones_like(z))
3298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rate.grad, -eps / rate**2)
3299*da0073e9SAndroid Build Coastguard Worker        rate.grad.zero_()
3300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (5, 5))
3301*da0073e9SAndroid Build Coastguard Worker
3302*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3303*da0073e9SAndroid Build Coastguard Worker            m = rate.view(-1)[idx]
3304*da0073e9SAndroid Build Coastguard Worker            expected = math.log(m) - m * x
3305*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3306*da0073e9SAndroid Build Coastguard Worker
3307*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Exponential(rate), ref_log_prob)
3308*da0073e9SAndroid Build Coastguard Worker        self._check_forward_ad(lambda x: x.exponential_())
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        def mean_var(lambd, sample):
3311*da0073e9SAndroid Build Coastguard Worker            sample.exponential_(lambd)
3312*da0073e9SAndroid Build Coastguard Worker            mean = sample.float().mean()
3313*da0073e9SAndroid Build Coastguard Worker            var = sample.float().var()
3314*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2)
3315*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2)
3316*da0073e9SAndroid Build Coastguard Worker
3317*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]:
3318*da0073e9SAndroid Build Coastguard Worker            for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]:
3319*da0073e9SAndroid Build Coastguard Worker                sample_len = 50000
3320*da0073e9SAndroid Build Coastguard Worker                mean_var(lambd, torch.rand(sample_len, dtype=dtype))
3321*da0073e9SAndroid Build Coastguard Worker
3322*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3323*da0073e9SAndroid Build Coastguard Worker    def test_exponential_sample(self):
3324*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
3325*da0073e9SAndroid Build Coastguard Worker        for rate in [1e-5, 1.0, 10.0]:
3326*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3327*da0073e9SAndroid Build Coastguard Worker                Exponential(rate),
3328*da0073e9SAndroid Build Coastguard Worker                scipy.stats.expon(scale=1.0 / rate),
3329*da0073e9SAndroid Build Coastguard Worker                f"Exponential(rate={rate})",
3330*da0073e9SAndroid Build Coastguard Worker            )
3331*da0073e9SAndroid Build Coastguard Worker
3332*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3333*da0073e9SAndroid Build Coastguard Worker    def test_laplace(self):
3334*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(5, 5, requires_grad=True)
3335*da0073e9SAndroid Build Coastguard Worker        scale = torch.randn(5, 5).abs().requires_grad_()
3336*da0073e9SAndroid Build Coastguard Worker        loc_1d = torch.randn(1, requires_grad=True)
3337*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.randn(1, requires_grad=True)
3338*da0073e9SAndroid Build Coastguard Worker        loc_delta = torch.tensor([1.0, 0.0])
3339*da0073e9SAndroid Build Coastguard Worker        scale_delta = torch.tensor([1e-5, 1e-5])
3340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
3341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5))
3342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
3343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
3344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(0.2, 0.6).sample((1,)).size(), (1,))
3345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,))
3346*da0073e9SAndroid Build Coastguard Worker
3347*da0073e9SAndroid Build Coastguard Worker        # sample check for extreme value of mean, std
3348*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)
3349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3350*da0073e9SAndroid Build Coastguard Worker            Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
3351*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
3352*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
3353*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3354*da0073e9SAndroid Build Coastguard Worker        )
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Laplace, (loc, scale))
3357*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Laplace, (loc, 1.0))
3358*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(Laplace, (0.0, scale))
3359*da0073e9SAndroid Build Coastguard Worker
3360*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
3361*da0073e9SAndroid Build Coastguard Worker        eps = torch.ones_like(loc).uniform_(-0.5, 0.5)
3362*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
3363*da0073e9SAndroid Build Coastguard Worker        z = Laplace(loc, scale).rsample()
3364*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones_like(z))
3365*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loc.grad, torch.ones_like(loc))
3366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
3367*da0073e9SAndroid Build Coastguard Worker        loc.grad.zero_()
3368*da0073e9SAndroid Build Coastguard Worker        scale.grad.zero_()
3369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.size(), (5, 5))
3370*da0073e9SAndroid Build Coastguard Worker
3371*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3372*da0073e9SAndroid Build Coastguard Worker            m = loc.view(-1)[idx]
3373*da0073e9SAndroid Build Coastguard Worker            s = scale.view(-1)[idx]
3374*da0073e9SAndroid Build Coastguard Worker            expected = -math.log(2 * s) - abs(x - m) / s
3375*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3376*da0073e9SAndroid Build Coastguard Worker
3377*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Laplace(loc, scale), ref_log_prob)
3378*da0073e9SAndroid Build Coastguard Worker
3379*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3380*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3381*da0073e9SAndroid Build Coastguard Worker    def test_laplace_sample(self):
3382*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
3383*da0073e9SAndroid Build Coastguard Worker        for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
3384*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3385*da0073e9SAndroid Build Coastguard Worker                Laplace(loc, scale),
3386*da0073e9SAndroid Build Coastguard Worker                scipy.stats.laplace(loc=loc, scale=scale),
3387*da0073e9SAndroid Build Coastguard Worker                f"Laplace(loc={loc}, scale={scale})",
3388*da0073e9SAndroid Build Coastguard Worker            )
3389*da0073e9SAndroid Build Coastguard Worker
3390*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3391*da0073e9SAndroid Build Coastguard Worker    def test_gamma_shape(self):
3392*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(2, 3).exp().requires_grad_()
3393*da0073e9SAndroid Build Coastguard Worker        beta = torch.randn(2, 3).exp().requires_grad_()
3394*da0073e9SAndroid Build Coastguard Worker        alpha_1d = torch.randn(1).exp().requires_grad_()
3395*da0073e9SAndroid Build Coastguard Worker        beta_1d = torch.randn(1).exp().requires_grad_()
3396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
3397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
3398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
3399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
3400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
3401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
3402*da0073e9SAndroid Build Coastguard Worker
3403*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3404*da0073e9SAndroid Build Coastguard Worker            a = alpha.view(-1)[idx].detach()
3405*da0073e9SAndroid Build Coastguard Worker            b = beta.view(-1)[idx].detach()
3406*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
3407*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
3410*da0073e9SAndroid Build Coastguard Worker
3411*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3412*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3413*da0073e9SAndroid Build Coastguard Worker    def test_gamma_gpu_shape(self):
3414*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
3415*da0073e9SAndroid Build Coastguard Worker        beta = torch.randn(2, 3).cuda().exp().requires_grad_()
3416*da0073e9SAndroid Build Coastguard Worker        alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
3417*da0073e9SAndroid Build Coastguard Worker        beta_1d = torch.randn(1).cuda().exp().requires_grad_()
3418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
3419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
3420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
3421*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
3422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
3423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
3424*da0073e9SAndroid Build Coastguard Worker
3425*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3426*da0073e9SAndroid Build Coastguard Worker            a = alpha.view(-1)[idx].detach().cpu()
3427*da0073e9SAndroid Build Coastguard Worker            b = beta.view(-1)[idx].detach().cpu()
3428*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
3429*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3430*da0073e9SAndroid Build Coastguard Worker
3431*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
3432*da0073e9SAndroid Build Coastguard Worker
3433*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3434*da0073e9SAndroid Build Coastguard Worker    def test_gamma_sample(self):
3435*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3436*da0073e9SAndroid Build Coastguard Worker        for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3437*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3438*da0073e9SAndroid Build Coastguard Worker                Gamma(alpha, beta),
3439*da0073e9SAndroid Build Coastguard Worker                scipy.stats.gamma(alpha, scale=1.0 / beta),
3440*da0073e9SAndroid Build Coastguard Worker                f"Gamma(concentration={alpha}, rate={beta})",
3441*da0073e9SAndroid Build Coastguard Worker            )
3442*da0073e9SAndroid Build Coastguard Worker
3443*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3444*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3445*da0073e9SAndroid Build Coastguard Worker    def test_gamma_gpu_sample(self):
3446*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)
3447*da0073e9SAndroid Build Coastguard Worker        for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3448*da0073e9SAndroid Build Coastguard Worker            a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
3449*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3450*da0073e9SAndroid Build Coastguard Worker                Gamma(a, b),
3451*da0073e9SAndroid Build Coastguard Worker                scipy.stats.gamma(alpha, scale=1.0 / beta),
3452*da0073e9SAndroid Build Coastguard Worker                f"Gamma(alpha={alpha}, beta={beta})",
3453*da0073e9SAndroid Build Coastguard Worker                failure_rate=1e-4,
3454*da0073e9SAndroid Build Coastguard Worker            )
3455*da0073e9SAndroid Build Coastguard Worker
3456*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3457*da0073e9SAndroid Build Coastguard Worker    def test_pareto(self):
3458*da0073e9SAndroid Build Coastguard Worker        scale = torch.randn(2, 3).abs().requires_grad_()
3459*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(2, 3).abs().requires_grad_()
3460*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.randn(1).abs().requires_grad_()
3461*da0073e9SAndroid Build Coastguard Worker        alpha_1d = torch.randn(1).abs().requires_grad_()
3462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale_1d, 0.5).mean, inf)
3463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale_1d, 0.5).variance, inf)
3464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
3465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
3466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
3467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,))
3468*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(1.0, 1.0).sample().size(), ())
3469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,))
3470*da0073e9SAndroid Build Coastguard Worker
3471*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3472*da0073e9SAndroid Build Coastguard Worker            s = scale.view(-1)[idx].detach()
3473*da0073e9SAndroid Build Coastguard Worker            a = alpha.view(-1)[idx].detach()
3474*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.pareto.logpdf(x, a, scale=s)
3475*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3476*da0073e9SAndroid Build Coastguard Worker
3477*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3480*da0073e9SAndroid Build Coastguard Worker    def test_pareto_sample(self):
3481*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
3482*da0073e9SAndroid Build Coastguard Worker        for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3483*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3484*da0073e9SAndroid Build Coastguard Worker                Pareto(scale, alpha),
3485*da0073e9SAndroid Build Coastguard Worker                scipy.stats.pareto(alpha, scale=scale),
3486*da0073e9SAndroid Build Coastguard Worker                f"Pareto(scale={scale}, alpha={alpha})",
3487*da0073e9SAndroid Build Coastguard Worker            )
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3490*da0073e9SAndroid Build Coastguard Worker    def test_gumbel(self):
3491*da0073e9SAndroid Build Coastguard Worker        loc = torch.randn(2, 3, requires_grad=True)
3492*da0073e9SAndroid Build Coastguard Worker        scale = torch.randn(2, 3).abs().requires_grad_()
3493*da0073e9SAndroid Build Coastguard Worker        loc_1d = torch.randn(1, requires_grad=True)
3494*da0073e9SAndroid Build Coastguard Worker        scale_1d = torch.randn(1).abs().requires_grad_()
3495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
3496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
3497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
3498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
3499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
3500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
3501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3502*da0073e9SAndroid Build Coastguard Worker            Gumbel(
3503*da0073e9SAndroid Build Coastguard Worker                torch.tensor(0.0, dtype=torch.float32),
3504*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1.0, dtype=torch.float32),
3505*da0073e9SAndroid Build Coastguard Worker                validate_args=False,
3506*da0073e9SAndroid Build Coastguard Worker            ).cdf(20.0),
3507*da0073e9SAndroid Build Coastguard Worker            1.0,
3508*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
3509*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3510*da0073e9SAndroid Build Coastguard Worker        )
3511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3512*da0073e9SAndroid Build Coastguard Worker            Gumbel(
3513*da0073e9SAndroid Build Coastguard Worker                torch.tensor(0.0, dtype=torch.float64),
3514*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1.0, dtype=torch.float64),
3515*da0073e9SAndroid Build Coastguard Worker                validate_args=False,
3516*da0073e9SAndroid Build Coastguard Worker            ).cdf(50.0),
3517*da0073e9SAndroid Build Coastguard Worker            1.0,
3518*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
3519*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3520*da0073e9SAndroid Build Coastguard Worker        )
3521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3522*da0073e9SAndroid Build Coastguard Worker            Gumbel(
3523*da0073e9SAndroid Build Coastguard Worker                torch.tensor(0.0, dtype=torch.float32),
3524*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1.0, dtype=torch.float32),
3525*da0073e9SAndroid Build Coastguard Worker                validate_args=False,
3526*da0073e9SAndroid Build Coastguard Worker            ).cdf(-5.0),
3527*da0073e9SAndroid Build Coastguard Worker            0.0,
3528*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
3529*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3530*da0073e9SAndroid Build Coastguard Worker        )
3531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3532*da0073e9SAndroid Build Coastguard Worker            Gumbel(
3533*da0073e9SAndroid Build Coastguard Worker                torch.tensor(0.0, dtype=torch.float64),
3534*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1.0, dtype=torch.float64),
3535*da0073e9SAndroid Build Coastguard Worker                validate_args=False,
3536*da0073e9SAndroid Build Coastguard Worker            ).cdf(-10.0),
3537*da0073e9SAndroid Build Coastguard Worker            0.0,
3538*da0073e9SAndroid Build Coastguard Worker            atol=1e-8,
3539*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3540*da0073e9SAndroid Build Coastguard Worker        )
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3543*da0073e9SAndroid Build Coastguard Worker            l = loc.view(-1)[idx].detach()
3544*da0073e9SAndroid Build Coastguard Worker            s = scale.view(-1)[idx].detach()
3545*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
3546*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3547*da0073e9SAndroid Build Coastguard Worker
3548*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3551*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3552*da0073e9SAndroid Build Coastguard Worker    def test_gumbel_sample(self):
3553*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see note [Randomized statistical tests]
3554*da0073e9SAndroid Build Coastguard Worker        for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
3555*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3556*da0073e9SAndroid Build Coastguard Worker                Gumbel(loc, scale),
3557*da0073e9SAndroid Build Coastguard Worker                scipy.stats.gumbel_r(loc=loc, scale=scale),
3558*da0073e9SAndroid Build Coastguard Worker                f"Gumbel(loc={loc}, scale={scale})",
3559*da0073e9SAndroid Build Coastguard Worker            )
3560*da0073e9SAndroid Build Coastguard Worker
3561*da0073e9SAndroid Build Coastguard Worker    def test_kumaraswamy_shape(self):
3562*da0073e9SAndroid Build Coastguard Worker        concentration1 = torch.randn(2, 3).abs().requires_grad_()
3563*da0073e9SAndroid Build Coastguard Worker        concentration0 = torch.randn(2, 3).abs().requires_grad_()
3564*da0073e9SAndroid Build Coastguard Worker        concentration1_1d = torch.randn(1).abs().requires_grad_()
3565*da0073e9SAndroid Build Coastguard Worker        concentration0_1d = torch.randn(1).abs().requires_grad_()
3566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3567*da0073e9SAndroid Build Coastguard Worker            Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)
3568*da0073e9SAndroid Build Coastguard Worker        )
3569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3570*da0073e9SAndroid Build Coastguard Worker            Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3)
3571*da0073e9SAndroid Build Coastguard Worker        )
3572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3573*da0073e9SAndroid Build Coastguard Worker            Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,)
3574*da0073e9SAndroid Build Coastguard Worker        )
3575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3576*da0073e9SAndroid Build Coastguard Worker            Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(),
3577*da0073e9SAndroid Build Coastguard Worker            (1, 1),
3578*da0073e9SAndroid Build Coastguard Worker        )
3579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ())
3580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,))
3581*da0073e9SAndroid Build Coastguard Worker
3582*da0073e9SAndroid Build Coastguard Worker    # Kumaraswamy distribution is not implemented in SciPy
3583*da0073e9SAndroid Build Coastguard Worker    # Hence these tests are explicit
3584*da0073e9SAndroid Build Coastguard Worker    def test_kumaraswamy_mean_variance(self):
3585*da0073e9SAndroid Build Coastguard Worker        c1_1 = torch.randn(2, 3).abs().requires_grad_()
3586*da0073e9SAndroid Build Coastguard Worker        c0_1 = torch.randn(2, 3).abs().requires_grad_()
3587*da0073e9SAndroid Build Coastguard Worker        c1_2 = torch.randn(4).abs().requires_grad_()
3588*da0073e9SAndroid Build Coastguard Worker        c0_2 = torch.randn(4).abs().requires_grad_()
3589*da0073e9SAndroid Build Coastguard Worker        cases = [(c1_1, c0_1), (c1_2, c0_2)]
3590*da0073e9SAndroid Build Coastguard Worker        for i, (a, b) in enumerate(cases):
3591*da0073e9SAndroid Build Coastguard Worker            m = Kumaraswamy(a, b)
3592*da0073e9SAndroid Build Coastguard Worker            samples = m.sample((60000,))
3593*da0073e9SAndroid Build Coastguard Worker            expected = samples.mean(0)
3594*da0073e9SAndroid Build Coastguard Worker            actual = m.mean
3595*da0073e9SAndroid Build Coastguard Worker            error = (expected - actual).abs()
3596*da0073e9SAndroid Build Coastguard Worker            max_error = max(error[error == error])
3597*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
3598*da0073e9SAndroid Build Coastguard Worker                max_error,
3599*da0073e9SAndroid Build Coastguard Worker                0.01,
3600*da0073e9SAndroid Build Coastguard Worker                f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .mean",
3601*da0073e9SAndroid Build Coastguard Worker            )
3602*da0073e9SAndroid Build Coastguard Worker            expected = samples.var(0)
3603*da0073e9SAndroid Build Coastguard Worker            actual = m.variance
3604*da0073e9SAndroid Build Coastguard Worker            error = (expected - actual).abs()
3605*da0073e9SAndroid Build Coastguard Worker            max_error = max(error[error == error])
3606*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
3607*da0073e9SAndroid Build Coastguard Worker                max_error,
3608*da0073e9SAndroid Build Coastguard Worker                0.01,
3609*da0073e9SAndroid Build Coastguard Worker                f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .variance",
3610*da0073e9SAndroid Build Coastguard Worker            )
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3613*da0073e9SAndroid Build Coastguard Worker    def test_fishersnedecor(self):
3614*da0073e9SAndroid Build Coastguard Worker        df1 = torch.randn(2, 3).abs().requires_grad_()
3615*da0073e9SAndroid Build Coastguard Worker        df2 = torch.randn(2, 3).abs().requires_grad_()
3616*da0073e9SAndroid Build Coastguard Worker        df1_1d = torch.randn(1).abs()
3617*da0073e9SAndroid Build Coastguard Worker        df2_1d = torch.randn(1).abs()
3618*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
3619*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance))
3620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
3621*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3))
3622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
3623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1))
3624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ())
3625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,))
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3628*da0073e9SAndroid Build Coastguard Worker            f1 = df1.view(-1)[idx].detach()
3629*da0073e9SAndroid Build Coastguard Worker            f2 = df2.view(-1)[idx].detach()
3630*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.f.logpdf(x, f1, f2)
3631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3636*da0073e9SAndroid Build Coastguard Worker    def test_fishersnedecor_sample(self):
3637*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see note [Randomized statistical tests]
3638*da0073e9SAndroid Build Coastguard Worker        for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
3639*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3640*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(df1, df2),
3641*da0073e9SAndroid Build Coastguard Worker                scipy.stats.f(df1, df2),
3642*da0073e9SAndroid Build Coastguard Worker                f"FisherSnedecor(loc={df1}, scale={df2})",
3643*da0073e9SAndroid Build Coastguard Worker            )
3644*da0073e9SAndroid Build Coastguard Worker
3645*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3646*da0073e9SAndroid Build Coastguard Worker    def test_chi2_shape(self):
3647*da0073e9SAndroid Build Coastguard Worker        df = torch.randn(2, 3).exp().requires_grad_()
3648*da0073e9SAndroid Build Coastguard Worker        df_1d = torch.randn(1).exp().requires_grad_()
3649*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(df).sample().size(), (2, 3))
3650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
3651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
3652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(df_1d).sample().size(), (1,))
3653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3654*da0073e9SAndroid Build Coastguard Worker            Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), ()
3655*da0073e9SAndroid Build Coastguard Worker        )
3656*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(0.5).sample().size(), ())
3657*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,))
3658*da0073e9SAndroid Build Coastguard Worker
3659*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3660*da0073e9SAndroid Build Coastguard Worker            d = df.view(-1)[idx].detach()
3661*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.chi2.logpdf(x, d)
3662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3663*da0073e9SAndroid Build Coastguard Worker
3664*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(Chi2(df), ref_log_prob)
3665*da0073e9SAndroid Build Coastguard Worker
3666*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3667*da0073e9SAndroid Build Coastguard Worker    def test_chi2_sample(self):
3668*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3669*da0073e9SAndroid Build Coastguard Worker        for df in [0.1, 1.0, 5.0]:
3670*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3671*da0073e9SAndroid Build Coastguard Worker                Chi2(df), scipy.stats.chi2(df), f"Chi2(df={df})"
3672*da0073e9SAndroid Build Coastguard Worker            )
3673*da0073e9SAndroid Build Coastguard Worker
3674*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3675*da0073e9SAndroid Build Coastguard Worker    def test_studentT(self):
3676*da0073e9SAndroid Build Coastguard Worker        df = torch.randn(2, 3).exp().requires_grad_()
3677*da0073e9SAndroid Build Coastguard Worker        df_1d = torch.randn(1).exp().requires_grad_()
3678*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(StudentT(1).mean))
3679*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_all_nan(StudentT(1).variance))
3680*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(2).variance, inf)
3681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(df).sample().size(), (2, 3))
3682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3))
3683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1))
3684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(df_1d).sample().size(), (1,))
3685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3686*da0073e9SAndroid Build Coastguard Worker            StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), ()
3687*da0073e9SAndroid Build Coastguard Worker        )
3688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(0.5).sample().size(), ())
3689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,))
3690*da0073e9SAndroid Build Coastguard Worker
3691*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, x, log_prob):
3692*da0073e9SAndroid Build Coastguard Worker            d = df.view(-1)[idx].detach()
3693*da0073e9SAndroid Build Coastguard Worker            expected = scipy.stats.t.logpdf(x, d)
3694*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3695*da0073e9SAndroid Build Coastguard Worker
3696*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(StudentT(df), ref_log_prob)
3697*da0073e9SAndroid Build Coastguard Worker
3698*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3699*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3700*da0073e9SAndroid Build Coastguard Worker    def test_studentT_sample(self):
3701*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(11)  # see Note [Randomized statistical tests]
3702*da0073e9SAndroid Build Coastguard Worker        for df, loc, scale in product(
3703*da0073e9SAndroid Build Coastguard Worker            [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]
3704*da0073e9SAndroid Build Coastguard Worker        ):
3705*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3706*da0073e9SAndroid Build Coastguard Worker                StudentT(df=df, loc=loc, scale=scale),
3707*da0073e9SAndroid Build Coastguard Worker                scipy.stats.t(df=df, loc=loc, scale=scale),
3708*da0073e9SAndroid Build Coastguard Worker                f"StudentT(df={df}, loc={loc}, scale={scale})",
3709*da0073e9SAndroid Build Coastguard Worker            )
3710*da0073e9SAndroid Build Coastguard Worker
3711*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3712*da0073e9SAndroid Build Coastguard Worker    def test_studentT_log_prob(self):
3713*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3714*da0073e9SAndroid Build Coastguard Worker        num_samples = 10
3715*da0073e9SAndroid Build Coastguard Worker        for df, loc, scale in product(
3716*da0073e9SAndroid Build Coastguard Worker            [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]
3717*da0073e9SAndroid Build Coastguard Worker        ):
3718*da0073e9SAndroid Build Coastguard Worker            dist = StudentT(df=df, loc=loc, scale=scale)
3719*da0073e9SAndroid Build Coastguard Worker            x = dist.sample((num_samples,))
3720*da0073e9SAndroid Build Coastguard Worker            actual_log_prob = dist.log_prob(x)
3721*da0073e9SAndroid Build Coastguard Worker            for i in range(num_samples):
3722*da0073e9SAndroid Build Coastguard Worker                expected_log_prob = scipy.stats.t.logpdf(
3723*da0073e9SAndroid Build Coastguard Worker                    x[i], df=df, loc=loc, scale=scale
3724*da0073e9SAndroid Build Coastguard Worker                )
3725*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
3726*da0073e9SAndroid Build Coastguard Worker                    float(actual_log_prob[i]),
3727*da0073e9SAndroid Build Coastguard Worker                    float(expected_log_prob),
3728*da0073e9SAndroid Build Coastguard Worker                    atol=1e-3,
3729*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
3730*da0073e9SAndroid Build Coastguard Worker                )
3731*da0073e9SAndroid Build Coastguard Worker
3732*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_shape(self):
3733*da0073e9SAndroid Build Coastguard Worker        alpha = torch.randn(2, 3).exp().requires_grad_()
3734*da0073e9SAndroid Build Coastguard Worker        alpha_1d = torch.randn(4).exp().requires_grad_()
3735*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
3736*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
3737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
3738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
3739*da0073e9SAndroid Build Coastguard Worker
3740*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3741*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3742*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_log_prob(self):
3743*da0073e9SAndroid Build Coastguard Worker        num_samples = 10
3744*da0073e9SAndroid Build Coastguard Worker        alpha = torch.exp(torch.randn(5))
3745*da0073e9SAndroid Build Coastguard Worker        dist = Dirichlet(alpha)
3746*da0073e9SAndroid Build Coastguard Worker        x = dist.sample((num_samples,))
3747*da0073e9SAndroid Build Coastguard Worker        actual_log_prob = dist.log_prob(x)
3748*da0073e9SAndroid Build Coastguard Worker        for i in range(num_samples):
3749*da0073e9SAndroid Build Coastguard Worker            expected_log_prob = scipy.stats.dirichlet.logpdf(
3750*da0073e9SAndroid Build Coastguard Worker                x[i].numpy(), alpha.numpy()
3751*da0073e9SAndroid Build Coastguard Worker            )
3752*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
3753*da0073e9SAndroid Build Coastguard Worker
3754*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3755*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_log_prob_zero(self):
3756*da0073e9SAndroid Build Coastguard Worker        # Specifically test the special case where x=0 and alpha=1.  The PDF is
3757*da0073e9SAndroid Build Coastguard Worker        # proportional to x**(alpha-1), which in this case works out to 0**0=1.
3758*da0073e9SAndroid Build Coastguard Worker        # The log PDF of this term should therefore be 0.  However, it's easy
3759*da0073e9SAndroid Build Coastguard Worker        # to accidentally introduce NaNs by calculating log(x) without regard
3760*da0073e9SAndroid Build Coastguard Worker        # for the value of alpha-1.
3761*da0073e9SAndroid Build Coastguard Worker        alpha = torch.tensor([1, 2])
3762*da0073e9SAndroid Build Coastguard Worker        dist = Dirichlet(alpha)
3763*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1])
3764*da0073e9SAndroid Build Coastguard Worker        actual_log_prob = dist.log_prob(x)
3765*da0073e9SAndroid Build Coastguard Worker        expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy())
3766*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0)
3767*da0073e9SAndroid Build Coastguard Worker
3768*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3769*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_sample(self):
3770*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
3771*da0073e9SAndroid Build Coastguard Worker        alpha = torch.exp(torch.randn(3))
3772*da0073e9SAndroid Build Coastguard Worker        self._check_sampler_sampler(
3773*da0073e9SAndroid Build Coastguard Worker            Dirichlet(alpha),
3774*da0073e9SAndroid Build Coastguard Worker            scipy.stats.dirichlet(alpha.numpy()),
3775*da0073e9SAndroid Build Coastguard Worker            f"Dirichlet(alpha={list(alpha)})",
3776*da0073e9SAndroid Build Coastguard Worker            multivariate=True,
3777*da0073e9SAndroid Build Coastguard Worker        )
3778*da0073e9SAndroid Build Coastguard Worker
3779*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_mode(self):
3780*da0073e9SAndroid Build Coastguard Worker        # Test a few edge cases for the Dirichlet distribution mode. This also covers beta distributions.
3781*da0073e9SAndroid Build Coastguard Worker        concentrations_and_modes = [
3782*da0073e9SAndroid Build Coastguard Worker            ([2, 2, 1], [0.5, 0.5, 0.0]),
3783*da0073e9SAndroid Build Coastguard Worker            ([3, 2, 1], [2 / 3, 1 / 3, 0]),
3784*da0073e9SAndroid Build Coastguard Worker            ([0.5, 0.2, 0.2], [1.0, 0.0, 0.0]),
3785*da0073e9SAndroid Build Coastguard Worker            ([1, 1, 1], [nan, nan, nan]),
3786*da0073e9SAndroid Build Coastguard Worker        ]
3787*da0073e9SAndroid Build Coastguard Worker        for concentration, mode in concentrations_and_modes:
3788*da0073e9SAndroid Build Coastguard Worker            dist = Dirichlet(torch.tensor(concentration))
3789*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dist.mode, torch.tensor(mode))
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker    def test_beta_shape(self):
3792*da0073e9SAndroid Build Coastguard Worker        con1 = torch.randn(2, 3).exp().requires_grad_()
3793*da0073e9SAndroid Build Coastguard Worker        con0 = torch.randn(2, 3).exp().requires_grad_()
3794*da0073e9SAndroid Build Coastguard Worker        con1_1d = torch.randn(4).exp().requires_grad_()
3795*da0073e9SAndroid Build Coastguard Worker        con0_1d = torch.randn(4).exp().requires_grad_()
3796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
3797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
3798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
3799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4))
3800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(0.1, 0.3).sample().size(), ())
3801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
3802*da0073e9SAndroid Build Coastguard Worker
3803*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3804*da0073e9SAndroid Build Coastguard Worker    def test_beta_log_prob(self):
3805*da0073e9SAndroid Build Coastguard Worker        for _ in range(100):
3806*da0073e9SAndroid Build Coastguard Worker            con1 = np.exp(np.random.normal())
3807*da0073e9SAndroid Build Coastguard Worker            con0 = np.exp(np.random.normal())
3808*da0073e9SAndroid Build Coastguard Worker            dist = Beta(con1, con0)
3809*da0073e9SAndroid Build Coastguard Worker            x = dist.sample()
3810*da0073e9SAndroid Build Coastguard Worker            actual_log_prob = dist.log_prob(x).sum()
3811*da0073e9SAndroid Build Coastguard Worker            expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
3812*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3813*da0073e9SAndroid Build Coastguard Worker                float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0
3814*da0073e9SAndroid Build Coastguard Worker            )
3815*da0073e9SAndroid Build Coastguard Worker
3816*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3817*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3818*da0073e9SAndroid Build Coastguard Worker    def test_beta_sample(self):
3819*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)  # see Note [Randomized statistical tests]
3820*da0073e9SAndroid Build Coastguard Worker        for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
3821*da0073e9SAndroid Build Coastguard Worker            self._check_sampler_sampler(
3822*da0073e9SAndroid Build Coastguard Worker                Beta(con1, con0),
3823*da0073e9SAndroid Build Coastguard Worker                scipy.stats.beta(con1, con0),
3824*da0073e9SAndroid Build Coastguard Worker                f"Beta(alpha={con1}, beta={con0})",
3825*da0073e9SAndroid Build Coastguard Worker            )
3826*da0073e9SAndroid Build Coastguard Worker        # Check that small alphas do not cause NANs.
3827*da0073e9SAndroid Build Coastguard Worker        for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
3828*da0073e9SAndroid Build Coastguard Worker            x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
3829*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(np.isfinite(x) and x > 0, f"Invalid Beta.sample(): {x}")
3830*da0073e9SAndroid Build Coastguard Worker
3831*da0073e9SAndroid Build Coastguard Worker    def test_beta_underflow(self):
3832*da0073e9SAndroid Build Coastguard Worker        # For low values of (alpha, beta), the gamma samples can underflow
3833*da0073e9SAndroid Build Coastguard Worker        # with float32 and result in a spurious mode at 0.5. To prevent this,
3834*da0073e9SAndroid Build Coastguard Worker        # torch._sample_dirichlet works with double precision for intermediate
3835*da0073e9SAndroid Build Coastguard Worker        # calculations.
3836*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
3837*da0073e9SAndroid Build Coastguard Worker        num_samples = 50000
3838*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
3839*da0073e9SAndroid Build Coastguard Worker            conc = torch.tensor(1e-2, dtype=dtype)
3840*da0073e9SAndroid Build Coastguard Worker            beta_samples = Beta(conc, conc).sample([num_samples])
3841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((beta_samples == 0).sum(), 0)
3842*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((beta_samples == 1).sum(), 0)
3843*da0073e9SAndroid Build Coastguard Worker            # assert support is concentrated around 0 and 1
3844*da0073e9SAndroid Build Coastguard Worker            frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
3845*da0073e9SAndroid Build Coastguard Worker            frac_ones = float((beta_samples > 0.9).sum()) / num_samples
3846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0)
3847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0)
3848*da0073e9SAndroid Build Coastguard Worker
3849*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
3850*da0073e9SAndroid Build Coastguard Worker    def test_beta_underflow_gpu(self):
3851*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(1)
3852*da0073e9SAndroid Build Coastguard Worker        num_samples = 50000
3853*da0073e9SAndroid Build Coastguard Worker        conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
3854*da0073e9SAndroid Build Coastguard Worker        beta_samples = Beta(conc, conc).sample([num_samples])
3855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((beta_samples == 0).sum(), 0)
3856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((beta_samples == 1).sum(), 0)
3857*da0073e9SAndroid Build Coastguard Worker        # assert support is concentrated around 0 and 1
3858*da0073e9SAndroid Build Coastguard Worker        frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
3859*da0073e9SAndroid Build Coastguard Worker        frac_ones = float((beta_samples > 0.9).sum()) / num_samples
3860*da0073e9SAndroid Build Coastguard Worker        # TODO: increase precision once imbalance on GPU is fixed.
3861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0)
3862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0)
3863*da0073e9SAndroid Build Coastguard Worker
3864*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3865*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli(self):
3866*da0073e9SAndroid Build Coastguard Worker        p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
3867*da0073e9SAndroid Build Coastguard Worker        r = torch.tensor(0.3, requires_grad=True)
3868*da0073e9SAndroid Build Coastguard Worker        s = 0.3
3869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
3870*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
3871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
3872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
3873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3874*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(r).sample((3, 2)).size(),
3875*da0073e9SAndroid Build Coastguard Worker            (
3876*da0073e9SAndroid Build Coastguard Worker                3,
3877*da0073e9SAndroid Build Coastguard Worker                2,
3878*da0073e9SAndroid Build Coastguard Worker            ),
3879*da0073e9SAndroid Build Coastguard Worker        )
3880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
3881*da0073e9SAndroid Build Coastguard Worker        self._gradcheck_log_prob(ContinuousBernoulli, (p,))
3882*da0073e9SAndroid Build Coastguard Worker
3883*da0073e9SAndroid Build Coastguard Worker        def ref_log_prob(idx, val, log_prob):
3884*da0073e9SAndroid Build Coastguard Worker            prob = p[idx]
3885*da0073e9SAndroid Build Coastguard Worker            if prob > 0.499 and prob < 0.501:  # using default value of lim here
3886*da0073e9SAndroid Build Coastguard Worker                log_norm_const = (
3887*da0073e9SAndroid Build Coastguard Worker                    math.log(2.0)
3888*da0073e9SAndroid Build Coastguard Worker                    + 4.0 / 3.0 * math.pow(prob - 0.5, 2)
3889*da0073e9SAndroid Build Coastguard Worker                    + 104.0 / 45.0 * math.pow(prob - 0.5, 4)
3890*da0073e9SAndroid Build Coastguard Worker                )
3891*da0073e9SAndroid Build Coastguard Worker            else:
3892*da0073e9SAndroid Build Coastguard Worker                log_norm_const = math.log(
3893*da0073e9SAndroid Build Coastguard Worker                    2.0 * math.atanh(1.0 - 2.0 * prob) / (1.0 - 2.0 * prob)
3894*da0073e9SAndroid Build Coastguard Worker                )
3895*da0073e9SAndroid Build Coastguard Worker            res = (
3896*da0073e9SAndroid Build Coastguard Worker                val * math.log(prob) + (1.0 - val) * math.log1p(-prob) + log_norm_const
3897*da0073e9SAndroid Build Coastguard Worker            )
3898*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_prob, res)
3899*da0073e9SAndroid Build Coastguard Worker
3900*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
3901*da0073e9SAndroid Build Coastguard Worker        self._check_log_prob(
3902*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob
3903*da0073e9SAndroid Build Coastguard Worker        )
3904*da0073e9SAndroid Build Coastguard Worker
3905*da0073e9SAndroid Build Coastguard Worker        # check entropy computation
3906*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3907*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(p).entropy(),
3908*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-0.02938, -0.07641, -0.00682]),
3909*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
3910*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3911*da0073e9SAndroid Build Coastguard Worker        )
3912*da0073e9SAndroid Build Coastguard Worker        # entropy below corresponds to the clamped value of prob when using float 64
3913*da0073e9SAndroid Build Coastguard Worker        # the value for float32 should be -1.76898
3914*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3915*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(torch.tensor([0.0])).entropy(),
3916*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-2.58473]),
3917*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
3918*da0073e9SAndroid Build Coastguard Worker            rtol=0,
3919*da0073e9SAndroid Build Coastguard Worker        )
3920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3921*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4, rtol=0
3922*da0073e9SAndroid Build Coastguard Worker        )
3923*da0073e9SAndroid Build Coastguard Worker
3924*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_3d(self):
3925*da0073e9SAndroid Build Coastguard Worker        p = torch.full((2, 3, 5), 0.5).requires_grad_()
3926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
3927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3928*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5)
3929*da0073e9SAndroid Build Coastguard Worker        )
3930*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Worker    def test_lkj_cholesky_log_prob(self):
3933*da0073e9SAndroid Build Coastguard Worker        def tril_cholesky_to_tril_corr(x):
3934*da0073e9SAndroid Build Coastguard Worker            x = vec_to_tril_matrix(x, -1)
3935*da0073e9SAndroid Build Coastguard Worker            diag = (1 - (x * x).sum(-1)).sqrt().diag_embed()
3936*da0073e9SAndroid Build Coastguard Worker            x = x + diag
3937*da0073e9SAndroid Build Coastguard Worker            return tril_matrix_to_vec(x @ x.T, -1)
3938*da0073e9SAndroid Build Coastguard Worker
3939*da0073e9SAndroid Build Coastguard Worker        for dim in range(2, 5):
3940*da0073e9SAndroid Build Coastguard Worker            log_probs = []
3941*da0073e9SAndroid Build Coastguard Worker            lkj = LKJCholesky(dim, concentration=1.0, validate_args=True)
3942*da0073e9SAndroid Build Coastguard Worker            for i in range(2):
3943*da0073e9SAndroid Build Coastguard Worker                sample = lkj.sample()
3944*da0073e9SAndroid Build Coastguard Worker                sample_tril = tril_matrix_to_vec(sample, diag=-1)
3945*da0073e9SAndroid Build Coastguard Worker                log_prob = lkj.log_prob(sample)
3946*da0073e9SAndroid Build Coastguard Worker                log_abs_det_jacobian = torch.slogdet(
3947*da0073e9SAndroid Build Coastguard Worker                    jacobian(tril_cholesky_to_tril_corr, sample_tril)
3948*da0073e9SAndroid Build Coastguard Worker                ).logabsdet
3949*da0073e9SAndroid Build Coastguard Worker                log_probs.append(log_prob - log_abs_det_jacobian)
3950*da0073e9SAndroid Build Coastguard Worker            # for concentration=1., the density is uniform over the space of all
3951*da0073e9SAndroid Build Coastguard Worker            # correlation matrices.
3952*da0073e9SAndroid Build Coastguard Worker            if dim == 2:
3953*da0073e9SAndroid Build Coastguard Worker                # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.)
3954*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
3955*da0073e9SAndroid Build Coastguard Worker                    all(
3956*da0073e9SAndroid Build Coastguard Worker                        torch.allclose(x, torch.tensor(0.5).log(), atol=1e-10)
3957*da0073e9SAndroid Build Coastguard Worker                        for x in log_probs
3958*da0073e9SAndroid Build Coastguard Worker                    )
3959*da0073e9SAndroid Build Coastguard Worker                )
3960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_probs[0], log_probs[1])
3961*da0073e9SAndroid Build Coastguard Worker            invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0)
3962*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample))
3963*da0073e9SAndroid Build Coastguard Worker
3964*da0073e9SAndroid Build Coastguard Worker    def test_independent_shape(self):
3965*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
3966*da0073e9SAndroid Build Coastguard Worker            for param in params:
3967*da0073e9SAndroid Build Coastguard Worker                base_dist = Dist(**param)
3968*da0073e9SAndroid Build Coastguard Worker                x = base_dist.sample()
3969*da0073e9SAndroid Build Coastguard Worker                base_log_prob_shape = base_dist.log_prob(x).shape
3970*da0073e9SAndroid Build Coastguard Worker                for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
3971*da0073e9SAndroid Build Coastguard Worker                    indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
3972*da0073e9SAndroid Build Coastguard Worker                    indep_log_prob_shape = base_log_prob_shape[
3973*da0073e9SAndroid Build Coastguard Worker                        : len(base_log_prob_shape) - reinterpreted_batch_ndims
3974*da0073e9SAndroid Build Coastguard Worker                    ]
3975*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape)
3976*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
3977*da0073e9SAndroid Build Coastguard Worker                        indep_dist.sample().shape, base_dist.sample().shape
3978*da0073e9SAndroid Build Coastguard Worker                    )
3979*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
3980*da0073e9SAndroid Build Coastguard Worker                    if indep_dist.has_rsample:
3981*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
3982*da0073e9SAndroid Build Coastguard Worker                            indep_dist.sample().shape, base_dist.sample().shape
3983*da0073e9SAndroid Build Coastguard Worker                        )
3984*da0073e9SAndroid Build Coastguard Worker                    try:
3985*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
3986*da0073e9SAndroid Build Coastguard Worker                            indep_dist.enumerate_support().shape,
3987*da0073e9SAndroid Build Coastguard Worker                            base_dist.enumerate_support().shape,
3988*da0073e9SAndroid Build Coastguard Worker                        )
3989*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
3990*da0073e9SAndroid Build Coastguard Worker                    except NotImplementedError:
3991*da0073e9SAndroid Build Coastguard Worker                        pass
3992*da0073e9SAndroid Build Coastguard Worker                    try:
3993*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
3994*da0073e9SAndroid Build Coastguard Worker                            indep_dist.variance.shape, base_dist.variance.shape
3995*da0073e9SAndroid Build Coastguard Worker                        )
3996*da0073e9SAndroid Build Coastguard Worker                    except NotImplementedError:
3997*da0073e9SAndroid Build Coastguard Worker                        pass
3998*da0073e9SAndroid Build Coastguard Worker                    try:
3999*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
4000*da0073e9SAndroid Build Coastguard Worker                            indep_dist.entropy().shape, indep_log_prob_shape
4001*da0073e9SAndroid Build Coastguard Worker                        )
4002*da0073e9SAndroid Build Coastguard Worker                    except NotImplementedError:
4003*da0073e9SAndroid Build Coastguard Worker                        pass
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker    def test_independent_expand(self):
4006*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
4007*da0073e9SAndroid Build Coastguard Worker            for param in params:
4008*da0073e9SAndroid Build Coastguard Worker                base_dist = Dist(**param)
4009*da0073e9SAndroid Build Coastguard Worker                for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
4010*da0073e9SAndroid Build Coastguard Worker                    for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]:
4011*da0073e9SAndroid Build Coastguard Worker                        indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
4012*da0073e9SAndroid Build Coastguard Worker                        expanded_shape = s + indep_dist.batch_shape
4013*da0073e9SAndroid Build Coastguard Worker                        expanded = indep_dist.expand(expanded_shape)
4014*da0073e9SAndroid Build Coastguard Worker                        expanded_sample = expanded.sample()
4015*da0073e9SAndroid Build Coastguard Worker                        expected_shape = expanded_shape + indep_dist.event_shape
4016*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expanded_sample.shape, expected_shape)
4017*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
4018*da0073e9SAndroid Build Coastguard Worker                            expanded.log_prob(expanded_sample),
4019*da0073e9SAndroid Build Coastguard Worker                            indep_dist.log_prob(expanded_sample),
4020*da0073e9SAndroid Build Coastguard Worker                        )
4021*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expanded.event_shape, indep_dist.event_shape)
4022*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(expanded.batch_shape, expanded_shape)
4023*da0073e9SAndroid Build Coastguard Worker
4024*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4025*da0073e9SAndroid Build Coastguard Worker    def test_cdf_icdf_inverse(self):
4026*da0073e9SAndroid Build Coastguard Worker        # Tests the invertibility property on the distributions
4027*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
4028*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
4029*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
4030*da0073e9SAndroid Build Coastguard Worker                samples = dist.sample(sample_shape=(20,))
4031*da0073e9SAndroid Build Coastguard Worker                try:
4032*da0073e9SAndroid Build Coastguard Worker                    cdf = dist.cdf(samples)
4033*da0073e9SAndroid Build Coastguard Worker                    actual = dist.icdf(cdf)
4034*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
4035*da0073e9SAndroid Build Coastguard Worker                    continue
4036*da0073e9SAndroid Build Coastguard Worker                rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
4037*da0073e9SAndroid Build Coastguard Worker                self.assertLess(
4038*da0073e9SAndroid Build Coastguard Worker                    rel_error.max(),
4039*da0073e9SAndroid Build Coastguard Worker                    1e-4,
4040*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
4041*da0073e9SAndroid Build Coastguard Worker                        [
4042*da0073e9SAndroid Build Coastguard Worker                            f"{Dist.__name__} example {i + 1}/{len(params)}, icdf(cdf(x)) != x",
4043*da0073e9SAndroid Build Coastguard Worker                            f"x = {samples}",
4044*da0073e9SAndroid Build Coastguard Worker                            f"cdf(x) = {cdf}",
4045*da0073e9SAndroid Build Coastguard Worker                            f"icdf(cdf(x)) = {actual}",
4046*da0073e9SAndroid Build Coastguard Worker                        ]
4047*da0073e9SAndroid Build Coastguard Worker                    ),
4048*da0073e9SAndroid Build Coastguard Worker                )
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4051*da0073e9SAndroid Build Coastguard Worker    def test_gamma_log_prob_at_boundary(self):
4052*da0073e9SAndroid Build Coastguard Worker        for concentration, log_prob in [(0.5, inf), (1, 0), (2, -inf)]:
4053*da0073e9SAndroid Build Coastguard Worker            dist = Gamma(concentration, 1)
4054*da0073e9SAndroid Build Coastguard Worker            scipy_dist = scipy.stats.gamma(concentration)
4055*da0073e9SAndroid Build Coastguard Worker            self.assertAlmostEqual(dist.log_prob(0), log_prob)
4056*da0073e9SAndroid Build Coastguard Worker            self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0))
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4059*da0073e9SAndroid Build Coastguard Worker    def test_cdf_log_prob(self):
4060*da0073e9SAndroid Build Coastguard Worker        # Tests if the differentiation of the CDF gives the PDF at a given value
4061*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
4062*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
4063*da0073e9SAndroid Build Coastguard Worker                # We do not need grads wrt params here, e.g. shape of gamma distribution.
4064*da0073e9SAndroid Build Coastguard Worker                param = {
4065*da0073e9SAndroid Build Coastguard Worker                    key: value.detach() if isinstance(value, torch.Tensor) else value
4066*da0073e9SAndroid Build Coastguard Worker                    for key, value in param.items()
4067*da0073e9SAndroid Build Coastguard Worker                }
4068*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
4069*da0073e9SAndroid Build Coastguard Worker                samples = dist.sample()
4070*da0073e9SAndroid Build Coastguard Worker                if not dist.support.is_discrete:
4071*da0073e9SAndroid Build Coastguard Worker                    samples.requires_grad_()
4072*da0073e9SAndroid Build Coastguard Worker                try:
4073*da0073e9SAndroid Build Coastguard Worker                    cdfs = dist.cdf(samples)
4074*da0073e9SAndroid Build Coastguard Worker                    pdfs = dist.log_prob(samples).exp()
4075*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
4076*da0073e9SAndroid Build Coastguard Worker                    continue
4077*da0073e9SAndroid Build Coastguard Worker                cdfs_derivative = grad(cdfs.sum(), [samples])[
4078*da0073e9SAndroid Build Coastguard Worker                    0
4079*da0073e9SAndroid Build Coastguard Worker                ]  # this should not be wrapped in torch.abs()
4080*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4081*da0073e9SAndroid Build Coastguard Worker                    cdfs_derivative,
4082*da0073e9SAndroid Build Coastguard Worker                    pdfs,
4083*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
4084*da0073e9SAndroid Build Coastguard Worker                        [
4085*da0073e9SAndroid Build Coastguard Worker                            f"{Dist.__name__} example {i + 1}/{len(params)}, d(cdf)/dx != pdf(x)",
4086*da0073e9SAndroid Build Coastguard Worker                            f"x = {samples}",
4087*da0073e9SAndroid Build Coastguard Worker                            f"cdf = {cdfs}",
4088*da0073e9SAndroid Build Coastguard Worker                            f"pdf = {pdfs}",
4089*da0073e9SAndroid Build Coastguard Worker                            f"grad(cdf) = {cdfs_derivative}",
4090*da0073e9SAndroid Build Coastguard Worker                        ]
4091*da0073e9SAndroid Build Coastguard Worker                    ),
4092*da0073e9SAndroid Build Coastguard Worker                )
4093*da0073e9SAndroid Build Coastguard Worker
4094*da0073e9SAndroid Build Coastguard Worker    def test_valid_parameter_broadcasting(self):
4095*da0073e9SAndroid Build Coastguard Worker        # Test correct broadcasting of parameter sizes for distributions that have multiple
4096*da0073e9SAndroid Build Coastguard Worker        # parameters.
4097*da0073e9SAndroid Build Coastguard Worker        # example type (distribution instance, expected sample shape)
4098*da0073e9SAndroid Build Coastguard Worker        valid_examples = [
4099*da0073e9SAndroid Build Coastguard Worker            (Normal(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4100*da0073e9SAndroid Build Coastguard Worker            (Normal(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4101*da0073e9SAndroid Build Coastguard Worker            (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4102*da0073e9SAndroid Build Coastguard Worker            (
4103*da0073e9SAndroid Build Coastguard Worker                Normal(
4104*da0073e9SAndroid Build Coastguard Worker                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4105*da0073e9SAndroid Build Coastguard Worker                ),
4106*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4107*da0073e9SAndroid Build Coastguard Worker            ),
4108*da0073e9SAndroid Build Coastguard Worker            (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)),
4109*da0073e9SAndroid Build Coastguard Worker            (Normal(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4110*da0073e9SAndroid Build Coastguard Worker            (FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=1), (2,)),
4111*da0073e9SAndroid Build Coastguard Worker            (FisherSnedecor(df1=1, df2=torch.tensor([1.0, 1.0])), (2,)),
4112*da0073e9SAndroid Build Coastguard Worker            (
4113*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([1.0])),
4114*da0073e9SAndroid Build Coastguard Worker                (2,),
4115*da0073e9SAndroid Build Coastguard Worker            ),
4116*da0073e9SAndroid Build Coastguard Worker            (
4117*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(
4118*da0073e9SAndroid Build Coastguard Worker                    df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0], [1.0]])
4119*da0073e9SAndroid Build Coastguard Worker                ),
4120*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4121*da0073e9SAndroid Build Coastguard Worker            ),
4122*da0073e9SAndroid Build Coastguard Worker            (
4123*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0]])),
4124*da0073e9SAndroid Build Coastguard Worker                (1, 2),
4125*da0073e9SAndroid Build Coastguard Worker            ),
4126*da0073e9SAndroid Build Coastguard Worker            (
4127*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(df1=torch.tensor([1.0]), df2=torch.tensor([[1.0]])),
4128*da0073e9SAndroid Build Coastguard Worker                (1, 1),
4129*da0073e9SAndroid Build Coastguard Worker            ),
4130*da0073e9SAndroid Build Coastguard Worker            (Gamma(concentration=torch.tensor([1.0, 1.0]), rate=1), (2,)),
4131*da0073e9SAndroid Build Coastguard Worker            (Gamma(concentration=1, rate=torch.tensor([1.0, 1.0])), (2,)),
4132*da0073e9SAndroid Build Coastguard Worker            (
4133*da0073e9SAndroid Build Coastguard Worker                Gamma(
4134*da0073e9SAndroid Build Coastguard Worker                    concentration=torch.tensor([1.0, 1.0]),
4135*da0073e9SAndroid Build Coastguard Worker                    rate=torch.tensor([[1.0], [1.0], [1.0]]),
4136*da0073e9SAndroid Build Coastguard Worker                ),
4137*da0073e9SAndroid Build Coastguard Worker                (3, 2),
4138*da0073e9SAndroid Build Coastguard Worker            ),
4139*da0073e9SAndroid Build Coastguard Worker            (
4140*da0073e9SAndroid Build Coastguard Worker                Gamma(
4141*da0073e9SAndroid Build Coastguard Worker                    concentration=torch.tensor([1.0, 1.0]),
4142*da0073e9SAndroid Build Coastguard Worker                    rate=torch.tensor([[1.0], [1.0]]),
4143*da0073e9SAndroid Build Coastguard Worker                ),
4144*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4145*da0073e9SAndroid Build Coastguard Worker            ),
4146*da0073e9SAndroid Build Coastguard Worker            (
4147*da0073e9SAndroid Build Coastguard Worker                Gamma(
4148*da0073e9SAndroid Build Coastguard Worker                    concentration=torch.tensor([1.0, 1.0]), rate=torch.tensor([[1.0]])
4149*da0073e9SAndroid Build Coastguard Worker                ),
4150*da0073e9SAndroid Build Coastguard Worker                (1, 2),
4151*da0073e9SAndroid Build Coastguard Worker            ),
4152*da0073e9SAndroid Build Coastguard Worker            (
4153*da0073e9SAndroid Build Coastguard Worker                Gamma(concentration=torch.tensor([1.0]), rate=torch.tensor([[1.0]])),
4154*da0073e9SAndroid Build Coastguard Worker                (1, 1),
4155*da0073e9SAndroid Build Coastguard Worker            ),
4156*da0073e9SAndroid Build Coastguard Worker            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4157*da0073e9SAndroid Build Coastguard Worker            (Gumbel(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4158*da0073e9SAndroid Build Coastguard Worker            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4159*da0073e9SAndroid Build Coastguard Worker            (
4160*da0073e9SAndroid Build Coastguard Worker                Gumbel(
4161*da0073e9SAndroid Build Coastguard Worker                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4162*da0073e9SAndroid Build Coastguard Worker                ),
4163*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4164*da0073e9SAndroid Build Coastguard Worker            ),
4165*da0073e9SAndroid Build Coastguard Worker            (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)),
4166*da0073e9SAndroid Build Coastguard Worker            (Gumbel(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4167*da0073e9SAndroid Build Coastguard Worker            (
4168*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(
4169*da0073e9SAndroid Build Coastguard Worker                    concentration1=torch.tensor([1.0, 1.0]), concentration0=1.0
4170*da0073e9SAndroid Build Coastguard Worker                ),
4171*da0073e9SAndroid Build Coastguard Worker                (2,),
4172*da0073e9SAndroid Build Coastguard Worker            ),
4173*da0073e9SAndroid Build Coastguard Worker            (
4174*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(concentration1=1, concentration0=torch.tensor([1.0, 1.0])),
4175*da0073e9SAndroid Build Coastguard Worker                (2,),
4176*da0073e9SAndroid Build Coastguard Worker            ),
4177*da0073e9SAndroid Build Coastguard Worker            (
4178*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(
4179*da0073e9SAndroid Build Coastguard Worker                    concentration1=torch.tensor([1.0, 1.0]),
4180*da0073e9SAndroid Build Coastguard Worker                    concentration0=torch.tensor([1.0]),
4181*da0073e9SAndroid Build Coastguard Worker                ),
4182*da0073e9SAndroid Build Coastguard Worker                (2,),
4183*da0073e9SAndroid Build Coastguard Worker            ),
4184*da0073e9SAndroid Build Coastguard Worker            (
4185*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(
4186*da0073e9SAndroid Build Coastguard Worker                    concentration1=torch.tensor([1.0, 1.0]),
4187*da0073e9SAndroid Build Coastguard Worker                    concentration0=torch.tensor([[1.0], [1.0]]),
4188*da0073e9SAndroid Build Coastguard Worker                ),
4189*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4190*da0073e9SAndroid Build Coastguard Worker            ),
4191*da0073e9SAndroid Build Coastguard Worker            (
4192*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(
4193*da0073e9SAndroid Build Coastguard Worker                    concentration1=torch.tensor([1.0, 1.0]),
4194*da0073e9SAndroid Build Coastguard Worker                    concentration0=torch.tensor([[1.0]]),
4195*da0073e9SAndroid Build Coastguard Worker                ),
4196*da0073e9SAndroid Build Coastguard Worker                (1, 2),
4197*da0073e9SAndroid Build Coastguard Worker            ),
4198*da0073e9SAndroid Build Coastguard Worker            (
4199*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy(
4200*da0073e9SAndroid Build Coastguard Worker                    concentration1=torch.tensor([1.0]),
4201*da0073e9SAndroid Build Coastguard Worker                    concentration0=torch.tensor([[1.0]]),
4202*da0073e9SAndroid Build Coastguard Worker                ),
4203*da0073e9SAndroid Build Coastguard Worker                (1, 1),
4204*da0073e9SAndroid Build Coastguard Worker            ),
4205*da0073e9SAndroid Build Coastguard Worker            (Laplace(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)),
4206*da0073e9SAndroid Build Coastguard Worker            (Laplace(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)),
4207*da0073e9SAndroid Build Coastguard Worker            (Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)),
4208*da0073e9SAndroid Build Coastguard Worker            (
4209*da0073e9SAndroid Build Coastguard Worker                Laplace(
4210*da0073e9SAndroid Build Coastguard Worker                    loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]])
4211*da0073e9SAndroid Build Coastguard Worker                ),
4212*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4213*da0073e9SAndroid Build Coastguard Worker            ),
4214*da0073e9SAndroid Build Coastguard Worker            (
4215*da0073e9SAndroid Build Coastguard Worker                Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])),
4216*da0073e9SAndroid Build Coastguard Worker                (1, 2),
4217*da0073e9SAndroid Build Coastguard Worker            ),
4218*da0073e9SAndroid Build Coastguard Worker            (Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4219*da0073e9SAndroid Build Coastguard Worker            (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=1), (2,)),
4220*da0073e9SAndroid Build Coastguard Worker            (Pareto(scale=1, alpha=torch.tensor([1.0, 1.0])), (2,)),
4221*da0073e9SAndroid Build Coastguard Worker            (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([1.0])), (2,)),
4222*da0073e9SAndroid Build Coastguard Worker            (
4223*da0073e9SAndroid Build Coastguard Worker                Pareto(
4224*da0073e9SAndroid Build Coastguard Worker                    scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0], [1.0]])
4225*da0073e9SAndroid Build Coastguard Worker                ),
4226*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4227*da0073e9SAndroid Build Coastguard Worker            ),
4228*da0073e9SAndroid Build Coastguard Worker            (
4229*da0073e9SAndroid Build Coastguard Worker                Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0]])),
4230*da0073e9SAndroid Build Coastguard Worker                (1, 2),
4231*da0073e9SAndroid Build Coastguard Worker            ),
4232*da0073e9SAndroid Build Coastguard Worker            (Pareto(scale=torch.tensor([1.0]), alpha=torch.tensor([[1.0]])), (1, 1)),
4233*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=torch.tensor([1.0, 1.0]), loc=1), (2,)),
4234*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=1, scale=torch.tensor([1.0, 1.0])), (2,)),
4235*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([1.0])), (2,)),
4236*da0073e9SAndroid Build Coastguard Worker            (
4237*da0073e9SAndroid Build Coastguard Worker                StudentT(
4238*da0073e9SAndroid Build Coastguard Worker                    df=torch.tensor([1.0, 1.0]), scale=torch.tensor([[1.0], [1.0]])
4239*da0073e9SAndroid Build Coastguard Worker                ),
4240*da0073e9SAndroid Build Coastguard Worker                (2, 2),
4241*da0073e9SAndroid Build Coastguard Worker            ),
4242*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([[1.0]])), (1, 2)),
4243*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=torch.tensor([1.0]), scale=torch.tensor([[1.0]])), (1, 1)),
4244*da0073e9SAndroid Build Coastguard Worker            (StudentT(df=1.0, loc=torch.zeros(5, 1), scale=torch.ones(3)), (5, 3)),
4245*da0073e9SAndroid Build Coastguard Worker        ]
4246*da0073e9SAndroid Build Coastguard Worker
4247*da0073e9SAndroid Build Coastguard Worker        for dist, expected_size in valid_examples:
4248*da0073e9SAndroid Build Coastguard Worker            actual_size = dist.sample().size()
4249*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
4250*da0073e9SAndroid Build Coastguard Worker                actual_size,
4251*da0073e9SAndroid Build Coastguard Worker                expected_size,
4252*da0073e9SAndroid Build Coastguard Worker                msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}",
4253*da0073e9SAndroid Build Coastguard Worker            )
4254*da0073e9SAndroid Build Coastguard Worker
4255*da0073e9SAndroid Build Coastguard Worker            sample_shape = torch.Size((2,))
4256*da0073e9SAndroid Build Coastguard Worker            expected_size = sample_shape + expected_size
4257*da0073e9SAndroid Build Coastguard Worker            actual_size = dist.sample(sample_shape).size()
4258*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
4259*da0073e9SAndroid Build Coastguard Worker                actual_size,
4260*da0073e9SAndroid Build Coastguard Worker                expected_size,
4261*da0073e9SAndroid Build Coastguard Worker                msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}",
4262*da0073e9SAndroid Build Coastguard Worker            )
4263*da0073e9SAndroid Build Coastguard Worker
4264*da0073e9SAndroid Build Coastguard Worker    def test_invalid_parameter_broadcasting(self):
4265*da0073e9SAndroid Build Coastguard Worker        # invalid broadcasting cases; should throw error
4266*da0073e9SAndroid Build Coastguard Worker        # example type (distribution class, distribution params)
4267*da0073e9SAndroid Build Coastguard Worker        invalid_examples = [
4268*da0073e9SAndroid Build Coastguard Worker            (
4269*da0073e9SAndroid Build Coastguard Worker                Normal,
4270*da0073e9SAndroid Build Coastguard Worker                {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])},
4271*da0073e9SAndroid Build Coastguard Worker            ),
4272*da0073e9SAndroid Build Coastguard Worker            (
4273*da0073e9SAndroid Build Coastguard Worker                Normal,
4274*da0073e9SAndroid Build Coastguard Worker                {
4275*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
4276*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1, 1]),
4277*da0073e9SAndroid Build Coastguard Worker                },
4278*da0073e9SAndroid Build Coastguard Worker            ),
4279*da0073e9SAndroid Build Coastguard Worker            (
4280*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor,
4281*da0073e9SAndroid Build Coastguard Worker                {
4282*da0073e9SAndroid Build Coastguard Worker                    "df1": torch.tensor([1, 1]),
4283*da0073e9SAndroid Build Coastguard Worker                    "df2": torch.tensor([1, 1, 1]),
4284*da0073e9SAndroid Build Coastguard Worker                },
4285*da0073e9SAndroid Build Coastguard Worker            ),
4286*da0073e9SAndroid Build Coastguard Worker            (
4287*da0073e9SAndroid Build Coastguard Worker                Gumbel,
4288*da0073e9SAndroid Build Coastguard Worker                {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])},
4289*da0073e9SAndroid Build Coastguard Worker            ),
4290*da0073e9SAndroid Build Coastguard Worker            (
4291*da0073e9SAndroid Build Coastguard Worker                Gumbel,
4292*da0073e9SAndroid Build Coastguard Worker                {
4293*da0073e9SAndroid Build Coastguard Worker                    "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
4294*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1, 1]),
4295*da0073e9SAndroid Build Coastguard Worker                },
4296*da0073e9SAndroid Build Coastguard Worker            ),
4297*da0073e9SAndroid Build Coastguard Worker            (
4298*da0073e9SAndroid Build Coastguard Worker                Gamma,
4299*da0073e9SAndroid Build Coastguard Worker                {
4300*da0073e9SAndroid Build Coastguard Worker                    "concentration": torch.tensor([0, 0]),
4301*da0073e9SAndroid Build Coastguard Worker                    "rate": torch.tensor([1, 1, 1]),
4302*da0073e9SAndroid Build Coastguard Worker                },
4303*da0073e9SAndroid Build Coastguard Worker            ),
4304*da0073e9SAndroid Build Coastguard Worker            (
4305*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy,
4306*da0073e9SAndroid Build Coastguard Worker                {
4307*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.tensor([[1, 1]]),
4308*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.tensor([1, 1, 1, 1]),
4309*da0073e9SAndroid Build Coastguard Worker                },
4310*da0073e9SAndroid Build Coastguard Worker            ),
4311*da0073e9SAndroid Build Coastguard Worker            (
4312*da0073e9SAndroid Build Coastguard Worker                Kumaraswamy,
4313*da0073e9SAndroid Build Coastguard Worker                {
4314*da0073e9SAndroid Build Coastguard Worker                    "concentration1": torch.tensor([[[1, 1, 1], [1, 1, 1]]]),
4315*da0073e9SAndroid Build Coastguard Worker                    "concentration0": torch.tensor([1, 1]),
4316*da0073e9SAndroid Build Coastguard Worker                },
4317*da0073e9SAndroid Build Coastguard Worker            ),
4318*da0073e9SAndroid Build Coastguard Worker            (Laplace, {"loc": torch.tensor([0, 0]), "scale": torch.tensor([1, 1, 1])}),
4319*da0073e9SAndroid Build Coastguard Worker            (Pareto, {"scale": torch.tensor([1, 1]), "alpha": torch.tensor([1, 1, 1])}),
4320*da0073e9SAndroid Build Coastguard Worker            (
4321*da0073e9SAndroid Build Coastguard Worker                StudentT,
4322*da0073e9SAndroid Build Coastguard Worker                {
4323*da0073e9SAndroid Build Coastguard Worker                    "df": torch.tensor([1.0, 1.0]),
4324*da0073e9SAndroid Build Coastguard Worker                    "scale": torch.tensor([1.0, 1.0, 1.0]),
4325*da0073e9SAndroid Build Coastguard Worker                },
4326*da0073e9SAndroid Build Coastguard Worker            ),
4327*da0073e9SAndroid Build Coastguard Worker            (
4328*da0073e9SAndroid Build Coastguard Worker                StudentT,
4329*da0073e9SAndroid Build Coastguard Worker                {"df": torch.tensor([1.0, 1.0]), "loc": torch.tensor([1.0, 1.0, 1.0])},
4330*da0073e9SAndroid Build Coastguard Worker            ),
4331*da0073e9SAndroid Build Coastguard Worker        ]
4332*da0073e9SAndroid Build Coastguard Worker
4333*da0073e9SAndroid Build Coastguard Worker        for dist, kwargs in invalid_examples:
4334*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, dist, **kwargs)
4335*da0073e9SAndroid Build Coastguard Worker
4336*da0073e9SAndroid Build Coastguard Worker    def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
4337*da0073e9SAndroid Build Coastguard Worker        # We cannot easily check the mode for discrete distributions, but we can look left and right
4338*da0073e9SAndroid Build Coastguard Worker        # to ensure the log probability is smaller than at the mode.
4339*da0073e9SAndroid Build Coastguard Worker        for step in [-1, 1]:
4340*da0073e9SAndroid Build Coastguard Worker            log_prob_mode = dist.log_prob(sanitized_mode)
4341*da0073e9SAndroid Build Coastguard Worker            if isinstance(dist, OneHotCategorical):
4342*da0073e9SAndroid Build Coastguard Worker                idx = (dist._categorical.mode + 1) % dist.probs.shape[-1]
4343*da0073e9SAndroid Build Coastguard Worker                other = torch.nn.functional.one_hot(
4344*da0073e9SAndroid Build Coastguard Worker                    idx, num_classes=dist.probs.shape[-1]
4345*da0073e9SAndroid Build Coastguard Worker                ).to(dist.mode)
4346*da0073e9SAndroid Build Coastguard Worker            else:
4347*da0073e9SAndroid Build Coastguard Worker                other = dist.mode + step
4348*da0073e9SAndroid Build Coastguard Worker            mask = batch_isfinite & dist.support.check(other)
4349*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mask.any() or dist.mode.unique().numel() == 1)
4350*da0073e9SAndroid Build Coastguard Worker            # Add a dimension to the right if the event shape is not a scalar, e.g. OneHotCategorical.
4351*da0073e9SAndroid Build Coastguard Worker            other = torch.where(
4352*da0073e9SAndroid Build Coastguard Worker                mask[..., None] if mask.ndim < other.ndim else mask,
4353*da0073e9SAndroid Build Coastguard Worker                other,
4354*da0073e9SAndroid Build Coastguard Worker                dist.sample(),
4355*da0073e9SAndroid Build Coastguard Worker            )
4356*da0073e9SAndroid Build Coastguard Worker            log_prob_other = dist.log_prob(other)
4357*da0073e9SAndroid Build Coastguard Worker            delta = log_prob_mode - log_prob_other
4358*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
4359*da0073e9SAndroid Build Coastguard Worker                (-1e-12 < delta[mask].detach()).all()
4360*da0073e9SAndroid Build Coastguard Worker            )  # Allow up to 1e-12 rounding error.
4361*da0073e9SAndroid Build Coastguard Worker
4362*da0073e9SAndroid Build Coastguard Worker    def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite):
4363*da0073e9SAndroid Build Coastguard Worker        # We perturb the mode in the unconstrained space and expect the log probability to decrease.
4364*da0073e9SAndroid Build Coastguard Worker        num_points = 10
4365*da0073e9SAndroid Build Coastguard Worker        transform = transform_to(dist.support)
4366*da0073e9SAndroid Build Coastguard Worker        unconstrained_mode = transform.inv(sanitized_mode)
4367*da0073e9SAndroid Build Coastguard Worker        perturbation = 1e-5 * (
4368*da0073e9SAndroid Build Coastguard Worker            torch.rand((num_points,) + unconstrained_mode.shape) - 0.5
4369*da0073e9SAndroid Build Coastguard Worker        )
4370*da0073e9SAndroid Build Coastguard Worker        perturbed_mode = transform(perturbation + unconstrained_mode)
4371*da0073e9SAndroid Build Coastguard Worker        log_prob_mode = dist.log_prob(sanitized_mode)
4372*da0073e9SAndroid Build Coastguard Worker        log_prob_other = dist.log_prob(perturbed_mode)
4373*da0073e9SAndroid Build Coastguard Worker        delta = log_prob_mode - log_prob_other
4374*da0073e9SAndroid Build Coastguard Worker
4375*da0073e9SAndroid Build Coastguard Worker        # We pass the test with a small tolerance to allow for rounding and manually set the
4376*da0073e9SAndroid Build Coastguard Worker        # difference to zero if both log probs are infinite with the same sign.
4377*da0073e9SAndroid Build Coastguard Worker        both_infinite_with_same_sign = (log_prob_mode == log_prob_other) & (
4378*da0073e9SAndroid Build Coastguard Worker            log_prob_mode.abs() == inf
4379*da0073e9SAndroid Build Coastguard Worker        )
4380*da0073e9SAndroid Build Coastguard Worker        delta[both_infinite_with_same_sign] = 0.0
4381*da0073e9SAndroid Build Coastguard Worker        ordering = (delta > -1e-12).all(axis=0)
4382*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ordering[batch_isfinite].all())
4383*da0073e9SAndroid Build Coastguard Worker
4384*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4385*da0073e9SAndroid Build Coastguard Worker    def test_mode(self):
4386*da0073e9SAndroid Build Coastguard Worker        discrete_distributions = (
4387*da0073e9SAndroid Build Coastguard Worker            Bernoulli,
4388*da0073e9SAndroid Build Coastguard Worker            Binomial,
4389*da0073e9SAndroid Build Coastguard Worker            Categorical,
4390*da0073e9SAndroid Build Coastguard Worker            Geometric,
4391*da0073e9SAndroid Build Coastguard Worker            NegativeBinomial,
4392*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical,
4393*da0073e9SAndroid Build Coastguard Worker            Poisson,
4394*da0073e9SAndroid Build Coastguard Worker        )
4395*da0073e9SAndroid Build Coastguard Worker        no_mode_available = (
4396*da0073e9SAndroid Build Coastguard Worker            ContinuousBernoulli,
4397*da0073e9SAndroid Build Coastguard Worker            LKJCholesky,
4398*da0073e9SAndroid Build Coastguard Worker            LogisticNormal,
4399*da0073e9SAndroid Build Coastguard Worker            MixtureSameFamily,
4400*da0073e9SAndroid Build Coastguard Worker            Multinomial,
4401*da0073e9SAndroid Build Coastguard Worker            RelaxedBernoulli,
4402*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical,
4403*da0073e9SAndroid Build Coastguard Worker        )
4404*da0073e9SAndroid Build Coastguard Worker
4405*da0073e9SAndroid Build Coastguard Worker        for dist_cls, params in _get_examples():
4406*da0073e9SAndroid Build Coastguard Worker            for param in params:
4407*da0073e9SAndroid Build Coastguard Worker                dist = dist_cls(**param)
4408*da0073e9SAndroid Build Coastguard Worker                if (
4409*da0073e9SAndroid Build Coastguard Worker                    isinstance(dist, no_mode_available)
4410*da0073e9SAndroid Build Coastguard Worker                    or type(dist) is TransformedDistribution
4411*da0073e9SAndroid Build Coastguard Worker                ):
4412*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(NotImplementedError):
4413*da0073e9SAndroid Build Coastguard Worker                        dist.mode
4414*da0073e9SAndroid Build Coastguard Worker                    continue
4415*da0073e9SAndroid Build Coastguard Worker
4416*da0073e9SAndroid Build Coastguard Worker                # Check that either all or no elements in the event shape are nan: the mode cannot be
4417*da0073e9SAndroid Build Coastguard Worker                # defined for part of an event.
4418*da0073e9SAndroid Build Coastguard Worker                isfinite = dist.mode.isfinite().reshape(
4419*da0073e9SAndroid Build Coastguard Worker                    dist.batch_shape + (dist.event_shape.numel(),)
4420*da0073e9SAndroid Build Coastguard Worker                )
4421*da0073e9SAndroid Build Coastguard Worker                batch_isfinite = isfinite.all(axis=-1)
4422*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((batch_isfinite | ~isfinite.any(axis=-1)).all())
4423*da0073e9SAndroid Build Coastguard Worker
4424*da0073e9SAndroid Build Coastguard Worker                # We sanitize undefined modes by sampling from the distribution.
4425*da0073e9SAndroid Build Coastguard Worker                sanitized_mode = torch.where(
4426*da0073e9SAndroid Build Coastguard Worker                    ~dist.mode.isnan(), dist.mode, dist.sample()
4427*da0073e9SAndroid Build Coastguard Worker                )
4428*da0073e9SAndroid Build Coastguard Worker                if isinstance(dist, discrete_distributions):
4429*da0073e9SAndroid Build Coastguard Worker                    self._test_discrete_distribution_mode(
4430*da0073e9SAndroid Build Coastguard Worker                        dist, sanitized_mode, batch_isfinite
4431*da0073e9SAndroid Build Coastguard Worker                    )
4432*da0073e9SAndroid Build Coastguard Worker                else:
4433*da0073e9SAndroid Build Coastguard Worker                    self._test_continuous_distribution_mode(
4434*da0073e9SAndroid Build Coastguard Worker                        dist, sanitized_mode, batch_isfinite
4435*da0073e9SAndroid Build Coastguard Worker                    )
4436*da0073e9SAndroid Build Coastguard Worker
4437*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(dist.log_prob(sanitized_mode).isnan().any())
4438*da0073e9SAndroid Build Coastguard Worker
4439*da0073e9SAndroid Build Coastguard Worker
4440*da0073e9SAndroid Build Coastguard Worker# These tests are only needed for a few distributions that implement custom
4441*da0073e9SAndroid Build Coastguard Worker# reparameterized gradients. Most .rsample() implementations simply rely on
4442*da0073e9SAndroid Build Coastguard Worker# the reparameterization trick and do not need to be tested for accuracy.
4443*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
4444*da0073e9SAndroid Build Coastguard Workerclass TestRsample(DistributionsTestCase):
4445*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4446*da0073e9SAndroid Build Coastguard Worker    def test_gamma(self):
4447*da0073e9SAndroid Build Coastguard Worker        num_samples = 100
4448*da0073e9SAndroid Build Coastguard Worker        for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
4449*da0073e9SAndroid Build Coastguard Worker            alphas = torch.tensor(
4450*da0073e9SAndroid Build Coastguard Worker                [alpha] * num_samples, dtype=torch.float, requires_grad=True
4451*da0073e9SAndroid Build Coastguard Worker            )
4452*da0073e9SAndroid Build Coastguard Worker            betas = alphas.new_ones(num_samples)
4453*da0073e9SAndroid Build Coastguard Worker            x = Gamma(alphas, betas).rsample()
4454*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
4455*da0073e9SAndroid Build Coastguard Worker            x, ind = x.sort()
4456*da0073e9SAndroid Build Coastguard Worker            x = x.detach().numpy()
4457*da0073e9SAndroid Build Coastguard Worker            actual_grad = alphas.grad[ind].numpy()
4458*da0073e9SAndroid Build Coastguard Worker            # Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
4459*da0073e9SAndroid Build Coastguard Worker            cdf = scipy.stats.gamma.cdf
4460*da0073e9SAndroid Build Coastguard Worker            pdf = scipy.stats.gamma.pdf
4461*da0073e9SAndroid Build Coastguard Worker            eps = 0.01 * alpha / (1.0 + alpha**0.5)
4462*da0073e9SAndroid Build Coastguard Worker            cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
4463*da0073e9SAndroid Build Coastguard Worker            cdf_x = pdf(x, alpha)
4464*da0073e9SAndroid Build Coastguard Worker            expected_grad = -cdf_alpha / cdf_x
4465*da0073e9SAndroid Build Coastguard Worker            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4466*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4467*da0073e9SAndroid Build Coastguard Worker                np.max(rel_error),
4468*da0073e9SAndroid Build Coastguard Worker                0.0005,
4469*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4470*da0073e9SAndroid Build Coastguard Worker                    [
4471*da0073e9SAndroid Build Coastguard Worker                        f"Bad gradient dx/alpha for x ~ Gamma({alpha}, 1)",
4472*da0073e9SAndroid Build Coastguard Worker                        f"x {x}",
4473*da0073e9SAndroid Build Coastguard Worker                        f"expected {expected_grad}",
4474*da0073e9SAndroid Build Coastguard Worker                        f"actual {actual_grad}",
4475*da0073e9SAndroid Build Coastguard Worker                        f"rel error {rel_error}",
4476*da0073e9SAndroid Build Coastguard Worker                        f"max error {rel_error.max()}",
4477*da0073e9SAndroid Build Coastguard Worker                        f"at alpha={alpha}, x={x[rel_error.argmax()]}",
4478*da0073e9SAndroid Build Coastguard Worker                    ]
4479*da0073e9SAndroid Build Coastguard Worker                ),
4480*da0073e9SAndroid Build Coastguard Worker            )
4481*da0073e9SAndroid Build Coastguard Worker
4482*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4483*da0073e9SAndroid Build Coastguard Worker    def test_chi2(self):
4484*da0073e9SAndroid Build Coastguard Worker        num_samples = 100
4485*da0073e9SAndroid Build Coastguard Worker        for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
4486*da0073e9SAndroid Build Coastguard Worker            dfs = torch.tensor(
4487*da0073e9SAndroid Build Coastguard Worker                [df] * num_samples, dtype=torch.float, requires_grad=True
4488*da0073e9SAndroid Build Coastguard Worker            )
4489*da0073e9SAndroid Build Coastguard Worker            x = Chi2(dfs).rsample()
4490*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
4491*da0073e9SAndroid Build Coastguard Worker            x, ind = x.sort()
4492*da0073e9SAndroid Build Coastguard Worker            x = x.detach().numpy()
4493*da0073e9SAndroid Build Coastguard Worker            actual_grad = dfs.grad[ind].numpy()
4494*da0073e9SAndroid Build Coastguard Worker            # Compare with expected gradient dx/ddf along constant cdf(x,df).
4495*da0073e9SAndroid Build Coastguard Worker            cdf = scipy.stats.chi2.cdf
4496*da0073e9SAndroid Build Coastguard Worker            pdf = scipy.stats.chi2.pdf
4497*da0073e9SAndroid Build Coastguard Worker            eps = 0.01 * df / (1.0 + df**0.5)
4498*da0073e9SAndroid Build Coastguard Worker            cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps)
4499*da0073e9SAndroid Build Coastguard Worker            cdf_x = pdf(x, df)
4500*da0073e9SAndroid Build Coastguard Worker            expected_grad = -cdf_df / cdf_x
4501*da0073e9SAndroid Build Coastguard Worker            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4502*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4503*da0073e9SAndroid Build Coastguard Worker                np.max(rel_error),
4504*da0073e9SAndroid Build Coastguard Worker                0.001,
4505*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4506*da0073e9SAndroid Build Coastguard Worker                    [
4507*da0073e9SAndroid Build Coastguard Worker                        f"Bad gradient dx/ddf for x ~ Chi2({df})",
4508*da0073e9SAndroid Build Coastguard Worker                        f"x {x}",
4509*da0073e9SAndroid Build Coastguard Worker                        f"expected {expected_grad}",
4510*da0073e9SAndroid Build Coastguard Worker                        f"actual {actual_grad}",
4511*da0073e9SAndroid Build Coastguard Worker                        f"rel error {rel_error}",
4512*da0073e9SAndroid Build Coastguard Worker                        f"max error {rel_error.max()}",
4513*da0073e9SAndroid Build Coastguard Worker                    ]
4514*da0073e9SAndroid Build Coastguard Worker                ),
4515*da0073e9SAndroid Build Coastguard Worker            )
4516*da0073e9SAndroid Build Coastguard Worker
4517*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4518*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_on_diagonal(self):
4519*da0073e9SAndroid Build Coastguard Worker        num_samples = 20
4520*da0073e9SAndroid Build Coastguard Worker        grid = [1e-1, 1e0, 1e1]
4521*da0073e9SAndroid Build Coastguard Worker        for a0, a1, a2 in product(grid, grid, grid):
4522*da0073e9SAndroid Build Coastguard Worker            alphas = torch.tensor(
4523*da0073e9SAndroid Build Coastguard Worker                [[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True
4524*da0073e9SAndroid Build Coastguard Worker            )
4525*da0073e9SAndroid Build Coastguard Worker            x = Dirichlet(alphas).rsample()[:, 0]
4526*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
4527*da0073e9SAndroid Build Coastguard Worker            x, ind = x.sort()
4528*da0073e9SAndroid Build Coastguard Worker            x = x.detach().numpy()
4529*da0073e9SAndroid Build Coastguard Worker            actual_grad = alphas.grad[ind].numpy()[:, 0]
4530*da0073e9SAndroid Build Coastguard Worker            # Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha).
4531*da0073e9SAndroid Build Coastguard Worker            # This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]).
4532*da0073e9SAndroid Build Coastguard Worker            cdf = scipy.stats.beta.cdf
4533*da0073e9SAndroid Build Coastguard Worker            pdf = scipy.stats.beta.pdf
4534*da0073e9SAndroid Build Coastguard Worker            alpha, beta = a0, a1 + a2
4535*da0073e9SAndroid Build Coastguard Worker            eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
4536*da0073e9SAndroid Build Coastguard Worker            cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (
4537*da0073e9SAndroid Build Coastguard Worker                2 * eps
4538*da0073e9SAndroid Build Coastguard Worker            )
4539*da0073e9SAndroid Build Coastguard Worker            cdf_x = pdf(x, alpha, beta)
4540*da0073e9SAndroid Build Coastguard Worker            expected_grad = -cdf_alpha / cdf_x
4541*da0073e9SAndroid Build Coastguard Worker            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4542*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4543*da0073e9SAndroid Build Coastguard Worker                np.max(rel_error),
4544*da0073e9SAndroid Build Coastguard Worker                0.001,
4545*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4546*da0073e9SAndroid Build Coastguard Worker                    [
4547*da0073e9SAndroid Build Coastguard Worker                        f"Bad gradient dx[0]/dalpha[0] for Dirichlet([{a0}, {a1}, {a2}])",
4548*da0073e9SAndroid Build Coastguard Worker                        f"x {x}",
4549*da0073e9SAndroid Build Coastguard Worker                        f"expected {expected_grad}",
4550*da0073e9SAndroid Build Coastguard Worker                        f"actual {actual_grad}",
4551*da0073e9SAndroid Build Coastguard Worker                        f"rel error {rel_error}",
4552*da0073e9SAndroid Build Coastguard Worker                        f"max error {rel_error.max()}",
4553*da0073e9SAndroid Build Coastguard Worker                        f"at x={x[rel_error.argmax()]}",
4554*da0073e9SAndroid Build Coastguard Worker                    ]
4555*da0073e9SAndroid Build Coastguard Worker                ),
4556*da0073e9SAndroid Build Coastguard Worker            )
4557*da0073e9SAndroid Build Coastguard Worker
4558*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4559*da0073e9SAndroid Build Coastguard Worker    def test_beta_wrt_alpha(self):
4560*da0073e9SAndroid Build Coastguard Worker        num_samples = 20
4561*da0073e9SAndroid Build Coastguard Worker        grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
4562*da0073e9SAndroid Build Coastguard Worker        for con1, con0 in product(grid, grid):
4563*da0073e9SAndroid Build Coastguard Worker            con1s = torch.tensor(
4564*da0073e9SAndroid Build Coastguard Worker                [con1] * num_samples, dtype=torch.float, requires_grad=True
4565*da0073e9SAndroid Build Coastguard Worker            )
4566*da0073e9SAndroid Build Coastguard Worker            con0s = con1s.new_tensor([con0] * num_samples)
4567*da0073e9SAndroid Build Coastguard Worker            x = Beta(con1s, con0s).rsample()
4568*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
4569*da0073e9SAndroid Build Coastguard Worker            x, ind = x.sort()
4570*da0073e9SAndroid Build Coastguard Worker            x = x.detach().numpy()
4571*da0073e9SAndroid Build Coastguard Worker            actual_grad = con1s.grad[ind].numpy()
4572*da0073e9SAndroid Build Coastguard Worker            # Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0).
4573*da0073e9SAndroid Build Coastguard Worker            cdf = scipy.stats.beta.cdf
4574*da0073e9SAndroid Build Coastguard Worker            pdf = scipy.stats.beta.pdf
4575*da0073e9SAndroid Build Coastguard Worker            eps = 0.01 * con1 / (1.0 + np.sqrt(con1))
4576*da0073e9SAndroid Build Coastguard Worker            cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / (
4577*da0073e9SAndroid Build Coastguard Worker                2 * eps
4578*da0073e9SAndroid Build Coastguard Worker            )
4579*da0073e9SAndroid Build Coastguard Worker            cdf_x = pdf(x, con1, con0)
4580*da0073e9SAndroid Build Coastguard Worker            expected_grad = -cdf_alpha / cdf_x
4581*da0073e9SAndroid Build Coastguard Worker            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4582*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4583*da0073e9SAndroid Build Coastguard Worker                np.max(rel_error),
4584*da0073e9SAndroid Build Coastguard Worker                0.005,
4585*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4586*da0073e9SAndroid Build Coastguard Worker                    [
4587*da0073e9SAndroid Build Coastguard Worker                        f"Bad gradient dx/dcon1 for x ~ Beta({con1}, {con0})",
4588*da0073e9SAndroid Build Coastguard Worker                        f"x {x}",
4589*da0073e9SAndroid Build Coastguard Worker                        f"expected {expected_grad}",
4590*da0073e9SAndroid Build Coastguard Worker                        f"actual {actual_grad}",
4591*da0073e9SAndroid Build Coastguard Worker                        f"rel error {rel_error}",
4592*da0073e9SAndroid Build Coastguard Worker                        f"max error {rel_error.max()}",
4593*da0073e9SAndroid Build Coastguard Worker                        f"at x = {x[rel_error.argmax()]}",
4594*da0073e9SAndroid Build Coastguard Worker                    ]
4595*da0073e9SAndroid Build Coastguard Worker                ),
4596*da0073e9SAndroid Build Coastguard Worker            )
4597*da0073e9SAndroid Build Coastguard Worker
4598*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
4599*da0073e9SAndroid Build Coastguard Worker    def test_beta_wrt_beta(self):
4600*da0073e9SAndroid Build Coastguard Worker        num_samples = 20
4601*da0073e9SAndroid Build Coastguard Worker        grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
4602*da0073e9SAndroid Build Coastguard Worker        for con1, con0 in product(grid, grid):
4603*da0073e9SAndroid Build Coastguard Worker            con0s = torch.tensor(
4604*da0073e9SAndroid Build Coastguard Worker                [con0] * num_samples, dtype=torch.float, requires_grad=True
4605*da0073e9SAndroid Build Coastguard Worker            )
4606*da0073e9SAndroid Build Coastguard Worker            con1s = con0s.new_tensor([con1] * num_samples)
4607*da0073e9SAndroid Build Coastguard Worker            x = Beta(con1s, con0s).rsample()
4608*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
4609*da0073e9SAndroid Build Coastguard Worker            x, ind = x.sort()
4610*da0073e9SAndroid Build Coastguard Worker            x = x.detach().numpy()
4611*da0073e9SAndroid Build Coastguard Worker            actual_grad = con0s.grad[ind].numpy()
4612*da0073e9SAndroid Build Coastguard Worker            # Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0).
4613*da0073e9SAndroid Build Coastguard Worker            cdf = scipy.stats.beta.cdf
4614*da0073e9SAndroid Build Coastguard Worker            pdf = scipy.stats.beta.pdf
4615*da0073e9SAndroid Build Coastguard Worker            eps = 0.01 * con0 / (1.0 + np.sqrt(con0))
4616*da0073e9SAndroid Build Coastguard Worker            cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps)
4617*da0073e9SAndroid Build Coastguard Worker            cdf_x = pdf(x, con1, con0)
4618*da0073e9SAndroid Build Coastguard Worker            expected_grad = -cdf_beta / cdf_x
4619*da0073e9SAndroid Build Coastguard Worker            rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
4620*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4621*da0073e9SAndroid Build Coastguard Worker                np.max(rel_error),
4622*da0073e9SAndroid Build Coastguard Worker                0.005,
4623*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4624*da0073e9SAndroid Build Coastguard Worker                    [
4625*da0073e9SAndroid Build Coastguard Worker                        f"Bad gradient dx/dcon0 for x ~ Beta({con1}, {con0})",
4626*da0073e9SAndroid Build Coastguard Worker                        f"x {x}",
4627*da0073e9SAndroid Build Coastguard Worker                        f"expected {expected_grad}",
4628*da0073e9SAndroid Build Coastguard Worker                        f"actual {actual_grad}",
4629*da0073e9SAndroid Build Coastguard Worker                        f"rel error {rel_error}",
4630*da0073e9SAndroid Build Coastguard Worker                        f"max error {rel_error.max()}",
4631*da0073e9SAndroid Build Coastguard Worker                        f"at x = {x[rel_error.argmax()]!r}",
4632*da0073e9SAndroid Build Coastguard Worker                    ]
4633*da0073e9SAndroid Build Coastguard Worker                ),
4634*da0073e9SAndroid Build Coastguard Worker            )
4635*da0073e9SAndroid Build Coastguard Worker
4636*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_multivariate(self):
4637*da0073e9SAndroid Build Coastguard Worker        alpha_crit = 0.25 * (5.0**0.5 - 1.0)
4638*da0073e9SAndroid Build Coastguard Worker        num_samples = 100000
4639*da0073e9SAndroid Build Coastguard Worker        for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
4640*da0073e9SAndroid Build Coastguard Worker            alpha = alpha_crit + shift
4641*da0073e9SAndroid Build Coastguard Worker            alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True)
4642*da0073e9SAndroid Build Coastguard Worker            alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
4643*da0073e9SAndroid Build Coastguard Worker            z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
4644*da0073e9SAndroid Build Coastguard Worker            mean_z3 = 1.0 / (2.0 * alpha + 1.0)
4645*da0073e9SAndroid Build Coastguard Worker            loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean()
4646*da0073e9SAndroid Build Coastguard Worker            actual_grad = grad(loss, [alpha])[0]
4647*da0073e9SAndroid Build Coastguard Worker            # Compute expected gradient by hand.
4648*da0073e9SAndroid Build Coastguard Worker            num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
4649*da0073e9SAndroid Build Coastguard Worker            den = (1.0 + alpha) ** 2 * (1.0 + 2.0 * alpha) ** 3
4650*da0073e9SAndroid Build Coastguard Worker            expected_grad = num / den
4651*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
4652*da0073e9SAndroid Build Coastguard Worker                actual_grad,
4653*da0073e9SAndroid Build Coastguard Worker                expected_grad,
4654*da0073e9SAndroid Build Coastguard Worker                atol=0.002,
4655*da0073e9SAndroid Build Coastguard Worker                rtol=0,
4656*da0073e9SAndroid Build Coastguard Worker                msg="\n".join(
4657*da0073e9SAndroid Build Coastguard Worker                    [
4658*da0073e9SAndroid Build Coastguard Worker                        "alpha = alpha_c + %.2g" % shift,  # noqa: UP031
4659*da0073e9SAndroid Build Coastguard Worker                        "expected_grad: %.5g" % expected_grad,  # noqa: UP031
4660*da0073e9SAndroid Build Coastguard Worker                        "actual_grad: %.5g" % actual_grad,  # noqa: UP031
4661*da0073e9SAndroid Build Coastguard Worker                        "error = %.2g"  # noqa: UP031
4662*da0073e9SAndroid Build Coastguard Worker                        % torch.abs(expected_grad - actual_grad).max(),  # noqa: UP031
4663*da0073e9SAndroid Build Coastguard Worker                    ]
4664*da0073e9SAndroid Build Coastguard Worker                ),
4665*da0073e9SAndroid Build Coastguard Worker            )
4666*da0073e9SAndroid Build Coastguard Worker
4667*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4668*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_tangent_field(self):
4669*da0073e9SAndroid Build Coastguard Worker        num_samples = 20
4670*da0073e9SAndroid Build Coastguard Worker        alpha_grid = [0.5, 1.0, 2.0]
4671*da0073e9SAndroid Build Coastguard Worker
4672*da0073e9SAndroid Build Coastguard Worker        # v = dx/dalpha[0] is the reparameterized gradient aka tangent field.
4673*da0073e9SAndroid Build Coastguard Worker        def compute_v(x, alpha):
4674*da0073e9SAndroid Build Coastguard Worker            return torch.stack(
4675*da0073e9SAndroid Build Coastguard Worker                [
4676*da0073e9SAndroid Build Coastguard Worker                    _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0]
4677*da0073e9SAndroid Build Coastguard Worker                    for i in range(3)
4678*da0073e9SAndroid Build Coastguard Worker                ],
4679*da0073e9SAndroid Build Coastguard Worker                dim=-1,
4680*da0073e9SAndroid Build Coastguard Worker            )
4681*da0073e9SAndroid Build Coastguard Worker
4682*da0073e9SAndroid Build Coastguard Worker        for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid):
4683*da0073e9SAndroid Build Coastguard Worker            alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand(
4684*da0073e9SAndroid Build Coastguard Worker                num_samples, 3
4685*da0073e9SAndroid Build Coastguard Worker            )
4686*da0073e9SAndroid Build Coastguard Worker            x = Dirichlet(alpha).rsample()
4687*da0073e9SAndroid Build Coastguard Worker            dlogp_da = grad(
4688*da0073e9SAndroid Build Coastguard Worker                [Dirichlet(alpha).log_prob(x.detach()).sum()],
4689*da0073e9SAndroid Build Coastguard Worker                [alpha],
4690*da0073e9SAndroid Build Coastguard Worker                retain_graph=True,
4691*da0073e9SAndroid Build Coastguard Worker            )[0][:, 0]
4692*da0073e9SAndroid Build Coastguard Worker            dlogp_dx = grad(
4693*da0073e9SAndroid Build Coastguard Worker                [Dirichlet(alpha.detach()).log_prob(x).sum()], [x], retain_graph=True
4694*da0073e9SAndroid Build Coastguard Worker            )[0]
4695*da0073e9SAndroid Build Coastguard Worker            v = torch.stack(
4696*da0073e9SAndroid Build Coastguard Worker                [
4697*da0073e9SAndroid Build Coastguard Worker                    grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0]
4698*da0073e9SAndroid Build Coastguard Worker                    for i in range(3)
4699*da0073e9SAndroid Build Coastguard Worker                ],
4700*da0073e9SAndroid Build Coastguard Worker                dim=-1,
4701*da0073e9SAndroid Build Coastguard Worker            )
4702*da0073e9SAndroid Build Coastguard Worker            # Compute ramaining properties by finite difference.
4703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compute_v(x, alpha), v, msg="Bug in compute_v() helper")
4704*da0073e9SAndroid Build Coastguard Worker            # dx is an arbitrary orthonormal basis tangent to the simplex.
4705*da0073e9SAndroid Build Coastguard Worker            dx = torch.tensor([[2.0, -1.0, -1.0], [0.0, 1.0, -1.0]])
4706*da0073e9SAndroid Build Coastguard Worker            dx /= dx.norm(2, -1, True)
4707*da0073e9SAndroid Build Coastguard Worker            eps = 1e-2 * x.min(-1, True)[0]  # avoid boundary
4708*da0073e9SAndroid Build Coastguard Worker            dv0 = (
4709*da0073e9SAndroid Build Coastguard Worker                compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha)
4710*da0073e9SAndroid Build Coastguard Worker            ) / (2 * eps)
4711*da0073e9SAndroid Build Coastguard Worker            dv1 = (
4712*da0073e9SAndroid Build Coastguard Worker                compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha)
4713*da0073e9SAndroid Build Coastguard Worker            ) / (2 * eps)
4714*da0073e9SAndroid Build Coastguard Worker            div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1)
4715*da0073e9SAndroid Build Coastguard Worker            # This is a modification of the standard continuity equation, using the product rule to allow
4716*da0073e9SAndroid Build Coastguard Worker            # expression in terms of log_prob rather than the less numerically stable log_prob.exp().
4717*da0073e9SAndroid Build Coastguard Worker            error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
4718*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
4719*da0073e9SAndroid Build Coastguard Worker                torch.abs(error).max(),
4720*da0073e9SAndroid Build Coastguard Worker                0.005,
4721*da0073e9SAndroid Build Coastguard Worker                "\n".join(
4722*da0073e9SAndroid Build Coastguard Worker                    [
4723*da0073e9SAndroid Build Coastguard Worker                        f"Dirichlet([{a1}, {a2}, {a3}]) gradient violates continuity equation:",
4724*da0073e9SAndroid Build Coastguard Worker                        f"error = {error}",
4725*da0073e9SAndroid Build Coastguard Worker                    ]
4726*da0073e9SAndroid Build Coastguard Worker                ),
4727*da0073e9SAndroid Build Coastguard Worker            )
4728*da0073e9SAndroid Build Coastguard Worker
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Workerclass TestDistributionShapes(DistributionsTestCase):
4731*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
4732*da0073e9SAndroid Build Coastguard Worker        super().setUp()
4733*da0073e9SAndroid Build Coastguard Worker        self.scalar_sample = 1
4734*da0073e9SAndroid Build Coastguard Worker        self.tensor_sample_1 = torch.ones(3, 2)
4735*da0073e9SAndroid Build Coastguard Worker        self.tensor_sample_2 = torch.ones(3, 2, 3)
4736*da0073e9SAndroid Build Coastguard Worker
4737*da0073e9SAndroid Build Coastguard Worker    def test_entropy_shape(self):
4738*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
4739*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
4740*da0073e9SAndroid Build Coastguard Worker                dist = Dist(validate_args=False, **param)
4741*da0073e9SAndroid Build Coastguard Worker                try:
4742*da0073e9SAndroid Build Coastguard Worker                    actual_shape = dist.entropy().size()
4743*da0073e9SAndroid Build Coastguard Worker                    expected_shape = (
4744*da0073e9SAndroid Build Coastguard Worker                        dist.batch_shape if dist.batch_shape else torch.Size()
4745*da0073e9SAndroid Build Coastguard Worker                    )
4746*da0073e9SAndroid Build Coastguard Worker                    message = f"{Dist.__name__} example {i + 1}/{len(params)}, shape mismatch. expected {expected_shape}, actual {actual_shape}"  # noqa: B950
4747*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual_shape, expected_shape, msg=message)
4748*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
4749*da0073e9SAndroid Build Coastguard Worker                    continue
4750*da0073e9SAndroid Build Coastguard Worker
4751*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_shape_scalar_params(self):
4752*da0073e9SAndroid Build Coastguard Worker        bernoulli = Bernoulli(0.3)
4753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli._batch_shape, torch.Size())
4754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli._event_shape, torch.Size())
4755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli.sample().size(), torch.Size())
4756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
4757*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample)
4758*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4759*da0073e9SAndroid Build Coastguard Worker            bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4760*da0073e9SAndroid Build Coastguard Worker        )
4761*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4762*da0073e9SAndroid Build Coastguard Worker            bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4763*da0073e9SAndroid Build Coastguard Worker        )
4764*da0073e9SAndroid Build Coastguard Worker
4765*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_shape_tensor_params(self):
4766*da0073e9SAndroid Build Coastguard Worker        bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4767*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
4768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli._event_shape, torch.Size(()))
4769*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
4770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4771*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4772*da0073e9SAndroid Build Coastguard Worker            bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4773*da0073e9SAndroid Build Coastguard Worker        )
4774*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
4775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4776*da0073e9SAndroid Build Coastguard Worker            bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4777*da0073e9SAndroid Build Coastguard Worker        )
4778*da0073e9SAndroid Build Coastguard Worker
4779*da0073e9SAndroid Build Coastguard Worker    def test_geometric_shape_scalar_params(self):
4780*da0073e9SAndroid Build Coastguard Worker        geometric = Geometric(0.3)
4781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric._batch_shape, torch.Size())
4782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric._event_shape, torch.Size())
4783*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric.sample().size(), torch.Size())
4784*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2)))
4785*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample)
4786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4787*da0073e9SAndroid Build Coastguard Worker            geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4788*da0073e9SAndroid Build Coastguard Worker        )
4789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4790*da0073e9SAndroid Build Coastguard Worker            geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4791*da0073e9SAndroid Build Coastguard Worker        )
4792*da0073e9SAndroid Build Coastguard Worker
4793*da0073e9SAndroid Build Coastguard Worker    def test_geometric_shape_tensor_params(self):
4794*da0073e9SAndroid Build Coastguard Worker        geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric._batch_shape, torch.Size((3, 2)))
4796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric._event_shape, torch.Size(()))
4797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric.sample().size(), torch.Size((3, 2)))
4798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4800*da0073e9SAndroid Build Coastguard Worker            geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4801*da0073e9SAndroid Build Coastguard Worker        )
4802*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2)
4803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4804*da0073e9SAndroid Build Coastguard Worker            geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4805*da0073e9SAndroid Build Coastguard Worker        )
4806*da0073e9SAndroid Build Coastguard Worker
4807*da0073e9SAndroid Build Coastguard Worker    def test_beta_shape_scalar_params(self):
4808*da0073e9SAndroid Build Coastguard Worker        dist = Beta(0.1, 0.1)
4809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size())
4810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size())
4811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size())
4812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
4813*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
4814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4815*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4816*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4817*da0073e9SAndroid Build Coastguard Worker        )
4818*da0073e9SAndroid Build Coastguard Worker
4819*da0073e9SAndroid Build Coastguard Worker    def test_beta_shape_tensor_params(self):
4820*da0073e9SAndroid Build Coastguard Worker        dist = Beta(
4821*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
4822*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
4823*da0073e9SAndroid Build Coastguard Worker        )
4824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
4825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size(()))
4826*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4829*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4831*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))
4832*da0073e9SAndroid Build Coastguard Worker        )
4833*da0073e9SAndroid Build Coastguard Worker
4834*da0073e9SAndroid Build Coastguard Worker    def test_binomial_shape(self):
4835*da0073e9SAndroid Build Coastguard Worker        dist = Binomial(10, torch.tensor([0.6, 0.3]))
4836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((2,)))
4837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size(()))
4838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((2,)))
4839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2)))
4840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4841*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4842*da0073e9SAndroid Build Coastguard Worker
4843*da0073e9SAndroid Build Coastguard Worker    def test_binomial_shape_vectorized_n(self):
4844*da0073e9SAndroid Build Coastguard Worker        dist = Binomial(
4845*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1])
4846*da0073e9SAndroid Build Coastguard Worker        )
4847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
4848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size(()))
4849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
4850*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
4851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4852*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4853*da0073e9SAndroid Build Coastguard Worker        )
4854*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4855*da0073e9SAndroid Build Coastguard Worker
4856*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_shape(self):
4857*da0073e9SAndroid Build Coastguard Worker        dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4859*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size((2,)))
4860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
4863*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
4865*da0073e9SAndroid Build Coastguard Worker
4866*da0073e9SAndroid Build Coastguard Worker    def test_categorical_shape(self):
4867*da0073e9SAndroid Build Coastguard Worker        # unbatched
4868*da0073e9SAndroid Build Coastguard Worker        dist = Categorical(torch.tensor([0.6, 0.3, 0.1]))
4869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size(()))
4870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size(()))
4871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size())
4872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4873*da0073e9SAndroid Build Coastguard Worker            dist.sample((3, 2)).size(),
4874*da0073e9SAndroid Build Coastguard Worker            torch.Size(
4875*da0073e9SAndroid Build Coastguard Worker                (
4876*da0073e9SAndroid Build Coastguard Worker                    3,
4877*da0073e9SAndroid Build Coastguard Worker                    2,
4878*da0073e9SAndroid Build Coastguard Worker                )
4879*da0073e9SAndroid Build Coastguard Worker            ),
4880*da0073e9SAndroid Build Coastguard Worker        )
4881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
4882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4883*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4884*da0073e9SAndroid Build Coastguard Worker        )
4885*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1)))
4886*da0073e9SAndroid Build Coastguard Worker        # batched
4887*da0073e9SAndroid Build Coastguard Worker        dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4889*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size(()))
4890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3,)))
4891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4892*da0073e9SAndroid Build Coastguard Worker            dist.sample((3, 2)).size(),
4893*da0073e9SAndroid Build Coastguard Worker            torch.Size(
4894*da0073e9SAndroid Build Coastguard Worker                (
4895*da0073e9SAndroid Build Coastguard Worker                    3,
4896*da0073e9SAndroid Build Coastguard Worker                    2,
4897*da0073e9SAndroid Build Coastguard Worker                    3,
4898*da0073e9SAndroid Build Coastguard Worker                )
4899*da0073e9SAndroid Build Coastguard Worker            ),
4900*da0073e9SAndroid Build Coastguard Worker        )
4901*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4903*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4904*da0073e9SAndroid Build Coastguard Worker        )
4905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3)))
4906*da0073e9SAndroid Build Coastguard Worker
4907*da0073e9SAndroid Build Coastguard Worker    def test_one_hot_categorical_shape(self):
4908*da0073e9SAndroid Build Coastguard Worker        # unbatched
4909*da0073e9SAndroid Build Coastguard Worker        dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1]))
4910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size(()))
4911*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size((3,)))
4912*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3,)))
4913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
4914*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
4915*da0073e9SAndroid Build Coastguard Worker        sample = torch.tensor([0.0, 1.0, 0.0]).expand(3, 2, 3)
4916*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4917*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(sample).size(),
4918*da0073e9SAndroid Build Coastguard Worker            torch.Size(
4919*da0073e9SAndroid Build Coastguard Worker                (
4920*da0073e9SAndroid Build Coastguard Worker                    3,
4921*da0073e9SAndroid Build Coastguard Worker                    2,
4922*da0073e9SAndroid Build Coastguard Worker                )
4923*da0073e9SAndroid Build Coastguard Worker            ),
4924*da0073e9SAndroid Build Coastguard Worker        )
4925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4926*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,))
4927*da0073e9SAndroid Build Coastguard Worker        )
4928*da0073e9SAndroid Build Coastguard Worker        sample = torch.eye(3)
4929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
4930*da0073e9SAndroid Build Coastguard Worker        # batched
4931*da0073e9SAndroid Build Coastguard Worker        dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
4932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((3,)))
4933*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size((2,)))
4934*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
4935*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
4936*da0073e9SAndroid Build Coastguard Worker        sample = torch.tensor([0.0, 1.0])
4937*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
4938*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
4939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4940*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3))
4941*da0073e9SAndroid Build Coastguard Worker        )
4942*da0073e9SAndroid Build Coastguard Worker        sample = torch.tensor([0.0, 1.0]).expand(3, 1, 2)
4943*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3)))
4944*da0073e9SAndroid Build Coastguard Worker
4945*da0073e9SAndroid Build Coastguard Worker    def test_cauchy_shape_scalar_params(self):
4946*da0073e9SAndroid Build Coastguard Worker        cauchy = Cauchy(0, 1)
4947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy._batch_shape, torch.Size())
4948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy._event_shape, torch.Size())
4949*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy.sample().size(), torch.Size())
4950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
4951*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample)
4952*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4953*da0073e9SAndroid Build Coastguard Worker            cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4954*da0073e9SAndroid Build Coastguard Worker        )
4955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4956*da0073e9SAndroid Build Coastguard Worker            cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4957*da0073e9SAndroid Build Coastguard Worker        )
4958*da0073e9SAndroid Build Coastguard Worker
4959*da0073e9SAndroid Build Coastguard Worker    def test_cauchy_shape_tensor_params(self):
4960*da0073e9SAndroid Build Coastguard Worker        cauchy = Cauchy(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
4961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy._batch_shape, torch.Size((2,)))
4962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy._event_shape, torch.Size(()))
4963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy.sample().size(), torch.Size((2,)))
4964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4965*da0073e9SAndroid Build Coastguard Worker            cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
4966*da0073e9SAndroid Build Coastguard Worker        )
4967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4968*da0073e9SAndroid Build Coastguard Worker            cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4969*da0073e9SAndroid Build Coastguard Worker        )
4970*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
4971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
4972*da0073e9SAndroid Build Coastguard Worker
4973*da0073e9SAndroid Build Coastguard Worker    def test_halfcauchy_shape_scalar_params(self):
4974*da0073e9SAndroid Build Coastguard Worker        halfcauchy = HalfCauchy(1)
4975*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy._batch_shape, torch.Size())
4976*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy._event_shape, torch.Size())
4977*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy.sample().size(), torch.Size())
4978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4979*da0073e9SAndroid Build Coastguard Worker            halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))
4980*da0073e9SAndroid Build Coastguard Worker        )
4981*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample)
4982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4983*da0073e9SAndroid Build Coastguard Worker            halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4984*da0073e9SAndroid Build Coastguard Worker        )
4985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4986*da0073e9SAndroid Build Coastguard Worker            halfcauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
4987*da0073e9SAndroid Build Coastguard Worker        )
4988*da0073e9SAndroid Build Coastguard Worker
4989*da0073e9SAndroid Build Coastguard Worker    def test_halfcauchy_shape_tensor_params(self):
4990*da0073e9SAndroid Build Coastguard Worker        halfcauchy = HalfCauchy(torch.tensor([1.0, 1.0]))
4991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
4992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy._event_shape, torch.Size(()))
4993*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
4994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4995*da0073e9SAndroid Build Coastguard Worker            halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
4996*da0073e9SAndroid Build Coastguard Worker        )
4997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4998*da0073e9SAndroid Build Coastguard Worker            halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
4999*da0073e9SAndroid Build Coastguard Worker        )
5000*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2)
5001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5002*da0073e9SAndroid Build Coastguard Worker            halfcauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))
5003*da0073e9SAndroid Build Coastguard Worker        )
5004*da0073e9SAndroid Build Coastguard Worker
5005*da0073e9SAndroid Build Coastguard Worker    def test_dirichlet_shape(self):
5006*da0073e9SAndroid Build Coastguard Worker        dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
5007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size((3,)))
5008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size((2,)))
5009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
5010*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
5011*da0073e9SAndroid Build Coastguard Worker        simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(
5012*da0073e9SAndroid Build Coastguard Worker            -1, keepdim=True
5013*da0073e9SAndroid Build Coastguard Worker        )
5014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
5015*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
5016*da0073e9SAndroid Build Coastguard Worker        simplex_sample = torch.ones(3, 1, 2)
5017*da0073e9SAndroid Build Coastguard Worker        simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
5018*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
5019*da0073e9SAndroid Build Coastguard Worker
5020*da0073e9SAndroid Build Coastguard Worker    def test_mixture_same_family_shape(self):
5021*da0073e9SAndroid Build Coastguard Worker        dist = MixtureSameFamily(
5022*da0073e9SAndroid Build Coastguard Worker            Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5))
5023*da0073e9SAndroid Build Coastguard Worker        )
5024*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._batch_shape, torch.Size())
5025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist._event_shape, torch.Size())
5026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample().size(), torch.Size())
5027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4)))
5028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5030*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5031*da0073e9SAndroid Build Coastguard Worker        )
5032*da0073e9SAndroid Build Coastguard Worker
5033*da0073e9SAndroid Build Coastguard Worker    def test_gamma_shape_scalar_params(self):
5034*da0073e9SAndroid Build Coastguard Worker        gamma = Gamma(1, 1)
5035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma._batch_shape, torch.Size())
5036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma._event_shape, torch.Size())
5037*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.sample().size(), torch.Size())
5038*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
5039*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.log_prob(self.scalar_sample).size(), torch.Size())
5040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5041*da0073e9SAndroid Build Coastguard Worker            gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5042*da0073e9SAndroid Build Coastguard Worker        )
5043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5044*da0073e9SAndroid Build Coastguard Worker            gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5045*da0073e9SAndroid Build Coastguard Worker        )
5046*da0073e9SAndroid Build Coastguard Worker
5047*da0073e9SAndroid Build Coastguard Worker    def test_gamma_shape_tensor_params(self):
5048*da0073e9SAndroid Build Coastguard Worker        gamma = Gamma(torch.tensor([1.0, 1.0]), torch.tensor([1.0, 1.0]))
5049*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma._batch_shape, torch.Size((2,)))
5050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma._event_shape, torch.Size(()))
5051*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.sample().size(), torch.Size((2,)))
5052*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5054*da0073e9SAndroid Build Coastguard Worker            gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5055*da0073e9SAndroid Build Coastguard Worker        )
5056*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
5057*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5058*da0073e9SAndroid Build Coastguard Worker
5059*da0073e9SAndroid Build Coastguard Worker    def test_chi2_shape_scalar_params(self):
5060*da0073e9SAndroid Build Coastguard Worker        chi2 = Chi2(1)
5061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2._batch_shape, torch.Size())
5062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2._event_shape, torch.Size())
5063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.sample().size(), torch.Size())
5064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2)))
5065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.log_prob(self.scalar_sample).size(), torch.Size())
5066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5068*da0073e9SAndroid Build Coastguard Worker            chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5069*da0073e9SAndroid Build Coastguard Worker        )
5070*da0073e9SAndroid Build Coastguard Worker
5071*da0073e9SAndroid Build Coastguard Worker    def test_chi2_shape_tensor_params(self):
5072*da0073e9SAndroid Build Coastguard Worker        chi2 = Chi2(torch.tensor([1.0, 1.0]))
5073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2._batch_shape, torch.Size((2,)))
5074*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2._event_shape, torch.Size(()))
5075*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.sample().size(), torch.Size((2,)))
5076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5078*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
5079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5080*da0073e9SAndroid Build Coastguard Worker
5081*da0073e9SAndroid Build Coastguard Worker    def test_studentT_shape_scalar_params(self):
5082*da0073e9SAndroid Build Coastguard Worker        st = StudentT(1)
5083*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st._batch_shape, torch.Size())
5084*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st._event_shape, torch.Size())
5085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.sample().size(), torch.Size())
5086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
5087*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
5088*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5090*da0073e9SAndroid Build Coastguard Worker            st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5091*da0073e9SAndroid Build Coastguard Worker        )
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker    def test_studentT_shape_tensor_params(self):
5094*da0073e9SAndroid Build Coastguard Worker        st = StudentT(torch.tensor([1.0, 1.0]))
5095*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st._batch_shape, torch.Size((2,)))
5096*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st._event_shape, torch.Size(()))
5097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.sample().size(), torch.Size((2,)))
5098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
5100*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
5101*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5102*da0073e9SAndroid Build Coastguard Worker
5103*da0073e9SAndroid Build Coastguard Worker    def test_pareto_shape_scalar_params(self):
5104*da0073e9SAndroid Build Coastguard Worker        pareto = Pareto(1, 1)
5105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pareto._batch_shape, torch.Size())
5106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pareto._event_shape, torch.Size())
5107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pareto.sample().size(), torch.Size())
5108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
5109*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5110*da0073e9SAndroid Build Coastguard Worker            pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2))
5111*da0073e9SAndroid Build Coastguard Worker        )
5112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5113*da0073e9SAndroid Build Coastguard Worker            pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3))
5114*da0073e9SAndroid Build Coastguard Worker        )
5115*da0073e9SAndroid Build Coastguard Worker
5116*da0073e9SAndroid Build Coastguard Worker    def test_gumbel_shape_scalar_params(self):
5117*da0073e9SAndroid Build Coastguard Worker        gumbel = Gumbel(1, 1)
5118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gumbel._batch_shape, torch.Size())
5119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gumbel._event_shape, torch.Size())
5120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gumbel.sample().size(), torch.Size())
5121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
5122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5123*da0073e9SAndroid Build Coastguard Worker            gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5124*da0073e9SAndroid Build Coastguard Worker        )
5125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5126*da0073e9SAndroid Build Coastguard Worker            gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5127*da0073e9SAndroid Build Coastguard Worker        )
5128*da0073e9SAndroid Build Coastguard Worker
5129*da0073e9SAndroid Build Coastguard Worker    def test_kumaraswamy_shape_scalar_params(self):
5130*da0073e9SAndroid Build Coastguard Worker        kumaraswamy = Kumaraswamy(1, 1)
5131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kumaraswamy._batch_shape, torch.Size())
5132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kumaraswamy._event_shape, torch.Size())
5133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kumaraswamy.sample().size(), torch.Size())
5134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2)))
5135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5136*da0073e9SAndroid Build Coastguard Worker            kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5137*da0073e9SAndroid Build Coastguard Worker        )
5138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5139*da0073e9SAndroid Build Coastguard Worker            kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5140*da0073e9SAndroid Build Coastguard Worker        )
5141*da0073e9SAndroid Build Coastguard Worker
5142*da0073e9SAndroid Build Coastguard Worker    def test_vonmises_shape_tensor_params(self):
5143*da0073e9SAndroid Build Coastguard Worker        von_mises = VonMises(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
5145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises._event_shape, torch.Size(()))
5146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises.sample().size(), torch.Size((2,)))
5147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5148*da0073e9SAndroid Build Coastguard Worker            von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
5149*da0073e9SAndroid Build Coastguard Worker        )
5150*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5151*da0073e9SAndroid Build Coastguard Worker            von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5152*da0073e9SAndroid Build Coastguard Worker        )
5153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5154*da0073e9SAndroid Build Coastguard Worker            von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))
5155*da0073e9SAndroid Build Coastguard Worker        )
5156*da0073e9SAndroid Build Coastguard Worker
5157*da0073e9SAndroid Build Coastguard Worker    def test_vonmises_shape_scalar_params(self):
5158*da0073e9SAndroid Build Coastguard Worker        von_mises = VonMises(0.0, 1.0)
5159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises._batch_shape, torch.Size())
5160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises._event_shape, torch.Size())
5161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(von_mises.sample().size(), torch.Size())
5162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5163*da0073e9SAndroid Build Coastguard Worker            von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))
5164*da0073e9SAndroid Build Coastguard Worker        )
5165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5166*da0073e9SAndroid Build Coastguard Worker            von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5167*da0073e9SAndroid Build Coastguard Worker        )
5168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5169*da0073e9SAndroid Build Coastguard Worker            von_mises.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5170*da0073e9SAndroid Build Coastguard Worker        )
5171*da0073e9SAndroid Build Coastguard Worker
5172*da0073e9SAndroid Build Coastguard Worker    def test_weibull_scale_scalar_params(self):
5173*da0073e9SAndroid Build Coastguard Worker        weibull = Weibull(1, 1)
5174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weibull._batch_shape, torch.Size())
5175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weibull._event_shape, torch.Size())
5176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weibull.sample().size(), torch.Size())
5177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2)))
5178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5179*da0073e9SAndroid Build Coastguard Worker            weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5180*da0073e9SAndroid Build Coastguard Worker        )
5181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5182*da0073e9SAndroid Build Coastguard Worker            weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5183*da0073e9SAndroid Build Coastguard Worker        )
5184*da0073e9SAndroid Build Coastguard Worker
5185*da0073e9SAndroid Build Coastguard Worker    def test_wishart_shape_scalar_params(self):
5186*da0073e9SAndroid Build Coastguard Worker        wishart = Wishart(torch.tensor(1), torch.tensor([[1.0]]))
5187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart._batch_shape, torch.Size())
5188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
5189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
5190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
5191*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)
5192*da0073e9SAndroid Build Coastguard Worker
5193*da0073e9SAndroid Build Coastguard Worker    def test_wishart_shape_tensor_params(self):
5194*da0073e9SAndroid Build Coastguard Worker        wishart = Wishart(torch.tensor([1.0, 1.0]), torch.tensor([[[1.0]], [[1.0]]]))
5195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart._batch_shape, torch.Size((2,)))
5196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
5197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
5198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
5199*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
5200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))
5201*da0073e9SAndroid Build Coastguard Worker
5202*da0073e9SAndroid Build Coastguard Worker    def test_normal_shape_scalar_params(self):
5203*da0073e9SAndroid Build Coastguard Worker        normal = Normal(0, 1)
5204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal._batch_shape, torch.Size())
5205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal._event_shape, torch.Size())
5206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal.sample().size(), torch.Size())
5207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
5208*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
5209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5210*da0073e9SAndroid Build Coastguard Worker            normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5211*da0073e9SAndroid Build Coastguard Worker        )
5212*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5213*da0073e9SAndroid Build Coastguard Worker            normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5214*da0073e9SAndroid Build Coastguard Worker        )
5215*da0073e9SAndroid Build Coastguard Worker
5216*da0073e9SAndroid Build Coastguard Worker    def test_normal_shape_tensor_params(self):
5217*da0073e9SAndroid Build Coastguard Worker        normal = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal._batch_shape, torch.Size((2,)))
5219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal._event_shape, torch.Size(()))
5220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal.sample().size(), torch.Size((2,)))
5221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5223*da0073e9SAndroid Build Coastguard Worker            normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5224*da0073e9SAndroid Build Coastguard Worker        )
5225*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
5226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5227*da0073e9SAndroid Build Coastguard Worker
5228*da0073e9SAndroid Build Coastguard Worker    def test_uniform_shape_scalar_params(self):
5229*da0073e9SAndroid Build Coastguard Worker        uniform = Uniform(0, 1)
5230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform._batch_shape, torch.Size())
5231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform._event_shape, torch.Size())
5232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.sample().size(), torch.Size())
5233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
5234*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample)
5235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5236*da0073e9SAndroid Build Coastguard Worker            uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5237*da0073e9SAndroid Build Coastguard Worker        )
5238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5239*da0073e9SAndroid Build Coastguard Worker            uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5240*da0073e9SAndroid Build Coastguard Worker        )
5241*da0073e9SAndroid Build Coastguard Worker
5242*da0073e9SAndroid Build Coastguard Worker    def test_uniform_shape_tensor_params(self):
5243*da0073e9SAndroid Build Coastguard Worker        uniform = Uniform(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform._batch_shape, torch.Size((2,)))
5245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform._event_shape, torch.Size(()))
5246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.sample().size(), torch.Size((2,)))
5247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5248*da0073e9SAndroid Build Coastguard Worker            uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2))
5249*da0073e9SAndroid Build Coastguard Worker        )
5250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5251*da0073e9SAndroid Build Coastguard Worker            uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5252*da0073e9SAndroid Build Coastguard Worker        )
5253*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
5254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5255*da0073e9SAndroid Build Coastguard Worker
5256*da0073e9SAndroid Build Coastguard Worker    def test_exponential_shape_scalar_param(self):
5257*da0073e9SAndroid Build Coastguard Worker        expon = Exponential(1.0)
5258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon._batch_shape, torch.Size())
5259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon._event_shape, torch.Size())
5260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon.sample().size(), torch.Size())
5261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
5262*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
5263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5264*da0073e9SAndroid Build Coastguard Worker            expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5265*da0073e9SAndroid Build Coastguard Worker        )
5266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5267*da0073e9SAndroid Build Coastguard Worker            expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5268*da0073e9SAndroid Build Coastguard Worker        )
5269*da0073e9SAndroid Build Coastguard Worker
5270*da0073e9SAndroid Build Coastguard Worker    def test_exponential_shape_tensor_param(self):
5271*da0073e9SAndroid Build Coastguard Worker        expon = Exponential(torch.tensor([1.0, 1.0]))
5272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon._batch_shape, torch.Size((2,)))
5273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon._event_shape, torch.Size(()))
5274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon.sample().size(), torch.Size((2,)))
5275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5277*da0073e9SAndroid Build Coastguard Worker            expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5278*da0073e9SAndroid Build Coastguard Worker        )
5279*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
5280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2)))
5281*da0073e9SAndroid Build Coastguard Worker
5282*da0073e9SAndroid Build Coastguard Worker    def test_laplace_shape_scalar_params(self):
5283*da0073e9SAndroid Build Coastguard Worker        laplace = Laplace(0, 1)
5284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace._batch_shape, torch.Size())
5285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace._event_shape, torch.Size())
5286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace.sample().size(), torch.Size())
5287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
5288*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
5289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5290*da0073e9SAndroid Build Coastguard Worker            laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5291*da0073e9SAndroid Build Coastguard Worker        )
5292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5293*da0073e9SAndroid Build Coastguard Worker            laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))
5294*da0073e9SAndroid Build Coastguard Worker        )
5295*da0073e9SAndroid Build Coastguard Worker
5296*da0073e9SAndroid Build Coastguard Worker    def test_laplace_shape_tensor_params(self):
5297*da0073e9SAndroid Build Coastguard Worker        laplace = Laplace(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
5298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace._batch_shape, torch.Size((2,)))
5299*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace._event_shape, torch.Size(()))
5300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace.sample().size(), torch.Size((2,)))
5301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
5302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5303*da0073e9SAndroid Build Coastguard Worker            laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))
5304*da0073e9SAndroid Build Coastguard Worker        )
5305*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
5306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
5307*da0073e9SAndroid Build Coastguard Worker
5308*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_shape_scalar_params(self):
5309*da0073e9SAndroid Build Coastguard Worker        continuous_bernoulli = ContinuousBernoulli(0.3)
5310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
5311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
5312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
5313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
5314*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
5315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5316*da0073e9SAndroid Build Coastguard Worker            continuous_bernoulli.log_prob(self.tensor_sample_1).size(),
5317*da0073e9SAndroid Build Coastguard Worker            torch.Size((3, 2)),
5318*da0073e9SAndroid Build Coastguard Worker        )
5319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5320*da0073e9SAndroid Build Coastguard Worker            continuous_bernoulli.log_prob(self.tensor_sample_2).size(),
5321*da0073e9SAndroid Build Coastguard Worker            torch.Size((3, 2, 3)),
5322*da0073e9SAndroid Build Coastguard Worker        )
5323*da0073e9SAndroid Build Coastguard Worker
5324*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_shape_tensor_params(self):
5325*da0073e9SAndroid Build Coastguard Worker        continuous_bernoulli = ContinuousBernoulli(
5326*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])
5327*da0073e9SAndroid Build Coastguard Worker        )
5328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
5329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
5330*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
5331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5332*da0073e9SAndroid Build Coastguard Worker            continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))
5333*da0073e9SAndroid Build Coastguard Worker        )
5334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5335*da0073e9SAndroid Build Coastguard Worker            continuous_bernoulli.log_prob(self.tensor_sample_1).size(),
5336*da0073e9SAndroid Build Coastguard Worker            torch.Size((3, 2)),
5337*da0073e9SAndroid Build Coastguard Worker        )
5338*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
5339*da0073e9SAndroid Build Coastguard Worker            ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2
5340*da0073e9SAndroid Build Coastguard Worker        )
5341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5342*da0073e9SAndroid Build Coastguard Worker            continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(),
5343*da0073e9SAndroid Build Coastguard Worker            torch.Size((3, 3, 2)),
5344*da0073e9SAndroid Build Coastguard Worker        )
5345*da0073e9SAndroid Build Coastguard Worker
5346*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
5347*da0073e9SAndroid Build Coastguard Worker    def test_mixture_same_family_mean_shape(self):
5348*da0073e9SAndroid Build Coastguard Worker        mix_distribution = Categorical(torch.ones([3, 1, 3]))
5349*da0073e9SAndroid Build Coastguard Worker        component_distribution = Normal(torch.zeros([3, 3, 3]), torch.ones([3, 3, 3]))
5350*da0073e9SAndroid Build Coastguard Worker        gmm = MixtureSameFamily(mix_distribution, component_distribution)
5351*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(gmm.mean.shape), 2)
5352*da0073e9SAndroid Build Coastguard Worker
5353*da0073e9SAndroid Build Coastguard Worker
5354*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
5355*da0073e9SAndroid Build Coastguard Workerclass TestKL(DistributionsTestCase):
5356*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
5357*da0073e9SAndroid Build Coastguard Worker        super().setUp()
5358*da0073e9SAndroid Build Coastguard Worker
5359*da0073e9SAndroid Build Coastguard Worker        class Binomial30(Binomial):
5360*da0073e9SAndroid Build Coastguard Worker            def __init__(self, probs):
5361*da0073e9SAndroid Build Coastguard Worker                super().__init__(30, probs)
5362*da0073e9SAndroid Build Coastguard Worker
5363*da0073e9SAndroid Build Coastguard Worker        # These are pairs of distributions with 4 x 4 parameters as specified.
5364*da0073e9SAndroid Build Coastguard Worker        # The first of the pair e.g. bernoulli[0] varies column-wise and the second
5365*da0073e9SAndroid Build Coastguard Worker        # e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
5366*da0073e9SAndroid Build Coastguard Worker        bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
5367*da0073e9SAndroid Build Coastguard Worker        binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
5368*da0073e9SAndroid Build Coastguard Worker        binomial_vectorized_count = (
5369*da0073e9SAndroid Build Coastguard Worker            Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
5370*da0073e9SAndroid Build Coastguard Worker            Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])),
5371*da0073e9SAndroid Build Coastguard Worker        )
5372*da0073e9SAndroid Build Coastguard Worker        beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
5373*da0073e9SAndroid Build Coastguard Worker        categorical = pairwise(
5374*da0073e9SAndroid Build Coastguard Worker            Categorical,
5375*da0073e9SAndroid Build Coastguard Worker            [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]],
5376*da0073e9SAndroid Build Coastguard Worker        )
5377*da0073e9SAndroid Build Coastguard Worker        cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5378*da0073e9SAndroid Build Coastguard Worker        chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0])
5379*da0073e9SAndroid Build Coastguard Worker        dirichlet = pairwise(
5380*da0073e9SAndroid Build Coastguard Worker            Dirichlet,
5381*da0073e9SAndroid Build Coastguard Worker            [[0.1, 0.2, 0.7], [0.5, 0.4, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]],
5382*da0073e9SAndroid Build Coastguard Worker        )
5383*da0073e9SAndroid Build Coastguard Worker        exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
5384*da0073e9SAndroid Build Coastguard Worker        gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
5385*da0073e9SAndroid Build Coastguard Worker        gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
5386*da0073e9SAndroid Build Coastguard Worker        halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
5387*da0073e9SAndroid Build Coastguard Worker        inversegamma = pairwise(
5388*da0073e9SAndroid Build Coastguard Worker            InverseGamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]
5389*da0073e9SAndroid Build Coastguard Worker        )
5390*da0073e9SAndroid Build Coastguard Worker        laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
5391*da0073e9SAndroid Build Coastguard Worker        lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5392*da0073e9SAndroid Build Coastguard Worker        normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
5393*da0073e9SAndroid Build Coastguard Worker        independent = (Independent(normal[0], 1), Independent(normal[1], 1))
5394*da0073e9SAndroid Build Coastguard Worker        onehotcategorical = pairwise(
5395*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical,
5396*da0073e9SAndroid Build Coastguard Worker            [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]],
5397*da0073e9SAndroid Build Coastguard Worker        )
5398*da0073e9SAndroid Build Coastguard Worker        pareto = (
5399*da0073e9SAndroid Build Coastguard Worker            Pareto(
5400*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4),
5401*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4),
5402*da0073e9SAndroid Build Coastguard Worker            ),
5403*da0073e9SAndroid Build Coastguard Worker            Pareto(
5404*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4),
5405*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4),
5406*da0073e9SAndroid Build Coastguard Worker            ),
5407*da0073e9SAndroid Build Coastguard Worker        )
5408*da0073e9SAndroid Build Coastguard Worker        poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
5409*da0073e9SAndroid Build Coastguard Worker        uniform_within_unit = pairwise(
5410*da0073e9SAndroid Build Coastguard Worker            Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8]
5411*da0073e9SAndroid Build Coastguard Worker        )
5412*da0073e9SAndroid Build Coastguard Worker        uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
5413*da0073e9SAndroid Build Coastguard Worker        uniform_real = pairwise(Uniform, [-2.0, -1, 0, 2], [-1.0, 1, 1, 4])
5414*da0073e9SAndroid Build Coastguard Worker        uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5])
5415*da0073e9SAndroid Build Coastguard Worker        continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])
5416*da0073e9SAndroid Build Coastguard Worker
5417*da0073e9SAndroid Build Coastguard Worker        # These tests should pass with precision = 0.01, but that makes tests very expensive.
5418*da0073e9SAndroid Build Coastguard Worker        # Instead, we test with precision = 0.1 and only test with higher precision locally
5419*da0073e9SAndroid Build Coastguard Worker        # when adding a new KL implementation.
5420*da0073e9SAndroid Build Coastguard Worker        # The following pairs are not tested due to very high variance of the monte carlo
5421*da0073e9SAndroid Build Coastguard Worker        # estimator; their implementations have been reviewed with extra care:
5422*da0073e9SAndroid Build Coastguard Worker        # - (pareto, normal)
5423*da0073e9SAndroid Build Coastguard Worker        self.precision = 0.1  # Set this to 0.01 when testing a new KL implementation.
5424*da0073e9SAndroid Build Coastguard Worker        self.max_samples = int(1e07)  # Increase this when testing at smaller precision.
5425*da0073e9SAndroid Build Coastguard Worker        self.samples_per_batch = int(1e04)
5426*da0073e9SAndroid Build Coastguard Worker        self.finite_examples = [
5427*da0073e9SAndroid Build Coastguard Worker            (bernoulli, bernoulli),
5428*da0073e9SAndroid Build Coastguard Worker            (bernoulli, poisson),
5429*da0073e9SAndroid Build Coastguard Worker            (beta, beta),
5430*da0073e9SAndroid Build Coastguard Worker            (beta, chi2),
5431*da0073e9SAndroid Build Coastguard Worker            (beta, exponential),
5432*da0073e9SAndroid Build Coastguard Worker            (beta, gamma),
5433*da0073e9SAndroid Build Coastguard Worker            (beta, normal),
5434*da0073e9SAndroid Build Coastguard Worker            (binomial30, binomial30),
5435*da0073e9SAndroid Build Coastguard Worker            (binomial_vectorized_count, binomial_vectorized_count),
5436*da0073e9SAndroid Build Coastguard Worker            (categorical, categorical),
5437*da0073e9SAndroid Build Coastguard Worker            (cauchy, cauchy),
5438*da0073e9SAndroid Build Coastguard Worker            (chi2, chi2),
5439*da0073e9SAndroid Build Coastguard Worker            (chi2, exponential),
5440*da0073e9SAndroid Build Coastguard Worker            (chi2, gamma),
5441*da0073e9SAndroid Build Coastguard Worker            (chi2, normal),
5442*da0073e9SAndroid Build Coastguard Worker            (dirichlet, dirichlet),
5443*da0073e9SAndroid Build Coastguard Worker            (exponential, chi2),
5444*da0073e9SAndroid Build Coastguard Worker            (exponential, exponential),
5445*da0073e9SAndroid Build Coastguard Worker            (exponential, gamma),
5446*da0073e9SAndroid Build Coastguard Worker            (exponential, gumbel),
5447*da0073e9SAndroid Build Coastguard Worker            (exponential, normal),
5448*da0073e9SAndroid Build Coastguard Worker            (gamma, chi2),
5449*da0073e9SAndroid Build Coastguard Worker            (gamma, exponential),
5450*da0073e9SAndroid Build Coastguard Worker            (gamma, gamma),
5451*da0073e9SAndroid Build Coastguard Worker            (gamma, gumbel),
5452*da0073e9SAndroid Build Coastguard Worker            (gamma, normal),
5453*da0073e9SAndroid Build Coastguard Worker            (gumbel, gumbel),
5454*da0073e9SAndroid Build Coastguard Worker            (gumbel, normal),
5455*da0073e9SAndroid Build Coastguard Worker            (halfnormal, halfnormal),
5456*da0073e9SAndroid Build Coastguard Worker            (independent, independent),
5457*da0073e9SAndroid Build Coastguard Worker            (inversegamma, inversegamma),
5458*da0073e9SAndroid Build Coastguard Worker            (laplace, laplace),
5459*da0073e9SAndroid Build Coastguard Worker            (lognormal, lognormal),
5460*da0073e9SAndroid Build Coastguard Worker            (laplace, normal),
5461*da0073e9SAndroid Build Coastguard Worker            (normal, gumbel),
5462*da0073e9SAndroid Build Coastguard Worker            (normal, laplace),
5463*da0073e9SAndroid Build Coastguard Worker            (normal, normal),
5464*da0073e9SAndroid Build Coastguard Worker            (onehotcategorical, onehotcategorical),
5465*da0073e9SAndroid Build Coastguard Worker            (pareto, chi2),
5466*da0073e9SAndroid Build Coastguard Worker            (pareto, pareto),
5467*da0073e9SAndroid Build Coastguard Worker            (pareto, exponential),
5468*da0073e9SAndroid Build Coastguard Worker            (pareto, gamma),
5469*da0073e9SAndroid Build Coastguard Worker            (poisson, poisson),
5470*da0073e9SAndroid Build Coastguard Worker            (uniform_within_unit, beta),
5471*da0073e9SAndroid Build Coastguard Worker            (uniform_positive, chi2),
5472*da0073e9SAndroid Build Coastguard Worker            (uniform_positive, exponential),
5473*da0073e9SAndroid Build Coastguard Worker            (uniform_positive, gamma),
5474*da0073e9SAndroid Build Coastguard Worker            (uniform_real, gumbel),
5475*da0073e9SAndroid Build Coastguard Worker            (uniform_real, normal),
5476*da0073e9SAndroid Build Coastguard Worker            (uniform_pareto, pareto),
5477*da0073e9SAndroid Build Coastguard Worker            (continuous_bernoulli, continuous_bernoulli),
5478*da0073e9SAndroid Build Coastguard Worker            (continuous_bernoulli, exponential),
5479*da0073e9SAndroid Build Coastguard Worker            (continuous_bernoulli, normal),
5480*da0073e9SAndroid Build Coastguard Worker            (beta, continuous_bernoulli),
5481*da0073e9SAndroid Build Coastguard Worker        ]
5482*da0073e9SAndroid Build Coastguard Worker
5483*da0073e9SAndroid Build Coastguard Worker        self.infinite_examples = [
5484*da0073e9SAndroid Build Coastguard Worker            (Bernoulli(0), Bernoulli(1)),
5485*da0073e9SAndroid Build Coastguard Worker            (Bernoulli(1), Bernoulli(0)),
5486*da0073e9SAndroid Build Coastguard Worker            (
5487*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([0.9, 0.1])),
5488*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([1.0, 0.0])),
5489*da0073e9SAndroid Build Coastguard Worker            ),
5490*da0073e9SAndroid Build Coastguard Worker            (
5491*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([[0.9, 0.1], [0.9, 0.1]])),
5492*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([1.0, 0.0])),
5493*da0073e9SAndroid Build Coastguard Worker            ),
5494*da0073e9SAndroid Build Coastguard Worker            (Beta(1, 2), Uniform(0.25, 1)),
5495*da0073e9SAndroid Build Coastguard Worker            (Beta(1, 2), Uniform(0, 0.75)),
5496*da0073e9SAndroid Build Coastguard Worker            (Beta(1, 2), Uniform(0.25, 0.75)),
5497*da0073e9SAndroid Build Coastguard Worker            (Beta(1, 2), Pareto(1, 2)),
5498*da0073e9SAndroid Build Coastguard Worker            (Binomial(31, 0.7), Binomial(30, 0.3)),
5499*da0073e9SAndroid Build Coastguard Worker            (
5500*da0073e9SAndroid Build Coastguard Worker                Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
5501*da0073e9SAndroid Build Coastguard Worker                Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8])),
5502*da0073e9SAndroid Build Coastguard Worker            ),
5503*da0073e9SAndroid Build Coastguard Worker            (Chi2(1), Beta(2, 3)),
5504*da0073e9SAndroid Build Coastguard Worker            (Chi2(1), Pareto(2, 3)),
5505*da0073e9SAndroid Build Coastguard Worker            (Chi2(1), Uniform(-2, 3)),
5506*da0073e9SAndroid Build Coastguard Worker            (Exponential(1), Beta(2, 3)),
5507*da0073e9SAndroid Build Coastguard Worker            (Exponential(1), Pareto(2, 3)),
5508*da0073e9SAndroid Build Coastguard Worker            (Exponential(1), Uniform(-2, 3)),
5509*da0073e9SAndroid Build Coastguard Worker            (Gamma(1, 2), Beta(3, 4)),
5510*da0073e9SAndroid Build Coastguard Worker            (Gamma(1, 2), Pareto(3, 4)),
5511*da0073e9SAndroid Build Coastguard Worker            (Gamma(1, 2), Uniform(-3, 4)),
5512*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Beta(3, 4)),
5513*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Chi2(3)),
5514*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Exponential(3)),
5515*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Gamma(3, 4)),
5516*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Pareto(3, 4)),
5517*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), Uniform(-3, 4)),
5518*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Beta(3, 4)),
5519*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Chi2(3)),
5520*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Exponential(3)),
5521*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Gamma(3, 4)),
5522*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Pareto(3, 4)),
5523*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), Uniform(-3, 4)),
5524*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Beta(3, 4)),
5525*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Chi2(3)),
5526*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Exponential(3)),
5527*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Gamma(3, 4)),
5528*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Pareto(3, 4)),
5529*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), Uniform(-3, 4)),
5530*da0073e9SAndroid Build Coastguard Worker            (Pareto(2, 1), Chi2(3)),
5531*da0073e9SAndroid Build Coastguard Worker            (Pareto(2, 1), Exponential(3)),
5532*da0073e9SAndroid Build Coastguard Worker            (Pareto(2, 1), Gamma(3, 4)),
5533*da0073e9SAndroid Build Coastguard Worker            (Pareto(1, 2), Normal(-3, 4)),
5534*da0073e9SAndroid Build Coastguard Worker            (Pareto(1, 2), Pareto(3, 4)),
5535*da0073e9SAndroid Build Coastguard Worker            (Poisson(2), Bernoulli(0.5)),
5536*da0073e9SAndroid Build Coastguard Worker            (Poisson(2.3), Binomial(10, 0.2)),
5537*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 1), Beta(2, 2)),
5538*da0073e9SAndroid Build Coastguard Worker            (Uniform(0, 2), Beta(3, 4)),
5539*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), Beta(3, 4)),
5540*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), Chi2(3)),
5541*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), Exponential(3)),
5542*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), Gamma(3, 4)),
5543*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), Pareto(3, 4)),
5544*da0073e9SAndroid Build Coastguard Worker            (ContinuousBernoulli(0.25), Uniform(0.25, 1)),
5545*da0073e9SAndroid Build Coastguard Worker            (ContinuousBernoulli(0.25), Uniform(0, 0.75)),
5546*da0073e9SAndroid Build Coastguard Worker            (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
5547*da0073e9SAndroid Build Coastguard Worker            (ContinuousBernoulli(0.25), Pareto(1, 2)),
5548*da0073e9SAndroid Build Coastguard Worker            (Exponential(1), ContinuousBernoulli(0.75)),
5549*da0073e9SAndroid Build Coastguard Worker            (Gamma(1, 2), ContinuousBernoulli(0.75)),
5550*da0073e9SAndroid Build Coastguard Worker            (Gumbel(-1, 2), ContinuousBernoulli(0.75)),
5551*da0073e9SAndroid Build Coastguard Worker            (Laplace(-1, 2), ContinuousBernoulli(0.75)),
5552*da0073e9SAndroid Build Coastguard Worker            (Normal(-1, 2), ContinuousBernoulli(0.75)),
5553*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 1), ContinuousBernoulli(0.75)),
5554*da0073e9SAndroid Build Coastguard Worker            (Uniform(0, 2), ContinuousBernoulli(0.75)),
5555*da0073e9SAndroid Build Coastguard Worker            (Uniform(-1, 2), ContinuousBernoulli(0.75)),
5556*da0073e9SAndroid Build Coastguard Worker        ]
5557*da0073e9SAndroid Build Coastguard Worker
5558*da0073e9SAndroid Build Coastguard Worker    def test_kl_monte_carlo(self):
5559*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
5560*da0073e9SAndroid Build Coastguard Worker        for (p, _), (_, q) in self.finite_examples:
5561*da0073e9SAndroid Build Coastguard Worker            actual = kl_divergence(p, q)
5562*da0073e9SAndroid Build Coastguard Worker            numerator = 0
5563*da0073e9SAndroid Build Coastguard Worker            denominator = 0
5564*da0073e9SAndroid Build Coastguard Worker            while denominator < self.max_samples:
5565*da0073e9SAndroid Build Coastguard Worker                x = p.sample(sample_shape=(self.samples_per_batch,))
5566*da0073e9SAndroid Build Coastguard Worker                numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
5567*da0073e9SAndroid Build Coastguard Worker                denominator += x.size(0)
5568*da0073e9SAndroid Build Coastguard Worker                expected = numerator / denominator
5569*da0073e9SAndroid Build Coastguard Worker                error = torch.abs(expected - actual) / (1 + expected)
5570*da0073e9SAndroid Build Coastguard Worker                if error[error == error].max() < self.precision:
5571*da0073e9SAndroid Build Coastguard Worker                    break
5572*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
5573*da0073e9SAndroid Build Coastguard Worker                error[error == error].max(),
5574*da0073e9SAndroid Build Coastguard Worker                self.precision,
5575*da0073e9SAndroid Build Coastguard Worker                "\n".join(
5576*da0073e9SAndroid Build Coastguard Worker                    [
5577*da0073e9SAndroid Build Coastguard Worker                        f"Incorrect KL({type(p).__name__}, {type(q).__name__}).",
5578*da0073e9SAndroid Build Coastguard Worker                        f"Expected ({denominator} Monte Carlo samples): {expected}",
5579*da0073e9SAndroid Build Coastguard Worker                        f"Actual (analytic): {actual}",
5580*da0073e9SAndroid Build Coastguard Worker                    ]
5581*da0073e9SAndroid Build Coastguard Worker                ),
5582*da0073e9SAndroid Build Coastguard Worker            )
5583*da0073e9SAndroid Build Coastguard Worker
5584*da0073e9SAndroid Build Coastguard Worker    # Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
5585*da0073e9SAndroid Build Coastguard Worker    # positive (semi) definite matrices. n is set to 5, but can be increased during testing.
5586*da0073e9SAndroid Build Coastguard Worker    def test_kl_multivariate_normal(self):
5587*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
5588*da0073e9SAndroid Build Coastguard Worker        n = 5  # Number of tests for multivariate_normal
5589*da0073e9SAndroid Build Coastguard Worker        for i in range(0, n):
5590*da0073e9SAndroid Build Coastguard Worker            loc = [torch.randn(4) for _ in range(0, 2)]
5591*da0073e9SAndroid Build Coastguard Worker            scale_tril = [
5592*da0073e9SAndroid Build Coastguard Worker                transform_to(constraints.lower_cholesky)(torch.randn(4, 4))
5593*da0073e9SAndroid Build Coastguard Worker                for _ in range(0, 2)
5594*da0073e9SAndroid Build Coastguard Worker            ]
5595*da0073e9SAndroid Build Coastguard Worker            p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0])
5596*da0073e9SAndroid Build Coastguard Worker            q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1])
5597*da0073e9SAndroid Build Coastguard Worker            actual = kl_divergence(p, q)
5598*da0073e9SAndroid Build Coastguard Worker            numerator = 0
5599*da0073e9SAndroid Build Coastguard Worker            denominator = 0
5600*da0073e9SAndroid Build Coastguard Worker            while denominator < self.max_samples:
5601*da0073e9SAndroid Build Coastguard Worker                x = p.sample(sample_shape=(self.samples_per_batch,))
5602*da0073e9SAndroid Build Coastguard Worker                numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
5603*da0073e9SAndroid Build Coastguard Worker                denominator += x.size(0)
5604*da0073e9SAndroid Build Coastguard Worker                expected = numerator / denominator
5605*da0073e9SAndroid Build Coastguard Worker                error = torch.abs(expected - actual) / (1 + expected)
5606*da0073e9SAndroid Build Coastguard Worker                if error[error == error].max() < self.precision:
5607*da0073e9SAndroid Build Coastguard Worker                    break
5608*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
5609*da0073e9SAndroid Build Coastguard Worker                error[error == error].max(),
5610*da0073e9SAndroid Build Coastguard Worker                self.precision,
5611*da0073e9SAndroid Build Coastguard Worker                "\n".join(
5612*da0073e9SAndroid Build Coastguard Worker                    [
5613*da0073e9SAndroid Build Coastguard Worker                        f"Incorrect KL(MultivariateNormal, MultivariateNormal) instance {i + 1}/{n}",
5614*da0073e9SAndroid Build Coastguard Worker                        f"Expected ({denominator} Monte Carlo sample): {expected}",
5615*da0073e9SAndroid Build Coastguard Worker                        f"Actual (analytic): {actual}",
5616*da0073e9SAndroid Build Coastguard Worker                    ]
5617*da0073e9SAndroid Build Coastguard Worker                ),
5618*da0073e9SAndroid Build Coastguard Worker            )
5619*da0073e9SAndroid Build Coastguard Worker
5620*da0073e9SAndroid Build Coastguard Worker    def test_kl_multivariate_normal_batched(self):
5621*da0073e9SAndroid Build Coastguard Worker        b = 7  # Number of batches
5622*da0073e9SAndroid Build Coastguard Worker        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5623*da0073e9SAndroid Build Coastguard Worker        scale_tril = [
5624*da0073e9SAndroid Build Coastguard Worker            transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3))
5625*da0073e9SAndroid Build Coastguard Worker            for _ in range(0, 2)
5626*da0073e9SAndroid Build Coastguard Worker        ]
5627*da0073e9SAndroid Build Coastguard Worker        expected_kl = torch.stack(
5628*da0073e9SAndroid Build Coastguard Worker            [
5629*da0073e9SAndroid Build Coastguard Worker                kl_divergence(
5630*da0073e9SAndroid Build Coastguard Worker                    MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
5631*da0073e9SAndroid Build Coastguard Worker                    MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i]),
5632*da0073e9SAndroid Build Coastguard Worker                )
5633*da0073e9SAndroid Build Coastguard Worker                for i in range(0, b)
5634*da0073e9SAndroid Build Coastguard Worker            ]
5635*da0073e9SAndroid Build Coastguard Worker        )
5636*da0073e9SAndroid Build Coastguard Worker        actual_kl = kl_divergence(
5637*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
5638*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(loc[1], scale_tril=scale_tril[1]),
5639*da0073e9SAndroid Build Coastguard Worker        )
5640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_kl, actual_kl)
5641*da0073e9SAndroid Build Coastguard Worker
5642*da0073e9SAndroid Build Coastguard Worker    def test_kl_multivariate_normal_batched_broadcasted(self):
5643*da0073e9SAndroid Build Coastguard Worker        b = 7  # Number of batches
5644*da0073e9SAndroid Build Coastguard Worker        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5645*da0073e9SAndroid Build Coastguard Worker        scale_tril = [
5646*da0073e9SAndroid Build Coastguard Worker            transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
5647*da0073e9SAndroid Build Coastguard Worker            transform_to(constraints.lower_cholesky)(torch.randn(3, 3)),
5648*da0073e9SAndroid Build Coastguard Worker        ]
5649*da0073e9SAndroid Build Coastguard Worker        expected_kl = torch.stack(
5650*da0073e9SAndroid Build Coastguard Worker            [
5651*da0073e9SAndroid Build Coastguard Worker                kl_divergence(
5652*da0073e9SAndroid Build Coastguard Worker                    MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
5653*da0073e9SAndroid Build Coastguard Worker                    MultivariateNormal(loc[1][i], scale_tril=scale_tril[1]),
5654*da0073e9SAndroid Build Coastguard Worker                )
5655*da0073e9SAndroid Build Coastguard Worker                for i in range(0, b)
5656*da0073e9SAndroid Build Coastguard Worker            ]
5657*da0073e9SAndroid Build Coastguard Worker        )
5658*da0073e9SAndroid Build Coastguard Worker        actual_kl = kl_divergence(
5659*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
5660*da0073e9SAndroid Build Coastguard Worker            MultivariateNormal(loc[1], scale_tril=scale_tril[1]),
5661*da0073e9SAndroid Build Coastguard Worker        )
5662*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_kl, actual_kl)
5663*da0073e9SAndroid Build Coastguard Worker
5664*da0073e9SAndroid Build Coastguard Worker    def test_kl_lowrank_multivariate_normal(self):
5665*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
5666*da0073e9SAndroid Build Coastguard Worker        n = 5  # Number of tests for lowrank_multivariate_normal
5667*da0073e9SAndroid Build Coastguard Worker        for i in range(0, n):
5668*da0073e9SAndroid Build Coastguard Worker            loc = [torch.randn(4) for _ in range(0, 2)]
5669*da0073e9SAndroid Build Coastguard Worker            cov_factor = [torch.randn(4, 3) for _ in range(0, 2)]
5670*da0073e9SAndroid Build Coastguard Worker            cov_diag = [
5671*da0073e9SAndroid Build Coastguard Worker                transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2)
5672*da0073e9SAndroid Build Coastguard Worker            ]
5673*da0073e9SAndroid Build Coastguard Worker            covariance_matrix = [
5674*da0073e9SAndroid Build Coastguard Worker                cov_factor[i].matmul(cov_factor[i].t()) + cov_diag[i].diag()
5675*da0073e9SAndroid Build Coastguard Worker                for i in range(0, 2)
5676*da0073e9SAndroid Build Coastguard Worker            ]
5677*da0073e9SAndroid Build Coastguard Worker            p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0])
5678*da0073e9SAndroid Build Coastguard Worker            q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])
5679*da0073e9SAndroid Build Coastguard Worker            p_full = MultivariateNormal(loc[0], covariance_matrix[0])
5680*da0073e9SAndroid Build Coastguard Worker            q_full = MultivariateNormal(loc[1], covariance_matrix[1])
5681*da0073e9SAndroid Build Coastguard Worker            expected = kl_divergence(p_full, q_full)
5682*da0073e9SAndroid Build Coastguard Worker
5683*da0073e9SAndroid Build Coastguard Worker            actual_lowrank_lowrank = kl_divergence(p, q)
5684*da0073e9SAndroid Build Coastguard Worker            actual_lowrank_full = kl_divergence(p, q_full)
5685*da0073e9SAndroid Build Coastguard Worker            actual_full_lowrank = kl_divergence(p_full, q)
5686*da0073e9SAndroid Build Coastguard Worker
5687*da0073e9SAndroid Build Coastguard Worker            error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
5688*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
5689*da0073e9SAndroid Build Coastguard Worker                error_lowrank_lowrank,
5690*da0073e9SAndroid Build Coastguard Worker                self.precision,
5691*da0073e9SAndroid Build Coastguard Worker                "\n".join(
5692*da0073e9SAndroid Build Coastguard Worker                    [
5693*da0073e9SAndroid Build Coastguard Worker                        f"Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}",
5694*da0073e9SAndroid Build Coastguard Worker                        f"Expected (from KL MultivariateNormal): {expected}",
5695*da0073e9SAndroid Build Coastguard Worker                        f"Actual (analytic): {actual_lowrank_lowrank}",
5696*da0073e9SAndroid Build Coastguard Worker                    ]
5697*da0073e9SAndroid Build Coastguard Worker                ),
5698*da0073e9SAndroid Build Coastguard Worker            )
5699*da0073e9SAndroid Build Coastguard Worker
5700*da0073e9SAndroid Build Coastguard Worker            error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
5701*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
5702*da0073e9SAndroid Build Coastguard Worker                error_lowrank_full,
5703*da0073e9SAndroid Build Coastguard Worker                self.precision,
5704*da0073e9SAndroid Build Coastguard Worker                "\n".join(
5705*da0073e9SAndroid Build Coastguard Worker                    [
5706*da0073e9SAndroid Build Coastguard Worker                        f"Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {i + 1}/{n}",
5707*da0073e9SAndroid Build Coastguard Worker                        f"Expected (from KL MultivariateNormal): {expected}",
5708*da0073e9SAndroid Build Coastguard Worker                        f"Actual (analytic): {actual_lowrank_full}",
5709*da0073e9SAndroid Build Coastguard Worker                    ]
5710*da0073e9SAndroid Build Coastguard Worker                ),
5711*da0073e9SAndroid Build Coastguard Worker            )
5712*da0073e9SAndroid Build Coastguard Worker
5713*da0073e9SAndroid Build Coastguard Worker            error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
5714*da0073e9SAndroid Build Coastguard Worker            self.assertLess(
5715*da0073e9SAndroid Build Coastguard Worker                error_full_lowrank,
5716*da0073e9SAndroid Build Coastguard Worker                self.precision,
5717*da0073e9SAndroid Build Coastguard Worker                "\n".join(
5718*da0073e9SAndroid Build Coastguard Worker                    [
5719*da0073e9SAndroid Build Coastguard Worker                        f"Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}",
5720*da0073e9SAndroid Build Coastguard Worker                        f"Expected (from KL MultivariateNormal): {expected}",
5721*da0073e9SAndroid Build Coastguard Worker                        f"Actual (analytic): {actual_full_lowrank}",
5722*da0073e9SAndroid Build Coastguard Worker                    ]
5723*da0073e9SAndroid Build Coastguard Worker                ),
5724*da0073e9SAndroid Build Coastguard Worker            )
5725*da0073e9SAndroid Build Coastguard Worker
5726*da0073e9SAndroid Build Coastguard Worker    def test_kl_lowrank_multivariate_normal_batched(self):
5727*da0073e9SAndroid Build Coastguard Worker        b = 7  # Number of batches
5728*da0073e9SAndroid Build Coastguard Worker        loc = [torch.randn(b, 3) for _ in range(0, 2)]
5729*da0073e9SAndroid Build Coastguard Worker        cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)]
5730*da0073e9SAndroid Build Coastguard Worker        cov_diag = [
5731*da0073e9SAndroid Build Coastguard Worker            transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2)
5732*da0073e9SAndroid Build Coastguard Worker        ]
5733*da0073e9SAndroid Build Coastguard Worker        expected_kl = torch.stack(
5734*da0073e9SAndroid Build Coastguard Worker            [
5735*da0073e9SAndroid Build Coastguard Worker                kl_divergence(
5736*da0073e9SAndroid Build Coastguard Worker                    LowRankMultivariateNormal(
5737*da0073e9SAndroid Build Coastguard Worker                        loc[0][i], cov_factor[0][i], cov_diag[0][i]
5738*da0073e9SAndroid Build Coastguard Worker                    ),
5739*da0073e9SAndroid Build Coastguard Worker                    LowRankMultivariateNormal(
5740*da0073e9SAndroid Build Coastguard Worker                        loc[1][i], cov_factor[1][i], cov_diag[1][i]
5741*da0073e9SAndroid Build Coastguard Worker                    ),
5742*da0073e9SAndroid Build Coastguard Worker                )
5743*da0073e9SAndroid Build Coastguard Worker                for i in range(0, b)
5744*da0073e9SAndroid Build Coastguard Worker            ]
5745*da0073e9SAndroid Build Coastguard Worker        )
5746*da0073e9SAndroid Build Coastguard Worker        actual_kl = kl_divergence(
5747*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]),
5748*da0073e9SAndroid Build Coastguard Worker            LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]),
5749*da0073e9SAndroid Build Coastguard Worker        )
5750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_kl, actual_kl)
5751*da0073e9SAndroid Build Coastguard Worker
5752*da0073e9SAndroid Build Coastguard Worker    def test_kl_exponential_family(self):
5753*da0073e9SAndroid Build Coastguard Worker        for (p, _), (_, q) in self.finite_examples:
5754*da0073e9SAndroid Build Coastguard Worker            if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
5755*da0073e9SAndroid Build Coastguard Worker                actual = kl_divergence(p, q)
5756*da0073e9SAndroid Build Coastguard Worker                expected = _kl_expfamily_expfamily(p, q)
5757*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5758*da0073e9SAndroid Build Coastguard Worker                    actual,
5759*da0073e9SAndroid Build Coastguard Worker                    expected,
5760*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
5761*da0073e9SAndroid Build Coastguard Worker                        [
5762*da0073e9SAndroid Build Coastguard Worker                            f"Incorrect KL({type(p).__name__}, {type(q).__name__}).",
5763*da0073e9SAndroid Build Coastguard Worker                            f"Expected (using Bregman Divergence) {expected}",
5764*da0073e9SAndroid Build Coastguard Worker                            f"Actual (analytic) {actual}",
5765*da0073e9SAndroid Build Coastguard Worker                            f"max error = {torch.abs(actual - expected).max()}",
5766*da0073e9SAndroid Build Coastguard Worker                        ]
5767*da0073e9SAndroid Build Coastguard Worker                    ),
5768*da0073e9SAndroid Build Coastguard Worker                )
5769*da0073e9SAndroid Build Coastguard Worker
5770*da0073e9SAndroid Build Coastguard Worker    def test_kl_infinite(self):
5771*da0073e9SAndroid Build Coastguard Worker        for p, q in self.infinite_examples:
5772*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
5773*da0073e9SAndroid Build Coastguard Worker                (kl_divergence(p, q) == inf).all(),
5774*da0073e9SAndroid Build Coastguard Worker                f"Incorrect KL({type(p).__name__}, {type(q).__name__})",
5775*da0073e9SAndroid Build Coastguard Worker            )
5776*da0073e9SAndroid Build Coastguard Worker
5777*da0073e9SAndroid Build Coastguard Worker    def test_kl_edgecases(self):
5778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
5779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
5780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5781*da0073e9SAndroid Build Coastguard Worker            kl_divergence(
5782*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([0.0, 1.0])),
5783*da0073e9SAndroid Build Coastguard Worker                Categorical(torch.tensor([0.0, 1.0])),
5784*da0073e9SAndroid Build Coastguard Worker            ),
5785*da0073e9SAndroid Build Coastguard Worker            0,
5786*da0073e9SAndroid Build Coastguard Worker        )
5787*da0073e9SAndroid Build Coastguard Worker
5788*da0073e9SAndroid Build Coastguard Worker    def test_kl_shape(self):
5789*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
5790*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
5791*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
5792*da0073e9SAndroid Build Coastguard Worker                try:
5793*da0073e9SAndroid Build Coastguard Worker                    kl = kl_divergence(dist, dist)
5794*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
5795*da0073e9SAndroid Build Coastguard Worker                    continue
5796*da0073e9SAndroid Build Coastguard Worker                expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
5797*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5798*da0073e9SAndroid Build Coastguard Worker                    kl.shape,
5799*da0073e9SAndroid Build Coastguard Worker                    expected_shape,
5800*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
5801*da0073e9SAndroid Build Coastguard Worker                        [
5802*da0073e9SAndroid Build Coastguard Worker                            f"{Dist.__name__} example {i + 1}/{len(params)}",
5803*da0073e9SAndroid Build Coastguard Worker                            f"Expected {expected_shape}",
5804*da0073e9SAndroid Build Coastguard Worker                            f"Actual {kl.shape}",
5805*da0073e9SAndroid Build Coastguard Worker                        ]
5806*da0073e9SAndroid Build Coastguard Worker                    ),
5807*da0073e9SAndroid Build Coastguard Worker                )
5808*da0073e9SAndroid Build Coastguard Worker
5809*da0073e9SAndroid Build Coastguard Worker    def test_kl_transformed(self):
5810*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/34859
5811*da0073e9SAndroid Build Coastguard Worker        scale = torch.ones(2, 3)
5812*da0073e9SAndroid Build Coastguard Worker        loc = torch.zeros(2, 3)
5813*da0073e9SAndroid Build Coastguard Worker        normal = Normal(loc=loc, scale=scale)
5814*da0073e9SAndroid Build Coastguard Worker        diag_normal = Independent(normal, reinterpreted_batch_ndims=1)
5815*da0073e9SAndroid Build Coastguard Worker        trans_dist = TransformedDistribution(
5816*da0073e9SAndroid Build Coastguard Worker            diag_normal, AffineTransform(loc=0.0, scale=2.0)
5817*da0073e9SAndroid Build Coastguard Worker        )
5818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,))
5819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,))
5820*da0073e9SAndroid Build Coastguard Worker
5821*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
5822*da0073e9SAndroid Build Coastguard Worker    def test_entropy_monte_carlo(self):
5823*da0073e9SAndroid Build Coastguard Worker        set_rng_seed(0)  # see Note [Randomized statistical tests]
5824*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
5825*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
5826*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
5827*da0073e9SAndroid Build Coastguard Worker                try:
5828*da0073e9SAndroid Build Coastguard Worker                    actual = dist.entropy()
5829*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
5830*da0073e9SAndroid Build Coastguard Worker                    continue
5831*da0073e9SAndroid Build Coastguard Worker                x = dist.sample(sample_shape=(60000,))
5832*da0073e9SAndroid Build Coastguard Worker                expected = -dist.log_prob(x).mean(0)
5833*da0073e9SAndroid Build Coastguard Worker                ignore = (expected == inf) | (expected == -inf)
5834*da0073e9SAndroid Build Coastguard Worker                expected[ignore] = actual[ignore]
5835*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5836*da0073e9SAndroid Build Coastguard Worker                    actual,
5837*da0073e9SAndroid Build Coastguard Worker                    expected,
5838*da0073e9SAndroid Build Coastguard Worker                    atol=0.2,
5839*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
5840*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
5841*da0073e9SAndroid Build Coastguard Worker                        [
5842*da0073e9SAndroid Build Coastguard Worker                            f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().",
5843*da0073e9SAndroid Build Coastguard Worker                            f"Expected (monte carlo) {expected}",
5844*da0073e9SAndroid Build Coastguard Worker                            f"Actual (analytic) {actual}",
5845*da0073e9SAndroid Build Coastguard Worker                            f"max error = {torch.abs(actual - expected).max()}",
5846*da0073e9SAndroid Build Coastguard Worker                        ]
5847*da0073e9SAndroid Build Coastguard Worker                    ),
5848*da0073e9SAndroid Build Coastguard Worker                )
5849*da0073e9SAndroid Build Coastguard Worker
5850*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
5851*da0073e9SAndroid Build Coastguard Worker    def test_entropy_exponential_family(self):
5852*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
5853*da0073e9SAndroid Build Coastguard Worker            if not issubclass(Dist, ExponentialFamily):
5854*da0073e9SAndroid Build Coastguard Worker                continue
5855*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
5856*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
5857*da0073e9SAndroid Build Coastguard Worker                try:
5858*da0073e9SAndroid Build Coastguard Worker                    actual = dist.entropy()
5859*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
5860*da0073e9SAndroid Build Coastguard Worker                    continue
5861*da0073e9SAndroid Build Coastguard Worker                try:
5862*da0073e9SAndroid Build Coastguard Worker                    expected = ExponentialFamily.entropy(dist)
5863*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
5864*da0073e9SAndroid Build Coastguard Worker                    continue
5865*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5866*da0073e9SAndroid Build Coastguard Worker                    actual,
5867*da0073e9SAndroid Build Coastguard Worker                    expected,
5868*da0073e9SAndroid Build Coastguard Worker                    msg="\n".join(
5869*da0073e9SAndroid Build Coastguard Worker                        [
5870*da0073e9SAndroid Build Coastguard Worker                            f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().",
5871*da0073e9SAndroid Build Coastguard Worker                            f"Expected (Bregman Divergence) {expected}",
5872*da0073e9SAndroid Build Coastguard Worker                            f"Actual (analytic) {actual}",
5873*da0073e9SAndroid Build Coastguard Worker                            f"max error = {torch.abs(actual - expected).max()}",
5874*da0073e9SAndroid Build Coastguard Worker                        ]
5875*da0073e9SAndroid Build Coastguard Worker                    ),
5876*da0073e9SAndroid Build Coastguard Worker                )
5877*da0073e9SAndroid Build Coastguard Worker
5878*da0073e9SAndroid Build Coastguard Worker
5879*da0073e9SAndroid Build Coastguard Workerclass TestConstraints(DistributionsTestCase):
5880*da0073e9SAndroid Build Coastguard Worker    def test_params_constraints(self):
5881*da0073e9SAndroid Build Coastguard Worker        normalize_probs_dists = (
5882*da0073e9SAndroid Build Coastguard Worker            Categorical,
5883*da0073e9SAndroid Build Coastguard Worker            Multinomial,
5884*da0073e9SAndroid Build Coastguard Worker            OneHotCategorical,
5885*da0073e9SAndroid Build Coastguard Worker            OneHotCategoricalStraightThrough,
5886*da0073e9SAndroid Build Coastguard Worker            RelaxedOneHotCategorical,
5887*da0073e9SAndroid Build Coastguard Worker        )
5888*da0073e9SAndroid Build Coastguard Worker
5889*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
5890*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
5891*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
5892*da0073e9SAndroid Build Coastguard Worker                for name, value in param.items():
5893*da0073e9SAndroid Build Coastguard Worker                    if isinstance(value, numbers.Number):
5894*da0073e9SAndroid Build Coastguard Worker                        value = torch.tensor([value])
5895*da0073e9SAndroid Build Coastguard Worker                    if Dist in normalize_probs_dists and name == "probs":
5896*da0073e9SAndroid Build Coastguard Worker                        # These distributions accept positive probs, but elsewhere we
5897*da0073e9SAndroid Build Coastguard Worker                        # use a stricter constraint to the simplex.
5898*da0073e9SAndroid Build Coastguard Worker                        value = value / value.sum(-1, True)
5899*da0073e9SAndroid Build Coastguard Worker                    try:
5900*da0073e9SAndroid Build Coastguard Worker                        constraint = dist.arg_constraints[name]
5901*da0073e9SAndroid Build Coastguard Worker                    except KeyError:
5902*da0073e9SAndroid Build Coastguard Worker                        continue  # ignore optional parameters
5903*da0073e9SAndroid Build Coastguard Worker
5904*da0073e9SAndroid Build Coastguard Worker                    # Check param shape is compatible with distribution shape.
5905*da0073e9SAndroid Build Coastguard Worker                    self.assertGreaterEqual(value.dim(), constraint.event_dim)
5906*da0073e9SAndroid Build Coastguard Worker                    value_batch_shape = value.shape[
5907*da0073e9SAndroid Build Coastguard Worker                        : value.dim() - constraint.event_dim
5908*da0073e9SAndroid Build Coastguard Worker                    ]
5909*da0073e9SAndroid Build Coastguard Worker                    torch.broadcast_shapes(dist.batch_shape, value_batch_shape)
5910*da0073e9SAndroid Build Coastguard Worker
5911*da0073e9SAndroid Build Coastguard Worker                    if is_dependent(constraint):
5912*da0073e9SAndroid Build Coastguard Worker                        continue
5913*da0073e9SAndroid Build Coastguard Worker
5914*da0073e9SAndroid Build Coastguard Worker                    message = f"{Dist.__name__} example {i + 1}/{len(params)} parameter {name} = {value}"
5915*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(constraint.check(value).all(), msg=message)
5916*da0073e9SAndroid Build Coastguard Worker
5917*da0073e9SAndroid Build Coastguard Worker    def test_support_constraints(self):
5918*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
5919*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(Dist.support, Constraint)
5920*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
5921*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
5922*da0073e9SAndroid Build Coastguard Worker                value = dist.sample()
5923*da0073e9SAndroid Build Coastguard Worker                constraint = dist.support
5924*da0073e9SAndroid Build Coastguard Worker                message = (
5925*da0073e9SAndroid Build Coastguard Worker                    f"{Dist.__name__} example {i + 1}/{len(params)} sample = {value}"
5926*da0073e9SAndroid Build Coastguard Worker                )
5927*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5928*da0073e9SAndroid Build Coastguard Worker                    constraint.event_dim, len(dist.event_shape), msg=message
5929*da0073e9SAndroid Build Coastguard Worker                )
5930*da0073e9SAndroid Build Coastguard Worker                ok = constraint.check(value)
5931*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ok.shape, dist.batch_shape, msg=message)
5932*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(ok.all(), msg=message)
5933*da0073e9SAndroid Build Coastguard Worker
5934*da0073e9SAndroid Build Coastguard Worker
5935*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test")
5936*da0073e9SAndroid Build Coastguard Workerclass TestNumericalStability(DistributionsTestCase):
5937*da0073e9SAndroid Build Coastguard Worker    def _test_pdf_score(
5938*da0073e9SAndroid Build Coastguard Worker        self,
5939*da0073e9SAndroid Build Coastguard Worker        dist_class,
5940*da0073e9SAndroid Build Coastguard Worker        x,
5941*da0073e9SAndroid Build Coastguard Worker        expected_value,
5942*da0073e9SAndroid Build Coastguard Worker        probs=None,
5943*da0073e9SAndroid Build Coastguard Worker        logits=None,
5944*da0073e9SAndroid Build Coastguard Worker        expected_gradient=None,
5945*da0073e9SAndroid Build Coastguard Worker        atol=1e-5,
5946*da0073e9SAndroid Build Coastguard Worker    ):
5947*da0073e9SAndroid Build Coastguard Worker        if probs is not None:
5948*da0073e9SAndroid Build Coastguard Worker            p = probs.detach().requires_grad_()
5949*da0073e9SAndroid Build Coastguard Worker            dist = dist_class(p)
5950*da0073e9SAndroid Build Coastguard Worker        else:
5951*da0073e9SAndroid Build Coastguard Worker            p = logits.detach().requires_grad_()
5952*da0073e9SAndroid Build Coastguard Worker            dist = dist_class(logits=p)
5953*da0073e9SAndroid Build Coastguard Worker        log_pdf = dist.log_prob(x)
5954*da0073e9SAndroid Build Coastguard Worker        log_pdf.sum().backward()
5955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
5956*da0073e9SAndroid Build Coastguard Worker            log_pdf,
5957*da0073e9SAndroid Build Coastguard Worker            expected_value,
5958*da0073e9SAndroid Build Coastguard Worker            atol=atol,
5959*da0073e9SAndroid Build Coastguard Worker            rtol=0,
5960*da0073e9SAndroid Build Coastguard Worker            msg=f"Incorrect value for tensor type: {type(x)}. Expected = {expected_value}, Actual = {log_pdf}",
5961*da0073e9SAndroid Build Coastguard Worker        )
5962*da0073e9SAndroid Build Coastguard Worker        if expected_gradient is not None:
5963*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5964*da0073e9SAndroid Build Coastguard Worker                p.grad,
5965*da0073e9SAndroid Build Coastguard Worker                expected_gradient,
5966*da0073e9SAndroid Build Coastguard Worker                atol=atol,
5967*da0073e9SAndroid Build Coastguard Worker                rtol=0,
5968*da0073e9SAndroid Build Coastguard Worker                msg=f"Incorrect gradient for tensor type: {type(x)}. Expected = {expected_gradient}, Actual = {p.grad}",
5969*da0073e9SAndroid Build Coastguard Worker            )
5970*da0073e9SAndroid Build Coastguard Worker
5971*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_gradient(self):
5972*da0073e9SAndroid Build Coastguard Worker        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
5973*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
5974*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
5975*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([0]),
5976*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
5977*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([0]),
5978*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0]),
5979*da0073e9SAndroid Build Coastguard Worker            )
5980*da0073e9SAndroid Build Coastguard Worker
5981*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
5982*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
5983*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([0]),
5984*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1]),
5985*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type(
5986*da0073e9SAndroid Build Coastguard Worker                    [torch.finfo(tensor_type([]).dtype).eps]
5987*da0073e9SAndroid Build Coastguard Worker                ).log(),
5988*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0]),
5989*da0073e9SAndroid Build Coastguard Worker            )
5990*da0073e9SAndroid Build Coastguard Worker
5991*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
5992*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
5993*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([1e-4]),
5994*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1]),
5995*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([math.log(1e-4)]),
5996*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([10000]),
5997*da0073e9SAndroid Build Coastguard Worker            )
5998*da0073e9SAndroid Build Coastguard Worker
5999*da0073e9SAndroid Build Coastguard Worker            # Lower precision due to:
6000*da0073e9SAndroid Build Coastguard Worker            # >>> 1 / (1 - torch.FloatTensor([0.9999]))
6001*da0073e9SAndroid Build Coastguard Worker            # 9998.3408
6002*da0073e9SAndroid Build Coastguard Worker            # [torch.FloatTensor of size 1]
6003*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6004*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
6005*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([1 - 1e-4]),
6006*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
6007*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([math.log(1e-4)]),
6008*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([-10000]),
6009*da0073e9SAndroid Build Coastguard Worker                atol=2,
6010*da0073e9SAndroid Build Coastguard Worker            )
6011*da0073e9SAndroid Build Coastguard Worker
6012*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6013*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
6014*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([math.log(9999)]),
6015*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
6016*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([math.log(1e-4)]),
6017*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([-1]),
6018*da0073e9SAndroid Build Coastguard Worker                atol=1e-3,
6019*da0073e9SAndroid Build Coastguard Worker            )
6020*da0073e9SAndroid Build Coastguard Worker
6021*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_with_logits_underflow(self):
6022*da0073e9SAndroid Build Coastguard Worker        for tensor_type, lim in [
6023*da0073e9SAndroid Build Coastguard Worker            (torch.FloatTensor, -1e38),
6024*da0073e9SAndroid Build Coastguard Worker            (torch.DoubleTensor, -1e308),
6025*da0073e9SAndroid Build Coastguard Worker        ]:
6026*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6027*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
6028*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([lim]),
6029*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
6030*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([0]),
6031*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0]),
6032*da0073e9SAndroid Build Coastguard Worker            )
6033*da0073e9SAndroid Build Coastguard Worker
6034*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli_with_logits_overflow(self):
6035*da0073e9SAndroid Build Coastguard Worker        for tensor_type, lim in [
6036*da0073e9SAndroid Build Coastguard Worker            (torch.FloatTensor, 1e38),
6037*da0073e9SAndroid Build Coastguard Worker            (torch.DoubleTensor, 1e308),
6038*da0073e9SAndroid Build Coastguard Worker        ]:
6039*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6040*da0073e9SAndroid Build Coastguard Worker                dist_class=Bernoulli,
6041*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([lim]),
6042*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1]),
6043*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([0]),
6044*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0]),
6045*da0073e9SAndroid Build Coastguard Worker            )
6046*da0073e9SAndroid Build Coastguard Worker
6047*da0073e9SAndroid Build Coastguard Worker    def test_categorical_log_prob(self):
6048*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
6049*da0073e9SAndroid Build Coastguard Worker            p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
6050*da0073e9SAndroid Build Coastguard Worker            categorical = OneHotCategorical(p)
6051*da0073e9SAndroid Build Coastguard Worker            log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
6052*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf.item(), 0)
6053*da0073e9SAndroid Build Coastguard Worker
6054*da0073e9SAndroid Build Coastguard Worker    def test_categorical_log_prob_with_logits(self):
6055*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
6056*da0073e9SAndroid Build Coastguard Worker            p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
6057*da0073e9SAndroid Build Coastguard Worker            categorical = OneHotCategorical(logits=p)
6058*da0073e9SAndroid Build Coastguard Worker            log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
6059*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf_prob_1.item(), 0)
6060*da0073e9SAndroid Build Coastguard Worker            log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype))
6061*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf_prob_0.item(), -inf)
6062*da0073e9SAndroid Build Coastguard Worker
6063*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_log_prob(self):
6064*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
6065*da0073e9SAndroid Build Coastguard Worker            p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
6066*da0073e9SAndroid Build Coastguard Worker            s = torch.tensor([0, 10], dtype=dtype)
6067*da0073e9SAndroid Build Coastguard Worker            multinomial = Multinomial(10, p)
6068*da0073e9SAndroid Build Coastguard Worker            log_pdf = multinomial.log_prob(s)
6069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf.item(), 0)
6070*da0073e9SAndroid Build Coastguard Worker
6071*da0073e9SAndroid Build Coastguard Worker    def test_multinomial_log_prob_with_logits(self):
6072*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
6073*da0073e9SAndroid Build Coastguard Worker            p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
6074*da0073e9SAndroid Build Coastguard Worker            multinomial = Multinomial(10, logits=p)
6075*da0073e9SAndroid Build Coastguard Worker            log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype))
6076*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf_prob_1.item(), 0)
6077*da0073e9SAndroid Build Coastguard Worker            log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
6078*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_pdf_prob_0.item(), -inf)
6079*da0073e9SAndroid Build Coastguard Worker
6080*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_gradient(self):
6081*da0073e9SAndroid Build Coastguard Worker        def expec_val(x, probs=None, logits=None):
6082*da0073e9SAndroid Build Coastguard Worker            assert not (probs is None and logits is None)
6083*da0073e9SAndroid Build Coastguard Worker            if logits is not None:
6084*da0073e9SAndroid Build Coastguard Worker                probs = 1.0 / (1.0 + math.exp(-logits))
6085*da0073e9SAndroid Build Coastguard Worker            bern_log_lik = x * math.log(probs) + (1.0 - x) * math.log1p(-probs)
6086*da0073e9SAndroid Build Coastguard Worker            if probs < 0.499 or probs > 0.501:  # using default values of lims here
6087*da0073e9SAndroid Build Coastguard Worker                log_norm_const = (
6088*da0073e9SAndroid Build Coastguard Worker                    math.log(math.fabs(math.atanh(1.0 - 2.0 * probs)))
6089*da0073e9SAndroid Build Coastguard Worker                    - math.log(math.fabs(1.0 - 2.0 * probs))
6090*da0073e9SAndroid Build Coastguard Worker                    + math.log(2.0)
6091*da0073e9SAndroid Build Coastguard Worker                )
6092*da0073e9SAndroid Build Coastguard Worker            else:
6093*da0073e9SAndroid Build Coastguard Worker                aux = math.pow(probs - 0.5, 2)
6094*da0073e9SAndroid Build Coastguard Worker                log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
6095*da0073e9SAndroid Build Coastguard Worker            log_lik = bern_log_lik + log_norm_const
6096*da0073e9SAndroid Build Coastguard Worker            return log_lik
6097*da0073e9SAndroid Build Coastguard Worker
6098*da0073e9SAndroid Build Coastguard Worker        def expec_grad(x, probs=None, logits=None):
6099*da0073e9SAndroid Build Coastguard Worker            assert not (probs is None and logits is None)
6100*da0073e9SAndroid Build Coastguard Worker            if logits is not None:
6101*da0073e9SAndroid Build Coastguard Worker                probs = 1.0 / (1.0 + math.exp(-logits))
6102*da0073e9SAndroid Build Coastguard Worker            grad_bern_log_lik = x / probs - (1.0 - x) / (1.0 - probs)
6103*da0073e9SAndroid Build Coastguard Worker            if probs < 0.499 or probs > 0.501:  # using default values of lims here
6104*da0073e9SAndroid Build Coastguard Worker                grad_log_c = (
6105*da0073e9SAndroid Build Coastguard Worker                    2.0 * probs
6106*da0073e9SAndroid Build Coastguard Worker                    - 4.0 * (probs - 1.0) * probs * math.atanh(1.0 - 2.0 * probs)
6107*da0073e9SAndroid Build Coastguard Worker                    - 1.0
6108*da0073e9SAndroid Build Coastguard Worker                )
6109*da0073e9SAndroid Build Coastguard Worker                grad_log_c /= (
6110*da0073e9SAndroid Build Coastguard Worker                    2.0
6111*da0073e9SAndroid Build Coastguard Worker                    * (probs - 1.0)
6112*da0073e9SAndroid Build Coastguard Worker                    * probs
6113*da0073e9SAndroid Build Coastguard Worker                    * (2.0 * probs - 1.0)
6114*da0073e9SAndroid Build Coastguard Worker                    * math.atanh(1.0 - 2.0 * probs)
6115*da0073e9SAndroid Build Coastguard Worker                )
6116*da0073e9SAndroid Build Coastguard Worker            else:
6117*da0073e9SAndroid Build Coastguard Worker                grad_log_c = 8.0 / 3.0 * (probs - 0.5) + 416.0 / 45.0 * math.pow(
6118*da0073e9SAndroid Build Coastguard Worker                    probs - 0.5, 3
6119*da0073e9SAndroid Build Coastguard Worker                )
6120*da0073e9SAndroid Build Coastguard Worker            grad = grad_bern_log_lik + grad_log_c
6121*da0073e9SAndroid Build Coastguard Worker            if logits is not None:
6122*da0073e9SAndroid Build Coastguard Worker                grad *= 1.0 / (1.0 + math.exp(logits)) - 1.0 / math.pow(
6123*da0073e9SAndroid Build Coastguard Worker                    1.0 + math.exp(logits), 2
6124*da0073e9SAndroid Build Coastguard Worker                )
6125*da0073e9SAndroid Build Coastguard Worker            return grad
6126*da0073e9SAndroid Build Coastguard Worker
6127*da0073e9SAndroid Build Coastguard Worker        for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
6128*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6129*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6130*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([0.1]),
6131*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0.1]),
6132*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
6133*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]),
6134*da0073e9SAndroid Build Coastguard Worker            )
6135*da0073e9SAndroid Build Coastguard Worker
6136*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6137*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6138*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([0.1]),
6139*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1.0]),
6140*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(1.0, probs=0.1)]),
6141*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(1.0, probs=0.1)]),
6142*da0073e9SAndroid Build Coastguard Worker            )
6143*da0073e9SAndroid Build Coastguard Worker
6144*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6145*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6146*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([0.4999]),
6147*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0.9]),
6148*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
6149*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]),
6150*da0073e9SAndroid Build Coastguard Worker            )
6151*da0073e9SAndroid Build Coastguard Worker
6152*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6153*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6154*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([1e-4]),
6155*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1]),
6156*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(1, probs=1e-4)]),
6157*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
6158*da0073e9SAndroid Build Coastguard Worker                atol=1e-3,
6159*da0073e9SAndroid Build Coastguard Worker            )
6160*da0073e9SAndroid Build Coastguard Worker
6161*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6162*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6163*da0073e9SAndroid Build Coastguard Worker                probs=tensor_type([1 - 1e-4]),
6164*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0.1]),
6165*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
6166*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
6167*da0073e9SAndroid Build Coastguard Worker                atol=2,
6168*da0073e9SAndroid Build Coastguard Worker            )
6169*da0073e9SAndroid Build Coastguard Worker
6170*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6171*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6172*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([math.log(9999)]),
6173*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
6174*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
6175*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
6176*da0073e9SAndroid Build Coastguard Worker                atol=1e-3,
6177*da0073e9SAndroid Build Coastguard Worker            )
6178*da0073e9SAndroid Build Coastguard Worker
6179*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6180*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6181*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([0.001]),
6182*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0.5]),
6183*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
6184*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]),
6185*da0073e9SAndroid Build Coastguard Worker            )
6186*da0073e9SAndroid Build Coastguard Worker
6187*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_with_logits_underflow(self):
6188*da0073e9SAndroid Build Coastguard Worker        for tensor_type, lim, expected in [
6189*da0073e9SAndroid Build Coastguard Worker            (torch.FloatTensor, -1e38, 2.76898),
6190*da0073e9SAndroid Build Coastguard Worker            (torch.DoubleTensor, -1e308, 3.58473),
6191*da0073e9SAndroid Build Coastguard Worker        ]:
6192*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6193*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6194*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([lim]),
6195*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([0]),
6196*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expected]),
6197*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0.0]),
6198*da0073e9SAndroid Build Coastguard Worker            )
6199*da0073e9SAndroid Build Coastguard Worker
6200*da0073e9SAndroid Build Coastguard Worker    def test_continuous_bernoulli_with_logits_overflow(self):
6201*da0073e9SAndroid Build Coastguard Worker        for tensor_type, lim, expected in [
6202*da0073e9SAndroid Build Coastguard Worker            (torch.FloatTensor, 1e38, 2.76898),
6203*da0073e9SAndroid Build Coastguard Worker            (torch.DoubleTensor, 1e308, 3.58473),
6204*da0073e9SAndroid Build Coastguard Worker        ]:
6205*da0073e9SAndroid Build Coastguard Worker            self._test_pdf_score(
6206*da0073e9SAndroid Build Coastguard Worker                dist_class=ContinuousBernoulli,
6207*da0073e9SAndroid Build Coastguard Worker                logits=tensor_type([lim]),
6208*da0073e9SAndroid Build Coastguard Worker                x=tensor_type([1]),
6209*da0073e9SAndroid Build Coastguard Worker                expected_value=tensor_type([expected]),
6210*da0073e9SAndroid Build Coastguard Worker                expected_gradient=tensor_type([0.0]),
6211*da0073e9SAndroid Build Coastguard Worker            )
6212*da0073e9SAndroid Build Coastguard Worker
6213*da0073e9SAndroid Build Coastguard Worker
6214*da0073e9SAndroid Build Coastguard Worker# TODO: make this a pytest parameterized test
6215*da0073e9SAndroid Build Coastguard Workerclass TestLazyLogitsInitialization(DistributionsTestCase):
6216*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
6217*da0073e9SAndroid Build Coastguard Worker        super().setUp()
6218*da0073e9SAndroid Build Coastguard Worker        # ContinuousBernoulli is not tested because log_prob is not computed simply
6219*da0073e9SAndroid Build Coastguard Worker        # from 'logits', but 'probs' is also needed
6220*da0073e9SAndroid Build Coastguard Worker        self.examples = [
6221*da0073e9SAndroid Build Coastguard Worker            e
6222*da0073e9SAndroid Build Coastguard Worker            for e in _get_examples()
6223*da0073e9SAndroid Build Coastguard Worker            if e.Dist
6224*da0073e9SAndroid Build Coastguard Worker            in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)
6225*da0073e9SAndroid Build Coastguard Worker        ]
6226*da0073e9SAndroid Build Coastguard Worker
6227*da0073e9SAndroid Build Coastguard Worker    def test_lazy_logits_initialization(self):
6228*da0073e9SAndroid Build Coastguard Worker        for Dist, params in self.examples:
6229*da0073e9SAndroid Build Coastguard Worker            param = params[0].copy()
6230*da0073e9SAndroid Build Coastguard Worker            if "probs" not in param:
6231*da0073e9SAndroid Build Coastguard Worker                continue
6232*da0073e9SAndroid Build Coastguard Worker            probs = param.pop("probs")
6233*da0073e9SAndroid Build Coastguard Worker            param["logits"] = probs_to_logits(probs)
6234*da0073e9SAndroid Build Coastguard Worker            dist = Dist(**param)
6235*da0073e9SAndroid Build Coastguard Worker            # Create new instance to generate a valid sample
6236*da0073e9SAndroid Build Coastguard Worker            dist.log_prob(Dist(**param).sample())
6237*da0073e9SAndroid Build Coastguard Worker            message = f"Failed for {Dist.__name__} example 0/{len(params)}"
6238*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("probs", dist.__dict__, msg=message)
6239*da0073e9SAndroid Build Coastguard Worker            try:
6240*da0073e9SAndroid Build Coastguard Worker                dist.enumerate_support()
6241*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6242*da0073e9SAndroid Build Coastguard Worker                pass
6243*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("probs", dist.__dict__, msg=message)
6244*da0073e9SAndroid Build Coastguard Worker            batch_shape, event_shape = dist.batch_shape, dist.event_shape
6245*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("probs", dist.__dict__, msg=message)
6246*da0073e9SAndroid Build Coastguard Worker
6247*da0073e9SAndroid Build Coastguard Worker    def test_lazy_probs_initialization(self):
6248*da0073e9SAndroid Build Coastguard Worker        for Dist, params in self.examples:
6249*da0073e9SAndroid Build Coastguard Worker            param = params[0].copy()
6250*da0073e9SAndroid Build Coastguard Worker            if "probs" not in param:
6251*da0073e9SAndroid Build Coastguard Worker                continue
6252*da0073e9SAndroid Build Coastguard Worker            dist = Dist(**param)
6253*da0073e9SAndroid Build Coastguard Worker            dist.sample()
6254*da0073e9SAndroid Build Coastguard Worker            message = f"Failed for {Dist.__name__} example 0/{len(params)}"
6255*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("logits", dist.__dict__, msg=message)
6256*da0073e9SAndroid Build Coastguard Worker            try:
6257*da0073e9SAndroid Build Coastguard Worker                dist.enumerate_support()
6258*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6259*da0073e9SAndroid Build Coastguard Worker                pass
6260*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("logits", dist.__dict__, msg=message)
6261*da0073e9SAndroid Build Coastguard Worker            batch_shape, event_shape = dist.batch_shape, dist.event_shape
6262*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn("logits", dist.__dict__, msg=message)
6263*da0073e9SAndroid Build Coastguard Worker
6264*da0073e9SAndroid Build Coastguard Worker
6265*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
6266*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("FIXME: Tries to trace through SciPy and fails")
6267*da0073e9SAndroid Build Coastguard Workerclass TestAgainstScipy(DistributionsTestCase):
6268*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
6269*da0073e9SAndroid Build Coastguard Worker        super().setUp()
6270*da0073e9SAndroid Build Coastguard Worker        positive_var = torch.randn(20, dtype=torch.double).exp()
6271*da0073e9SAndroid Build Coastguard Worker        positive_var2 = torch.randn(20, dtype=torch.double).exp()
6272*da0073e9SAndroid Build Coastguard Worker        random_var = torch.randn(20, dtype=torch.double)
6273*da0073e9SAndroid Build Coastguard Worker        simplex_tensor = softmax(torch.randn(20, dtype=torch.double), dim=-1)
6274*da0073e9SAndroid Build Coastguard Worker        cov_tensor = torch.randn(20, 20, dtype=torch.double)
6275*da0073e9SAndroid Build Coastguard Worker        cov_tensor = cov_tensor @ cov_tensor.mT
6276*da0073e9SAndroid Build Coastguard Worker        self.distribution_pairs = [
6277*da0073e9SAndroid Build Coastguard Worker            (Bernoulli(simplex_tensor), scipy.stats.bernoulli(simplex_tensor)),
6278*da0073e9SAndroid Build Coastguard Worker            (
6279*da0073e9SAndroid Build Coastguard Worker                Beta(positive_var, positive_var2),
6280*da0073e9SAndroid Build Coastguard Worker                scipy.stats.beta(positive_var, positive_var2),
6281*da0073e9SAndroid Build Coastguard Worker            ),
6282*da0073e9SAndroid Build Coastguard Worker            (
6283*da0073e9SAndroid Build Coastguard Worker                Binomial(10, simplex_tensor),
6284*da0073e9SAndroid Build Coastguard Worker                scipy.stats.binom(
6285*da0073e9SAndroid Build Coastguard Worker                    10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy()
6286*da0073e9SAndroid Build Coastguard Worker                ),
6287*da0073e9SAndroid Build Coastguard Worker            ),
6288*da0073e9SAndroid Build Coastguard Worker            (
6289*da0073e9SAndroid Build Coastguard Worker                Cauchy(random_var, positive_var),
6290*da0073e9SAndroid Build Coastguard Worker                scipy.stats.cauchy(loc=random_var, scale=positive_var),
6291*da0073e9SAndroid Build Coastguard Worker            ),
6292*da0073e9SAndroid Build Coastguard Worker            (Dirichlet(positive_var), scipy.stats.dirichlet(positive_var)),
6293*da0073e9SAndroid Build Coastguard Worker            (
6294*da0073e9SAndroid Build Coastguard Worker                Exponential(positive_var),
6295*da0073e9SAndroid Build Coastguard Worker                scipy.stats.expon(scale=positive_var.reciprocal()),
6296*da0073e9SAndroid Build Coastguard Worker            ),
6297*da0073e9SAndroid Build Coastguard Worker            (
6298*da0073e9SAndroid Build Coastguard Worker                FisherSnedecor(
6299*da0073e9SAndroid Build Coastguard Worker                    positive_var, 4 + positive_var2
6300*da0073e9SAndroid Build Coastguard Worker                ),  # var for df2<=4 is undefined
6301*da0073e9SAndroid Build Coastguard Worker                scipy.stats.f(positive_var, 4 + positive_var2),
6302*da0073e9SAndroid Build Coastguard Worker            ),
6303*da0073e9SAndroid Build Coastguard Worker            (
6304*da0073e9SAndroid Build Coastguard Worker                Gamma(positive_var, positive_var2),
6305*da0073e9SAndroid Build Coastguard Worker                scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()),
6306*da0073e9SAndroid Build Coastguard Worker            ),
6307*da0073e9SAndroid Build Coastguard Worker            (Geometric(simplex_tensor), scipy.stats.geom(simplex_tensor, loc=-1)),
6308*da0073e9SAndroid Build Coastguard Worker            (
6309*da0073e9SAndroid Build Coastguard Worker                Gumbel(random_var, positive_var2),
6310*da0073e9SAndroid Build Coastguard Worker                scipy.stats.gumbel_r(random_var, positive_var2),
6311*da0073e9SAndroid Build Coastguard Worker            ),
6312*da0073e9SAndroid Build Coastguard Worker            (HalfCauchy(positive_var), scipy.stats.halfcauchy(scale=positive_var)),
6313*da0073e9SAndroid Build Coastguard Worker            (HalfNormal(positive_var2), scipy.stats.halfnorm(scale=positive_var2)),
6314*da0073e9SAndroid Build Coastguard Worker            (
6315*da0073e9SAndroid Build Coastguard Worker                InverseGamma(positive_var, positive_var2),
6316*da0073e9SAndroid Build Coastguard Worker                scipy.stats.invgamma(positive_var, scale=positive_var2),
6317*da0073e9SAndroid Build Coastguard Worker            ),
6318*da0073e9SAndroid Build Coastguard Worker            (
6319*da0073e9SAndroid Build Coastguard Worker                Laplace(random_var, positive_var2),
6320*da0073e9SAndroid Build Coastguard Worker                scipy.stats.laplace(random_var, positive_var2),
6321*da0073e9SAndroid Build Coastguard Worker            ),
6322*da0073e9SAndroid Build Coastguard Worker            (
6323*da0073e9SAndroid Build Coastguard Worker                # Tests fail 1e-5 threshold if scale > 3
6324*da0073e9SAndroid Build Coastguard Worker                LogNormal(random_var, positive_var.clamp(max=3)),
6325*da0073e9SAndroid Build Coastguard Worker                scipy.stats.lognorm(
6326*da0073e9SAndroid Build Coastguard Worker                    s=positive_var.clamp(max=3), scale=random_var.exp()
6327*da0073e9SAndroid Build Coastguard Worker                ),
6328*da0073e9SAndroid Build Coastguard Worker            ),
6329*da0073e9SAndroid Build Coastguard Worker            (
6330*da0073e9SAndroid Build Coastguard Worker                LowRankMultivariateNormal(
6331*da0073e9SAndroid Build Coastguard Worker                    random_var, torch.zeros(20, 1, dtype=torch.double), positive_var2
6332*da0073e9SAndroid Build Coastguard Worker                ),
6333*da0073e9SAndroid Build Coastguard Worker                scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)),
6334*da0073e9SAndroid Build Coastguard Worker            ),
6335*da0073e9SAndroid Build Coastguard Worker            (
6336*da0073e9SAndroid Build Coastguard Worker                Multinomial(10, simplex_tensor),
6337*da0073e9SAndroid Build Coastguard Worker                scipy.stats.multinomial(10, simplex_tensor),
6338*da0073e9SAndroid Build Coastguard Worker            ),
6339*da0073e9SAndroid Build Coastguard Worker            (
6340*da0073e9SAndroid Build Coastguard Worker                MultivariateNormal(random_var, torch.diag(positive_var2)),
6341*da0073e9SAndroid Build Coastguard Worker                scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)),
6342*da0073e9SAndroid Build Coastguard Worker            ),
6343*da0073e9SAndroid Build Coastguard Worker            (
6344*da0073e9SAndroid Build Coastguard Worker                MultivariateNormal(random_var, cov_tensor),
6345*da0073e9SAndroid Build Coastguard Worker                scipy.stats.multivariate_normal(random_var, cov_tensor),
6346*da0073e9SAndroid Build Coastguard Worker            ),
6347*da0073e9SAndroid Build Coastguard Worker            (
6348*da0073e9SAndroid Build Coastguard Worker                Normal(random_var, positive_var2),
6349*da0073e9SAndroid Build Coastguard Worker                scipy.stats.norm(random_var, positive_var2),
6350*da0073e9SAndroid Build Coastguard Worker            ),
6351*da0073e9SAndroid Build Coastguard Worker            (
6352*da0073e9SAndroid Build Coastguard Worker                OneHotCategorical(simplex_tensor),
6353*da0073e9SAndroid Build Coastguard Worker                scipy.stats.multinomial(1, simplex_tensor),
6354*da0073e9SAndroid Build Coastguard Worker            ),
6355*da0073e9SAndroid Build Coastguard Worker            (
6356*da0073e9SAndroid Build Coastguard Worker                Pareto(positive_var, 2 + positive_var2),
6357*da0073e9SAndroid Build Coastguard Worker                scipy.stats.pareto(2 + positive_var2, scale=positive_var),
6358*da0073e9SAndroid Build Coastguard Worker            ),
6359*da0073e9SAndroid Build Coastguard Worker            (Poisson(positive_var), scipy.stats.poisson(positive_var)),
6360*da0073e9SAndroid Build Coastguard Worker            (
6361*da0073e9SAndroid Build Coastguard Worker                StudentT(2 + positive_var, random_var, positive_var2),
6362*da0073e9SAndroid Build Coastguard Worker                scipy.stats.t(2 + positive_var, random_var, positive_var2),
6363*da0073e9SAndroid Build Coastguard Worker            ),
6364*da0073e9SAndroid Build Coastguard Worker            (
6365*da0073e9SAndroid Build Coastguard Worker                Uniform(random_var, random_var + positive_var),
6366*da0073e9SAndroid Build Coastguard Worker                scipy.stats.uniform(random_var, positive_var),
6367*da0073e9SAndroid Build Coastguard Worker            ),
6368*da0073e9SAndroid Build Coastguard Worker            (
6369*da0073e9SAndroid Build Coastguard Worker                VonMises(random_var, positive_var),
6370*da0073e9SAndroid Build Coastguard Worker                scipy.stats.vonmises(positive_var, loc=random_var),
6371*da0073e9SAndroid Build Coastguard Worker            ),
6372*da0073e9SAndroid Build Coastguard Worker            (
6373*da0073e9SAndroid Build Coastguard Worker                Weibull(
6374*da0073e9SAndroid Build Coastguard Worker                    positive_var[0], positive_var2[0]
6375*da0073e9SAndroid Build Coastguard Worker                ),  # scipy var for Weibull only supports scalars
6376*da0073e9SAndroid Build Coastguard Worker                scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0]),
6377*da0073e9SAndroid Build Coastguard Worker            ),
6378*da0073e9SAndroid Build Coastguard Worker            (
6379*da0073e9SAndroid Build Coastguard Worker                # scipy var for Wishart only supports scalars
6380*da0073e9SAndroid Build Coastguard Worker                # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0
6381*da0073e9SAndroid Build Coastguard Worker                Wishart(
6382*da0073e9SAndroid Build Coastguard Worker                    (
6383*da0073e9SAndroid Build Coastguard Worker                        20
6384*da0073e9SAndroid Build Coastguard Worker                        if version.parse(scipy.__version__) < version.parse("1.7.0")
6385*da0073e9SAndroid Build Coastguard Worker                        else 19
6386*da0073e9SAndroid Build Coastguard Worker                    )
6387*da0073e9SAndroid Build Coastguard Worker                    + positive_var[0],
6388*da0073e9SAndroid Build Coastguard Worker                    cov_tensor,
6389*da0073e9SAndroid Build Coastguard Worker                ),
6390*da0073e9SAndroid Build Coastguard Worker                scipy.stats.wishart(
6391*da0073e9SAndroid Build Coastguard Worker                    (
6392*da0073e9SAndroid Build Coastguard Worker                        20
6393*da0073e9SAndroid Build Coastguard Worker                        if version.parse(scipy.__version__) < version.parse("1.7.0")
6394*da0073e9SAndroid Build Coastguard Worker                        else 19
6395*da0073e9SAndroid Build Coastguard Worker                    )
6396*da0073e9SAndroid Build Coastguard Worker                    + positive_var[0].item(),
6397*da0073e9SAndroid Build Coastguard Worker                    cov_tensor,
6398*da0073e9SAndroid Build Coastguard Worker                ),
6399*da0073e9SAndroid Build Coastguard Worker            ),
6400*da0073e9SAndroid Build Coastguard Worker        ]
6401*da0073e9SAndroid Build Coastguard Worker
6402*da0073e9SAndroid Build Coastguard Worker    def test_mean(self):
6403*da0073e9SAndroid Build Coastguard Worker        for pytorch_dist, scipy_dist in self.distribution_pairs:
6404*da0073e9SAndroid Build Coastguard Worker            if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
6405*da0073e9SAndroid Build Coastguard Worker                # Cauchy, HalfCauchy distributions' mean is nan, skipping check
6406*da0073e9SAndroid Build Coastguard Worker                continue
6407*da0073e9SAndroid Build Coastguard Worker            elif isinstance(
6408*da0073e9SAndroid Build Coastguard Worker                pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)
6409*da0073e9SAndroid Build Coastguard Worker            ):
6410*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(pytorch_dist.mean, scipy_dist.mean, msg=pytorch_dist)
6411*da0073e9SAndroid Build Coastguard Worker            else:
6412*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), msg=pytorch_dist)
6413*da0073e9SAndroid Build Coastguard Worker
6414*da0073e9SAndroid Build Coastguard Worker    def test_variance_stddev(self):
6415*da0073e9SAndroid Build Coastguard Worker        for pytorch_dist, scipy_dist in self.distribution_pairs:
6416*da0073e9SAndroid Build Coastguard Worker            if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)):
6417*da0073e9SAndroid Build Coastguard Worker                # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
6418*da0073e9SAndroid Build Coastguard Worker                # VonMises variance is circular and scipy doesn't produce a correct result
6419*da0073e9SAndroid Build Coastguard Worker                continue
6420*da0073e9SAndroid Build Coastguard Worker            elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
6421*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6422*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.variance, np.diag(scipy_dist.cov()), msg=pytorch_dist
6423*da0073e9SAndroid Build Coastguard Worker                )
6424*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6425*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.stddev,
6426*da0073e9SAndroid Build Coastguard Worker                    np.diag(scipy_dist.cov()) ** 0.5,
6427*da0073e9SAndroid Build Coastguard Worker                    msg=pytorch_dist,
6428*da0073e9SAndroid Build Coastguard Worker                )
6429*da0073e9SAndroid Build Coastguard Worker            elif isinstance(
6430*da0073e9SAndroid Build Coastguard Worker                pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)
6431*da0073e9SAndroid Build Coastguard Worker            ):
6432*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6433*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.variance, np.diag(scipy_dist.cov), msg=pytorch_dist
6434*da0073e9SAndroid Build Coastguard Worker                )
6435*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6436*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.stddev,
6437*da0073e9SAndroid Build Coastguard Worker                    np.diag(scipy_dist.cov) ** 0.5,
6438*da0073e9SAndroid Build Coastguard Worker                    msg=pytorch_dist,
6439*da0073e9SAndroid Build Coastguard Worker                )
6440*da0073e9SAndroid Build Coastguard Worker            else:
6441*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6442*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist
6443*da0073e9SAndroid Build Coastguard Worker                )
6444*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
6445*da0073e9SAndroid Build Coastguard Worker                    pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist
6446*da0073e9SAndroid Build Coastguard Worker                )
6447*da0073e9SAndroid Build Coastguard Worker
6448*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6449*da0073e9SAndroid Build Coastguard Worker    def test_cdf(self):
6450*da0073e9SAndroid Build Coastguard Worker        for pytorch_dist, scipy_dist in self.distribution_pairs:
6451*da0073e9SAndroid Build Coastguard Worker            samples = pytorch_dist.sample((5,))
6452*da0073e9SAndroid Build Coastguard Worker            try:
6453*da0073e9SAndroid Build Coastguard Worker                cdf = pytorch_dist.cdf(samples)
6454*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6455*da0073e9SAndroid Build Coastguard Worker                continue
6456*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cdf, scipy_dist.cdf(samples), msg=pytorch_dist)
6457*da0073e9SAndroid Build Coastguard Worker
6458*da0073e9SAndroid Build Coastguard Worker    def test_icdf(self):
6459*da0073e9SAndroid Build Coastguard Worker        for pytorch_dist, scipy_dist in self.distribution_pairs:
6460*da0073e9SAndroid Build Coastguard Worker            samples = torch.rand((5,) + pytorch_dist.batch_shape, dtype=torch.double)
6461*da0073e9SAndroid Build Coastguard Worker            try:
6462*da0073e9SAndroid Build Coastguard Worker                icdf = pytorch_dist.icdf(samples)
6463*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6464*da0073e9SAndroid Build Coastguard Worker                continue
6465*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)
6466*da0073e9SAndroid Build Coastguard Worker
6467*da0073e9SAndroid Build Coastguard Worker
6468*da0073e9SAndroid Build Coastguard Workerclass TestFunctors(DistributionsTestCase):
6469*da0073e9SAndroid Build Coastguard Worker    def test_cat_transform(self):
6470*da0073e9SAndroid Build Coastguard Worker        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6471*da0073e9SAndroid Build Coastguard Worker        x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100
6472*da0073e9SAndroid Build Coastguard Worker        x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6473*da0073e9SAndroid Build Coastguard Worker        t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
6474*da0073e9SAndroid Build Coastguard Worker        dim = 0
6475*da0073e9SAndroid Build Coastguard Worker        x = torch.cat([x1, x2, x3], dim=dim)
6476*da0073e9SAndroid Build Coastguard Worker        t = CatTransform([t1, t2, t3], dim=dim)
6477*da0073e9SAndroid Build Coastguard Worker        actual_dom_check = t.domain.check(x)
6478*da0073e9SAndroid Build Coastguard Worker        expected_dom_check = torch.cat(
6479*da0073e9SAndroid Build Coastguard Worker            [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim
6480*da0073e9SAndroid Build Coastguard Worker        )
6481*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_dom_check, actual_dom_check)
6482*da0073e9SAndroid Build Coastguard Worker        actual = t(x)
6483*da0073e9SAndroid Build Coastguard Worker        expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim)
6484*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
6485*da0073e9SAndroid Build Coastguard Worker        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6486*da0073e9SAndroid Build Coastguard Worker        y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6487*da0073e9SAndroid Build Coastguard Worker        y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6488*da0073e9SAndroid Build Coastguard Worker        y = torch.cat([y1, y2, y3], dim=dim)
6489*da0073e9SAndroid Build Coastguard Worker        actual_cod_check = t.codomain.check(y)
6490*da0073e9SAndroid Build Coastguard Worker        expected_cod_check = torch.cat(
6491*da0073e9SAndroid Build Coastguard Worker            [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)],
6492*da0073e9SAndroid Build Coastguard Worker            dim=dim,
6493*da0073e9SAndroid Build Coastguard Worker        )
6494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_cod_check, expected_cod_check)
6495*da0073e9SAndroid Build Coastguard Worker        actual_inv = t.inv(y)
6496*da0073e9SAndroid Build Coastguard Worker        expected_inv = torch.cat([t1.inv(y1), t2.inv(y2), t3.inv(y3)], dim=dim)
6497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_inv, actual_inv)
6498*da0073e9SAndroid Build Coastguard Worker        actual_jac = t.log_abs_det_jacobian(x, y)
6499*da0073e9SAndroid Build Coastguard Worker        expected_jac = torch.cat(
6500*da0073e9SAndroid Build Coastguard Worker            [
6501*da0073e9SAndroid Build Coastguard Worker                t1.log_abs_det_jacobian(x1, y1),
6502*da0073e9SAndroid Build Coastguard Worker                t2.log_abs_det_jacobian(x2, y2),
6503*da0073e9SAndroid Build Coastguard Worker                t3.log_abs_det_jacobian(x3, y3),
6504*da0073e9SAndroid Build Coastguard Worker            ],
6505*da0073e9SAndroid Build Coastguard Worker            dim=dim,
6506*da0073e9SAndroid Build Coastguard Worker        )
6507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_jac, expected_jac)
6508*da0073e9SAndroid Build Coastguard Worker
6509*da0073e9SAndroid Build Coastguard Worker    def test_cat_transform_non_uniform(self):
6510*da0073e9SAndroid Build Coastguard Worker        x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6511*da0073e9SAndroid Build Coastguard Worker        x2 = torch.cat(
6512*da0073e9SAndroid Build Coastguard Worker            [
6513*da0073e9SAndroid Build Coastguard Worker                (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100,
6514*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6515*da0073e9SAndroid Build Coastguard Worker            ]
6516*da0073e9SAndroid Build Coastguard Worker        )
6517*da0073e9SAndroid Build Coastguard Worker        t1 = ExpTransform()
6518*da0073e9SAndroid Build Coastguard Worker        t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0)
6519*da0073e9SAndroid Build Coastguard Worker        dim = 0
6520*da0073e9SAndroid Build Coastguard Worker        x = torch.cat([x1, x2], dim=dim)
6521*da0073e9SAndroid Build Coastguard Worker        t = CatTransform([t1, t2], dim=dim, lengths=[1, 2])
6522*da0073e9SAndroid Build Coastguard Worker        actual_dom_check = t.domain.check(x)
6523*da0073e9SAndroid Build Coastguard Worker        expected_dom_check = torch.cat(
6524*da0073e9SAndroid Build Coastguard Worker            [t1.domain.check(x1), t2.domain.check(x2)], dim=dim
6525*da0073e9SAndroid Build Coastguard Worker        )
6526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_dom_check, actual_dom_check)
6527*da0073e9SAndroid Build Coastguard Worker        actual = t(x)
6528*da0073e9SAndroid Build Coastguard Worker        expected = torch.cat([t1(x1), t2(x2)], dim=dim)
6529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
6530*da0073e9SAndroid Build Coastguard Worker        y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100)
6531*da0073e9SAndroid Build Coastguard Worker        y2 = torch.cat(
6532*da0073e9SAndroid Build Coastguard Worker            [
6533*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6534*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 101, dtype=torch.float).view(-1, 100),
6535*da0073e9SAndroid Build Coastguard Worker            ]
6536*da0073e9SAndroid Build Coastguard Worker        )
6537*da0073e9SAndroid Build Coastguard Worker        y = torch.cat([y1, y2], dim=dim)
6538*da0073e9SAndroid Build Coastguard Worker        actual_cod_check = t.codomain.check(y)
6539*da0073e9SAndroid Build Coastguard Worker        expected_cod_check = torch.cat(
6540*da0073e9SAndroid Build Coastguard Worker            [t1.codomain.check(y1), t2.codomain.check(y2)], dim=dim
6541*da0073e9SAndroid Build Coastguard Worker        )
6542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_cod_check, expected_cod_check)
6543*da0073e9SAndroid Build Coastguard Worker        actual_inv = t.inv(y)
6544*da0073e9SAndroid Build Coastguard Worker        expected_inv = torch.cat([t1.inv(y1), t2.inv(y2)], dim=dim)
6545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_inv, actual_inv)
6546*da0073e9SAndroid Build Coastguard Worker        actual_jac = t.log_abs_det_jacobian(x, y)
6547*da0073e9SAndroid Build Coastguard Worker        expected_jac = torch.cat(
6548*da0073e9SAndroid Build Coastguard Worker            [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)], dim=dim
6549*da0073e9SAndroid Build Coastguard Worker        )
6550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_jac, expected_jac)
6551*da0073e9SAndroid Build Coastguard Worker
6552*da0073e9SAndroid Build Coastguard Worker    def test_cat_event_dim(self):
6553*da0073e9SAndroid Build Coastguard Worker        t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
6554*da0073e9SAndroid Build Coastguard Worker        t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1)
6555*da0073e9SAndroid Build Coastguard Worker        dim = 1
6556*da0073e9SAndroid Build Coastguard Worker        bs = 16
6557*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(bs, 2)
6558*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(bs, 2)
6559*da0073e9SAndroid Build Coastguard Worker        x = torch.cat([x1, x2], dim=1)
6560*da0073e9SAndroid Build Coastguard Worker        t = CatTransform([t1, t2], dim=dim, lengths=[2, 2])
6561*da0073e9SAndroid Build Coastguard Worker        y1 = t1(x1)
6562*da0073e9SAndroid Build Coastguard Worker        y2 = t2(x2)
6563*da0073e9SAndroid Build Coastguard Worker        y = t(x)
6564*da0073e9SAndroid Build Coastguard Worker        actual_jac = t.log_abs_det_jacobian(x, y)
6565*da0073e9SAndroid Build Coastguard Worker        expected_jac = sum(
6566*da0073e9SAndroid Build Coastguard Worker            [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)]
6567*da0073e9SAndroid Build Coastguard Worker        )
6568*da0073e9SAndroid Build Coastguard Worker
6569*da0073e9SAndroid Build Coastguard Worker    def test_stack_transform(self):
6570*da0073e9SAndroid Build Coastguard Worker        x1 = -1 * torch.arange(1, 101, dtype=torch.float)
6571*da0073e9SAndroid Build Coastguard Worker        x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100
6572*da0073e9SAndroid Build Coastguard Worker        x3 = torch.arange(1, 101, dtype=torch.float)
6573*da0073e9SAndroid Build Coastguard Worker        t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform
6574*da0073e9SAndroid Build Coastguard Worker        dim = 0
6575*da0073e9SAndroid Build Coastguard Worker        x = torch.stack([x1, x2, x3], dim=dim)
6576*da0073e9SAndroid Build Coastguard Worker        t = StackTransform([t1, t2, t3], dim=dim)
6577*da0073e9SAndroid Build Coastguard Worker        actual_dom_check = t.domain.check(x)
6578*da0073e9SAndroid Build Coastguard Worker        expected_dom_check = torch.stack(
6579*da0073e9SAndroid Build Coastguard Worker            [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim
6580*da0073e9SAndroid Build Coastguard Worker        )
6581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_dom_check, actual_dom_check)
6582*da0073e9SAndroid Build Coastguard Worker        actual = t(x)
6583*da0073e9SAndroid Build Coastguard Worker        expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim)
6584*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
6585*da0073e9SAndroid Build Coastguard Worker        y1 = torch.arange(1, 101, dtype=torch.float)
6586*da0073e9SAndroid Build Coastguard Worker        y2 = torch.arange(1, 101, dtype=torch.float)
6587*da0073e9SAndroid Build Coastguard Worker        y3 = torch.arange(1, 101, dtype=torch.float)
6588*da0073e9SAndroid Build Coastguard Worker        y = torch.stack([y1, y2, y3], dim=dim)
6589*da0073e9SAndroid Build Coastguard Worker        actual_cod_check = t.codomain.check(y)
6590*da0073e9SAndroid Build Coastguard Worker        expected_cod_check = torch.stack(
6591*da0073e9SAndroid Build Coastguard Worker            [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)],
6592*da0073e9SAndroid Build Coastguard Worker            dim=dim,
6593*da0073e9SAndroid Build Coastguard Worker        )
6594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_cod_check, expected_cod_check)
6595*da0073e9SAndroid Build Coastguard Worker        actual_inv = t.inv(x)
6596*da0073e9SAndroid Build Coastguard Worker        expected_inv = torch.stack([t1.inv(x1), t2.inv(x2), t3.inv(x3)], dim=dim)
6597*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_inv, actual_inv)
6598*da0073e9SAndroid Build Coastguard Worker        actual_jac = t.log_abs_det_jacobian(x, y)
6599*da0073e9SAndroid Build Coastguard Worker        expected_jac = torch.stack(
6600*da0073e9SAndroid Build Coastguard Worker            [
6601*da0073e9SAndroid Build Coastguard Worker                t1.log_abs_det_jacobian(x1, y1),
6602*da0073e9SAndroid Build Coastguard Worker                t2.log_abs_det_jacobian(x2, y2),
6603*da0073e9SAndroid Build Coastguard Worker                t3.log_abs_det_jacobian(x3, y3),
6604*da0073e9SAndroid Build Coastguard Worker            ],
6605*da0073e9SAndroid Build Coastguard Worker            dim=dim,
6606*da0073e9SAndroid Build Coastguard Worker        )
6607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_jac, expected_jac)
6608*da0073e9SAndroid Build Coastguard Worker
6609*da0073e9SAndroid Build Coastguard Worker
6610*da0073e9SAndroid Build Coastguard Workerclass TestValidation(DistributionsTestCase):
6611*da0073e9SAndroid Build Coastguard Worker    def test_valid(self):
6612*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
6613*da0073e9SAndroid Build Coastguard Worker            for param in params:
6614*da0073e9SAndroid Build Coastguard Worker                Dist(validate_args=True, **param)
6615*da0073e9SAndroid Build Coastguard Worker
6616*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6617*da0073e9SAndroid Build Coastguard Worker    def test_invalid_log_probs_arg(self):
6618*da0073e9SAndroid Build Coastguard Worker        # Check that validation errors are indeed disabled,
6619*da0073e9SAndroid Build Coastguard Worker        # but they might raise another error
6620*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
6621*da0073e9SAndroid Build Coastguard Worker            if Dist == TransformedDistribution:
6622*da0073e9SAndroid Build Coastguard Worker                # TransformedDistribution has a distribution instance
6623*da0073e9SAndroid Build Coastguard Worker                # as the argument, so we cannot do much about that
6624*da0073e9SAndroid Build Coastguard Worker                continue
6625*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
6626*da0073e9SAndroid Build Coastguard Worker                d_nonval = Dist(validate_args=False, **param)
6627*da0073e9SAndroid Build Coastguard Worker                d_val = Dist(validate_args=True, **param)
6628*da0073e9SAndroid Build Coastguard Worker                for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]):
6629*da0073e9SAndroid Build Coastguard Worker                    # samples with incorrect shape must throw ValueError only
6630*da0073e9SAndroid Build Coastguard Worker                    try:
6631*da0073e9SAndroid Build Coastguard Worker                        log_prob = d_val.log_prob(v)
6632*da0073e9SAndroid Build Coastguard Worker                    except ValueError:
6633*da0073e9SAndroid Build Coastguard Worker                        pass
6634*da0073e9SAndroid Build Coastguard Worker                    # get sample of correct shape
6635*da0073e9SAndroid Build Coastguard Worker                    val = torch.full(d_val.batch_shape + d_val.event_shape, v)
6636*da0073e9SAndroid Build Coastguard Worker                    # check samples with incorrect support
6637*da0073e9SAndroid Build Coastguard Worker                    try:
6638*da0073e9SAndroid Build Coastguard Worker                        log_prob = d_val.log_prob(val)
6639*da0073e9SAndroid Build Coastguard Worker                    except ValueError as e:
6640*da0073e9SAndroid Build Coastguard Worker                        if e.args and "must be within the support" in e.args[0]:
6641*da0073e9SAndroid Build Coastguard Worker                            try:
6642*da0073e9SAndroid Build Coastguard Worker                                log_prob = d_nonval.log_prob(val)
6643*da0073e9SAndroid Build Coastguard Worker                            except RuntimeError:
6644*da0073e9SAndroid Build Coastguard Worker                                pass
6645*da0073e9SAndroid Build Coastguard Worker
6646*da0073e9SAndroid Build Coastguard Worker                # check correct samples are ok
6647*da0073e9SAndroid Build Coastguard Worker                valid_value = d_val.sample()
6648*da0073e9SAndroid Build Coastguard Worker                d_val.log_prob(valid_value)
6649*da0073e9SAndroid Build Coastguard Worker                # check invalid values raise ValueError
6650*da0073e9SAndroid Build Coastguard Worker                if valid_value.dtype == torch.long:
6651*da0073e9SAndroid Build Coastguard Worker                    valid_value = valid_value.float()
6652*da0073e9SAndroid Build Coastguard Worker                invalid_value = torch.full_like(valid_value, math.nan)
6653*da0073e9SAndroid Build Coastguard Worker                try:
6654*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
6655*da0073e9SAndroid Build Coastguard Worker                        ValueError,
6656*da0073e9SAndroid Build Coastguard Worker                        "Expected value argument .* to be within the support .*",
6657*da0073e9SAndroid Build Coastguard Worker                    ):
6658*da0073e9SAndroid Build Coastguard Worker                        d_val.log_prob(invalid_value)
6659*da0073e9SAndroid Build Coastguard Worker                except AssertionError as e:
6660*da0073e9SAndroid Build Coastguard Worker                    fail_string = "Support ValueError not raised for {} example {}/{}"
6661*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError(
6662*da0073e9SAndroid Build Coastguard Worker                        fail_string.format(Dist.__name__, i + 1, len(params))
6663*da0073e9SAndroid Build Coastguard Worker                    ) from e
6664*da0073e9SAndroid Build Coastguard Worker
6665*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6666*da0073e9SAndroid Build Coastguard Worker    def test_invalid(self):
6667*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_bad_examples():
6668*da0073e9SAndroid Build Coastguard Worker            for i, param in enumerate(params):
6669*da0073e9SAndroid Build Coastguard Worker                try:
6670*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(ValueError):
6671*da0073e9SAndroid Build Coastguard Worker                        Dist(validate_args=True, **param)
6672*da0073e9SAndroid Build Coastguard Worker                except AssertionError as e:
6673*da0073e9SAndroid Build Coastguard Worker                    fail_string = "ValueError not raised for {} example {}/{}"
6674*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError(
6675*da0073e9SAndroid Build Coastguard Worker                        fail_string.format(Dist.__name__, i + 1, len(params))
6676*da0073e9SAndroid Build Coastguard Worker                    ) from e
6677*da0073e9SAndroid Build Coastguard Worker
6678*da0073e9SAndroid Build Coastguard Worker    def test_warning_unimplemented_constraints(self):
6679*da0073e9SAndroid Build Coastguard Worker        class Delta(Distribution):
6680*da0073e9SAndroid Build Coastguard Worker            def __init__(self, validate_args=True):
6681*da0073e9SAndroid Build Coastguard Worker                super().__init__(validate_args=validate_args)
6682*da0073e9SAndroid Build Coastguard Worker
6683*da0073e9SAndroid Build Coastguard Worker            def sample(self, sample_shape=torch.Size()):
6684*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(0.0).expand(sample_shape)
6685*da0073e9SAndroid Build Coastguard Worker
6686*da0073e9SAndroid Build Coastguard Worker            def log_prob(self, value):
6687*da0073e9SAndroid Build Coastguard Worker                if self._validate_args:
6688*da0073e9SAndroid Build Coastguard Worker                    self._validate_sample(value)
6689*da0073e9SAndroid Build Coastguard Worker                value[value != 0.0] = -float("inf")
6690*da0073e9SAndroid Build Coastguard Worker                value[value == 0.0] = 0.0
6691*da0073e9SAndroid Build Coastguard Worker                return value
6692*da0073e9SAndroid Build Coastguard Worker
6693*da0073e9SAndroid Build Coastguard Worker        with self.assertWarns(UserWarning):
6694*da0073e9SAndroid Build Coastguard Worker            d = Delta()
6695*da0073e9SAndroid Build Coastguard Worker        sample = d.sample((2,))
6696*da0073e9SAndroid Build Coastguard Worker        with self.assertWarns(UserWarning):
6697*da0073e9SAndroid Build Coastguard Worker            d.log_prob(sample)
6698*da0073e9SAndroid Build Coastguard Worker
6699*da0073e9SAndroid Build Coastguard Worker
6700*da0073e9SAndroid Build Coastguard Workerclass TestJit(DistributionsTestCase):
6701*da0073e9SAndroid Build Coastguard Worker    def _examples(self):
6702*da0073e9SAndroid Build Coastguard Worker        for Dist, params in _get_examples():
6703*da0073e9SAndroid Build Coastguard Worker            for param in params:
6704*da0073e9SAndroid Build Coastguard Worker                keys = param.keys()
6705*da0073e9SAndroid Build Coastguard Worker                values = tuple(param[key] for key in keys)
6706*da0073e9SAndroid Build Coastguard Worker                if not all(isinstance(x, torch.Tensor) for x in values):
6707*da0073e9SAndroid Build Coastguard Worker                    continue
6708*da0073e9SAndroid Build Coastguard Worker                sample = Dist(**param).sample()
6709*da0073e9SAndroid Build Coastguard Worker                yield Dist, keys, values, sample
6710*da0073e9SAndroid Build Coastguard Worker
6711*da0073e9SAndroid Build Coastguard Worker    def _perturb_tensor(self, value, constraint):
6712*da0073e9SAndroid Build Coastguard Worker        if isinstance(constraint, constraints._IntegerGreaterThan):
6713*da0073e9SAndroid Build Coastguard Worker            return value + 1
6714*da0073e9SAndroid Build Coastguard Worker        if isinstance(
6715*da0073e9SAndroid Build Coastguard Worker            constraint,
6716*da0073e9SAndroid Build Coastguard Worker            (constraints._PositiveDefinite, constraints._PositiveSemidefinite),
6717*da0073e9SAndroid Build Coastguard Worker        ):
6718*da0073e9SAndroid Build Coastguard Worker            return value + torch.eye(value.shape[-1])
6719*da0073e9SAndroid Build Coastguard Worker        if value.dtype in [torch.float, torch.double]:
6720*da0073e9SAndroid Build Coastguard Worker            transform = transform_to(constraint)
6721*da0073e9SAndroid Build Coastguard Worker            delta = value.new(value.shape).normal_()
6722*da0073e9SAndroid Build Coastguard Worker            return transform(transform.inv(value) + delta)
6723*da0073e9SAndroid Build Coastguard Worker        if value.dtype == torch.long:
6724*da0073e9SAndroid Build Coastguard Worker            result = value.clone()
6725*da0073e9SAndroid Build Coastguard Worker            result[value == 0] = 1
6726*da0073e9SAndroid Build Coastguard Worker            result[value == 1] = 0
6727*da0073e9SAndroid Build Coastguard Worker            return result
6728*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
6729*da0073e9SAndroid Build Coastguard Worker
6730*da0073e9SAndroid Build Coastguard Worker    def _perturb(self, Dist, keys, values, sample):
6731*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
6732*da0073e9SAndroid Build Coastguard Worker            if Dist is Uniform:
6733*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6734*da0073e9SAndroid Build Coastguard Worker                param["low"] = param["low"] - torch.rand(param["low"].shape)
6735*da0073e9SAndroid Build Coastguard Worker                param["high"] = param["high"] + torch.rand(param["high"].shape)
6736*da0073e9SAndroid Build Coastguard Worker                values = [param[key] for key in keys]
6737*da0073e9SAndroid Build Coastguard Worker            else:
6738*da0073e9SAndroid Build Coastguard Worker                values = [
6739*da0073e9SAndroid Build Coastguard Worker                    self._perturb_tensor(
6740*da0073e9SAndroid Build Coastguard Worker                        value, Dist.arg_constraints.get(key, constraints.real)
6741*da0073e9SAndroid Build Coastguard Worker                    )
6742*da0073e9SAndroid Build Coastguard Worker                    for key, value in zip(keys, values)
6743*da0073e9SAndroid Build Coastguard Worker                ]
6744*da0073e9SAndroid Build Coastguard Worker            param = dict(zip(keys, values))
6745*da0073e9SAndroid Build Coastguard Worker            sample = Dist(**param).sample()
6746*da0073e9SAndroid Build Coastguard Worker            return values, sample
6747*da0073e9SAndroid Build Coastguard Worker
6748*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6749*da0073e9SAndroid Build Coastguard Worker    def test_sample(self):
6750*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6751*da0073e9SAndroid Build Coastguard Worker
6752*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6753*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6754*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6755*da0073e9SAndroid Build Coastguard Worker                return dist.sample()
6756*da0073e9SAndroid Build Coastguard Worker
6757*da0073e9SAndroid Build Coastguard Worker            traced_f = torch.jit.trace(f, values, check_trace=False)
6758*da0073e9SAndroid Build Coastguard Worker
6759*da0073e9SAndroid Build Coastguard Worker            # FIXME Schema not found for node
6760*da0073e9SAndroid Build Coastguard Worker            xfail = [
6761*da0073e9SAndroid Build Coastguard Worker                Cauchy,  # aten::cauchy(Double(2,1), float, float, Generator)
6762*da0073e9SAndroid Build Coastguard Worker                HalfCauchy,  # aten::cauchy(Double(2, 1), float, float, Generator)
6763*da0073e9SAndroid Build Coastguard Worker                VonMises,  # Variance is not Euclidean
6764*da0073e9SAndroid Build Coastguard Worker            ]
6765*da0073e9SAndroid Build Coastguard Worker            if Dist in xfail:
6766*da0073e9SAndroid Build Coastguard Worker                continue
6767*da0073e9SAndroid Build Coastguard Worker
6768*da0073e9SAndroid Build Coastguard Worker            with torch.random.fork_rng():
6769*da0073e9SAndroid Build Coastguard Worker                sample = f(*values)
6770*da0073e9SAndroid Build Coastguard Worker            traced_sample = traced_f(*values)
6771*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample, traced_sample)
6772*da0073e9SAndroid Build Coastguard Worker
6773*da0073e9SAndroid Build Coastguard Worker            # FIXME no nondeterministic nodes found in trace
6774*da0073e9SAndroid Build Coastguard Worker            xfail = [Beta, Dirichlet]
6775*da0073e9SAndroid Build Coastguard Worker            if Dist not in xfail:
6776*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
6777*da0073e9SAndroid Build Coastguard Worker                    any(n.isNondeterministic() for n in traced_f.graph.nodes())
6778*da0073e9SAndroid Build Coastguard Worker                )
6779*da0073e9SAndroid Build Coastguard Worker
6780*da0073e9SAndroid Build Coastguard Worker    def test_rsample(self):
6781*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6782*da0073e9SAndroid Build Coastguard Worker            if not Dist.has_rsample:
6783*da0073e9SAndroid Build Coastguard Worker                continue
6784*da0073e9SAndroid Build Coastguard Worker
6785*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6786*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6787*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6788*da0073e9SAndroid Build Coastguard Worker                return dist.rsample()
6789*da0073e9SAndroid Build Coastguard Worker
6790*da0073e9SAndroid Build Coastguard Worker            traced_f = torch.jit.trace(f, values, check_trace=False)
6791*da0073e9SAndroid Build Coastguard Worker
6792*da0073e9SAndroid Build Coastguard Worker            # FIXME Schema not found for node
6793*da0073e9SAndroid Build Coastguard Worker            xfail = [
6794*da0073e9SAndroid Build Coastguard Worker                Cauchy,  # aten::cauchy(Double(2,1), float, float, Generator)
6795*da0073e9SAndroid Build Coastguard Worker                HalfCauchy,  # aten::cauchy(Double(2, 1), float, float, Generator)
6796*da0073e9SAndroid Build Coastguard Worker            ]
6797*da0073e9SAndroid Build Coastguard Worker            if Dist in xfail:
6798*da0073e9SAndroid Build Coastguard Worker                continue
6799*da0073e9SAndroid Build Coastguard Worker
6800*da0073e9SAndroid Build Coastguard Worker            with torch.random.fork_rng():
6801*da0073e9SAndroid Build Coastguard Worker                sample = f(*values)
6802*da0073e9SAndroid Build Coastguard Worker            traced_sample = traced_f(*values)
6803*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample, traced_sample)
6804*da0073e9SAndroid Build Coastguard Worker
6805*da0073e9SAndroid Build Coastguard Worker            # FIXME no nondeterministic nodes found in trace
6806*da0073e9SAndroid Build Coastguard Worker            xfail = [Beta, Dirichlet]
6807*da0073e9SAndroid Build Coastguard Worker            if Dist not in xfail:
6808*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
6809*da0073e9SAndroid Build Coastguard Worker                    any(n.isNondeterministic() for n in traced_f.graph.nodes())
6810*da0073e9SAndroid Build Coastguard Worker                )
6811*da0073e9SAndroid Build Coastguard Worker
6812*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6813*da0073e9SAndroid Build Coastguard Worker    def test_log_prob(self):
6814*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6815*da0073e9SAndroid Build Coastguard Worker            # FIXME traced functions produce incorrect results
6816*da0073e9SAndroid Build Coastguard Worker            xfail = [LowRankMultivariateNormal, MultivariateNormal]
6817*da0073e9SAndroid Build Coastguard Worker            if Dist in xfail:
6818*da0073e9SAndroid Build Coastguard Worker                continue
6819*da0073e9SAndroid Build Coastguard Worker
6820*da0073e9SAndroid Build Coastguard Worker            def f(sample, *values):
6821*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6822*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6823*da0073e9SAndroid Build Coastguard Worker                return dist.log_prob(sample)
6824*da0073e9SAndroid Build Coastguard Worker
6825*da0073e9SAndroid Build Coastguard Worker            traced_f = torch.jit.trace(f, (sample,) + values)
6826*da0073e9SAndroid Build Coastguard Worker
6827*da0073e9SAndroid Build Coastguard Worker            # check on different data
6828*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6829*da0073e9SAndroid Build Coastguard Worker            expected = f(sample, *values)
6830*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(sample, *values)
6831*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6832*da0073e9SAndroid Build Coastguard Worker                expected,
6833*da0073e9SAndroid Build Coastguard Worker                actual,
6834*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6835*da0073e9SAndroid Build Coastguard Worker            )
6836*da0073e9SAndroid Build Coastguard Worker
6837*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_support(self):
6838*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6839*da0073e9SAndroid Build Coastguard Worker            # FIXME traced functions produce incorrect results
6840*da0073e9SAndroid Build Coastguard Worker            xfail = [Binomial]
6841*da0073e9SAndroid Build Coastguard Worker            if Dist in xfail:
6842*da0073e9SAndroid Build Coastguard Worker                continue
6843*da0073e9SAndroid Build Coastguard Worker
6844*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6845*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6846*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6847*da0073e9SAndroid Build Coastguard Worker                return dist.enumerate_support()
6848*da0073e9SAndroid Build Coastguard Worker
6849*da0073e9SAndroid Build Coastguard Worker            try:
6850*da0073e9SAndroid Build Coastguard Worker                traced_f = torch.jit.trace(f, values)
6851*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6852*da0073e9SAndroid Build Coastguard Worker                continue
6853*da0073e9SAndroid Build Coastguard Worker
6854*da0073e9SAndroid Build Coastguard Worker            # check on different data
6855*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6856*da0073e9SAndroid Build Coastguard Worker            expected = f(*values)
6857*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(*values)
6858*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6859*da0073e9SAndroid Build Coastguard Worker                expected,
6860*da0073e9SAndroid Build Coastguard Worker                actual,
6861*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6862*da0073e9SAndroid Build Coastguard Worker            )
6863*da0073e9SAndroid Build Coastguard Worker
6864*da0073e9SAndroid Build Coastguard Worker    def test_mean(self):
6865*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6866*da0073e9SAndroid Build Coastguard Worker
6867*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6868*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6869*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6870*da0073e9SAndroid Build Coastguard Worker                return dist.mean
6871*da0073e9SAndroid Build Coastguard Worker
6872*da0073e9SAndroid Build Coastguard Worker            try:
6873*da0073e9SAndroid Build Coastguard Worker                traced_f = torch.jit.trace(f, values)
6874*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6875*da0073e9SAndroid Build Coastguard Worker                continue
6876*da0073e9SAndroid Build Coastguard Worker
6877*da0073e9SAndroid Build Coastguard Worker            # check on different data
6878*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6879*da0073e9SAndroid Build Coastguard Worker            expected = f(*values)
6880*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(*values)
6881*da0073e9SAndroid Build Coastguard Worker            expected[expected == float("inf")] = 0.0
6882*da0073e9SAndroid Build Coastguard Worker            actual[actual == float("inf")] = 0.0
6883*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6884*da0073e9SAndroid Build Coastguard Worker                expected,
6885*da0073e9SAndroid Build Coastguard Worker                actual,
6886*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6887*da0073e9SAndroid Build Coastguard Worker            )
6888*da0073e9SAndroid Build Coastguard Worker
6889*da0073e9SAndroid Build Coastguard Worker    def test_variance(self):
6890*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6891*da0073e9SAndroid Build Coastguard Worker            if Dist in [Cauchy, HalfCauchy]:
6892*da0073e9SAndroid Build Coastguard Worker                continue  # infinite variance
6893*da0073e9SAndroid Build Coastguard Worker
6894*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6895*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6896*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6897*da0073e9SAndroid Build Coastguard Worker                return dist.variance
6898*da0073e9SAndroid Build Coastguard Worker
6899*da0073e9SAndroid Build Coastguard Worker            try:
6900*da0073e9SAndroid Build Coastguard Worker                traced_f = torch.jit.trace(f, values)
6901*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6902*da0073e9SAndroid Build Coastguard Worker                continue
6903*da0073e9SAndroid Build Coastguard Worker
6904*da0073e9SAndroid Build Coastguard Worker            # check on different data
6905*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6906*da0073e9SAndroid Build Coastguard Worker            expected = f(*values).clone()
6907*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(*values).clone()
6908*da0073e9SAndroid Build Coastguard Worker            expected[expected == float("inf")] = 0.0
6909*da0073e9SAndroid Build Coastguard Worker            actual[actual == float("inf")] = 0.0
6910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6911*da0073e9SAndroid Build Coastguard Worker                expected,
6912*da0073e9SAndroid Build Coastguard Worker                actual,
6913*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6914*da0073e9SAndroid Build Coastguard Worker            )
6915*da0073e9SAndroid Build Coastguard Worker
6916*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6917*da0073e9SAndroid Build Coastguard Worker    def test_entropy(self):
6918*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6919*da0073e9SAndroid Build Coastguard Worker            # FIXME traced functions produce incorrect results
6920*da0073e9SAndroid Build Coastguard Worker            xfail = [LowRankMultivariateNormal, MultivariateNormal]
6921*da0073e9SAndroid Build Coastguard Worker            if Dist in xfail:
6922*da0073e9SAndroid Build Coastguard Worker                continue
6923*da0073e9SAndroid Build Coastguard Worker
6924*da0073e9SAndroid Build Coastguard Worker            def f(*values):
6925*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6926*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6927*da0073e9SAndroid Build Coastguard Worker                return dist.entropy()
6928*da0073e9SAndroid Build Coastguard Worker
6929*da0073e9SAndroid Build Coastguard Worker            try:
6930*da0073e9SAndroid Build Coastguard Worker                traced_f = torch.jit.trace(f, values)
6931*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6932*da0073e9SAndroid Build Coastguard Worker                continue
6933*da0073e9SAndroid Build Coastguard Worker
6934*da0073e9SAndroid Build Coastguard Worker            # check on different data
6935*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6936*da0073e9SAndroid Build Coastguard Worker            expected = f(*values)
6937*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(*values)
6938*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6939*da0073e9SAndroid Build Coastguard Worker                expected,
6940*da0073e9SAndroid Build Coastguard Worker                actual,
6941*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6942*da0073e9SAndroid Build Coastguard Worker            )
6943*da0073e9SAndroid Build Coastguard Worker
6944*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6945*da0073e9SAndroid Build Coastguard Worker    def test_cdf(self):
6946*da0073e9SAndroid Build Coastguard Worker        for Dist, keys, values, sample in self._examples():
6947*da0073e9SAndroid Build Coastguard Worker
6948*da0073e9SAndroid Build Coastguard Worker            def f(sample, *values):
6949*da0073e9SAndroid Build Coastguard Worker                param = dict(zip(keys, values))
6950*da0073e9SAndroid Build Coastguard Worker                dist = Dist(**param)
6951*da0073e9SAndroid Build Coastguard Worker                cdf = dist.cdf(sample)
6952*da0073e9SAndroid Build Coastguard Worker                return dist.icdf(cdf)
6953*da0073e9SAndroid Build Coastguard Worker
6954*da0073e9SAndroid Build Coastguard Worker            try:
6955*da0073e9SAndroid Build Coastguard Worker                traced_f = torch.jit.trace(f, (sample,) + values)
6956*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
6957*da0073e9SAndroid Build Coastguard Worker                continue
6958*da0073e9SAndroid Build Coastguard Worker
6959*da0073e9SAndroid Build Coastguard Worker            # check on different data
6960*da0073e9SAndroid Build Coastguard Worker            values, sample = self._perturb(Dist, keys, values, sample)
6961*da0073e9SAndroid Build Coastguard Worker            expected = f(sample, *values)
6962*da0073e9SAndroid Build Coastguard Worker            actual = traced_f(sample, *values)
6963*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6964*da0073e9SAndroid Build Coastguard Worker                expected,
6965*da0073e9SAndroid Build Coastguard Worker                actual,
6966*da0073e9SAndroid Build Coastguard Worker                msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}",
6967*da0073e9SAndroid Build Coastguard Worker            )
6968*da0073e9SAndroid Build Coastguard Worker
6969*da0073e9SAndroid Build Coastguard Worker
6970*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__" and torch._C.has_lapack:
6971*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
6972*da0073e9SAndroid Build Coastguard Worker    run_tests()
6973