1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker""" 4*da0073e9SAndroid Build Coastguard WorkerNote [Randomized statistical tests] 5*da0073e9SAndroid Build Coastguard Worker----------------------------------- 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard WorkerThis note describes how to maintain tests in this file as random sources 8*da0073e9SAndroid Build Coastguard Workerchange. This file contains two types of randomized tests: 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker1. The easier type of randomized test are tests that should always pass but are 11*da0073e9SAndroid Build Coastguard Worker initialized with random data. If these fail something is wrong, but it's 12*da0073e9SAndroid Build Coastguard Worker fine to use a fixed seed by inheriting from common.TestCase. 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker2. The trickier tests are statistical tests. These tests explicitly call 15*da0073e9SAndroid Build Coastguard Worker set_rng_seed(n) and are marked "see Note [Randomized statistical tests]". 16*da0073e9SAndroid Build Coastguard Worker These statistical tests have a known positive failure rate 17*da0073e9SAndroid Build Coastguard Worker (we set failure_rate=1e-3 by default). We need to balance strength of these 18*da0073e9SAndroid Build Coastguard Worker tests with annoyance of false alarms. One way that works is to specifically 19*da0073e9SAndroid Build Coastguard Worker set seeds in each of the randomized tests. When a random generator 20*da0073e9SAndroid Build Coastguard Worker occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some 21*da0073e9SAndroid Build Coastguard Worker of these statistical tests may (rarely) fail. If one fails in this case, 22*da0073e9SAndroid Build Coastguard Worker it's fine to increment the seed of the failing test (but you shouldn't need 23*da0073e9SAndroid Build Coastguard Worker to increment it more than once; otherwise something is probably actually 24*da0073e9SAndroid Build Coastguard Worker wrong). 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker3. `test_geometric_sample`, `test_binomial_sample` and `test_poisson_sample` 27*da0073e9SAndroid Build Coastguard Worker are validated against `scipy.stats.` which are not guaranteed to be identical 28*da0073e9SAndroid Build Coastguard Worker across different versions of scipy (namely, they yield invalid results in 1.7+) 29*da0073e9SAndroid Build Coastguard Worker""" 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerimport math 32*da0073e9SAndroid Build Coastguard Workerimport numbers 33*da0073e9SAndroid Build Coastguard Workerimport unittest 34*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 35*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 36*da0073e9SAndroid Build Coastguard Workerfrom random import shuffle 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerfrom packaging import version 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerimport torch 41*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD 42*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 43*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import grad 44*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.functional import jacobian 45*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import ( 46*da0073e9SAndroid Build Coastguard Worker Bernoulli, 47*da0073e9SAndroid Build Coastguard Worker Beta, 48*da0073e9SAndroid Build Coastguard Worker Binomial, 49*da0073e9SAndroid Build Coastguard Worker Categorical, 50*da0073e9SAndroid Build Coastguard Worker Cauchy, 51*da0073e9SAndroid Build Coastguard Worker Chi2, 52*da0073e9SAndroid Build Coastguard Worker constraints, 53*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli, 54*da0073e9SAndroid Build Coastguard Worker Dirichlet, 55*da0073e9SAndroid Build Coastguard Worker Distribution, 56*da0073e9SAndroid Build Coastguard Worker Exponential, 57*da0073e9SAndroid Build Coastguard Worker ExponentialFamily, 58*da0073e9SAndroid Build Coastguard Worker FisherSnedecor, 59*da0073e9SAndroid Build Coastguard Worker Gamma, 60*da0073e9SAndroid Build Coastguard Worker Geometric, 61*da0073e9SAndroid Build Coastguard Worker Gumbel, 62*da0073e9SAndroid Build Coastguard Worker HalfCauchy, 63*da0073e9SAndroid Build Coastguard Worker HalfNormal, 64*da0073e9SAndroid Build Coastguard Worker Independent, 65*da0073e9SAndroid Build Coastguard Worker InverseGamma, 66*da0073e9SAndroid Build Coastguard Worker kl_divergence, 67*da0073e9SAndroid Build Coastguard Worker Kumaraswamy, 68*da0073e9SAndroid Build Coastguard Worker Laplace, 69*da0073e9SAndroid Build Coastguard Worker LKJCholesky, 70*da0073e9SAndroid Build Coastguard Worker LogisticNormal, 71*da0073e9SAndroid Build Coastguard Worker LogNormal, 72*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, 73*da0073e9SAndroid Build Coastguard Worker MixtureSameFamily, 74*da0073e9SAndroid Build Coastguard Worker Multinomial, 75*da0073e9SAndroid Build Coastguard Worker MultivariateNormal, 76*da0073e9SAndroid Build Coastguard Worker NegativeBinomial, 77*da0073e9SAndroid Build Coastguard Worker Normal, 78*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 79*da0073e9SAndroid Build Coastguard Worker OneHotCategoricalStraightThrough, 80*da0073e9SAndroid Build Coastguard Worker Pareto, 81*da0073e9SAndroid Build Coastguard Worker Poisson, 82*da0073e9SAndroid Build Coastguard Worker RelaxedBernoulli, 83*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical, 84*da0073e9SAndroid Build Coastguard Worker StudentT, 85*da0073e9SAndroid Build Coastguard Worker TransformedDistribution, 86*da0073e9SAndroid Build Coastguard Worker Uniform, 87*da0073e9SAndroid Build Coastguard Worker VonMises, 88*da0073e9SAndroid Build Coastguard Worker Weibull, 89*da0073e9SAndroid Build Coastguard Worker Wishart, 90*da0073e9SAndroid Build Coastguard Worker) 91*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.constraint_registry import transform_to 92*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.constraints import Constraint, is_dependent 93*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.dirichlet import _Dirichlet_backward 94*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.kl import _kl_expfamily_expfamily 95*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import ( 96*da0073e9SAndroid Build Coastguard Worker AffineTransform, 97*da0073e9SAndroid Build Coastguard Worker CatTransform, 98*da0073e9SAndroid Build Coastguard Worker ExpTransform, 99*da0073e9SAndroid Build Coastguard Worker identity_transform, 100*da0073e9SAndroid Build Coastguard Worker StackTransform, 101*da0073e9SAndroid Build Coastguard Worker) 102*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import ( 103*da0073e9SAndroid Build Coastguard Worker lazy_property, 104*da0073e9SAndroid Build Coastguard Worker probs_to_logits, 105*da0073e9SAndroid Build Coastguard Worker tril_matrix_to_vec, 106*da0073e9SAndroid Build Coastguard Worker vec_to_tril_matrix, 107*da0073e9SAndroid Build Coastguard Worker) 108*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import softmax 109*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 110*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 111*da0073e9SAndroid Build Coastguard Worker gradcheck, 112*da0073e9SAndroid Build Coastguard Worker load_tests, 113*da0073e9SAndroid Build Coastguard Worker run_tests, 114*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 115*da0073e9SAndroid Build Coastguard Worker set_rng_seed, 116*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 117*da0073e9SAndroid Build Coastguard Worker TestCase, 118*da0073e9SAndroid Build Coastguard Worker) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 122*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 123*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard WorkerTEST_NUMPY = True 126*da0073e9SAndroid Build Coastguard Workertry: 127*da0073e9SAndroid Build Coastguard Worker import numpy as np 128*da0073e9SAndroid Build Coastguard Worker import scipy.special 129*da0073e9SAndroid Build Coastguard Worker import scipy.stats 130*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 131*da0073e9SAndroid Build Coastguard Worker TEST_NUMPY = False 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Workerdef pairwise(Dist, *params): 135*da0073e9SAndroid Build Coastguard Worker """ 136*da0073e9SAndroid Build Coastguard Worker Creates a pair of distributions `Dist` initialized to test each element of 137*da0073e9SAndroid Build Coastguard Worker param with each other. 138*da0073e9SAndroid Build Coastguard Worker """ 139*da0073e9SAndroid Build Coastguard Worker params1 = [torch.tensor([p] * len(p)) for p in params] 140*da0073e9SAndroid Build Coastguard Worker params2 = [p.transpose(0, 1) for p in params1] 141*da0073e9SAndroid Build Coastguard Worker return Dist(*params1), Dist(*params2) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workerdef is_all_nan(tensor): 145*da0073e9SAndroid Build Coastguard Worker """ 146*da0073e9SAndroid Build Coastguard Worker Checks if all entries of a tensor is nan. 147*da0073e9SAndroid Build Coastguard Worker """ 148*da0073e9SAndroid Build Coastguard Worker return (tensor != tensor).all() 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard WorkerExample = namedtuple("Example", ["Dist", "params"]) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker# Register all distributions for generic tests. 155*da0073e9SAndroid Build Coastguard Workerdef _get_examples(): 156*da0073e9SAndroid Build Coastguard Worker return [ 157*da0073e9SAndroid Build Coastguard Worker Example( 158*da0073e9SAndroid Build Coastguard Worker Bernoulli, 159*da0073e9SAndroid Build Coastguard Worker [ 160*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 161*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.3], requires_grad=True)}, 162*da0073e9SAndroid Build Coastguard Worker {"probs": 0.3}, 163*da0073e9SAndroid Build Coastguard Worker {"logits": torch.tensor([0.0], requires_grad=True)}, 164*da0073e9SAndroid Build Coastguard Worker ], 165*da0073e9SAndroid Build Coastguard Worker ), 166*da0073e9SAndroid Build Coastguard Worker Example( 167*da0073e9SAndroid Build Coastguard Worker Geometric, 168*da0073e9SAndroid Build Coastguard Worker [ 169*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 170*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.3], requires_grad=True)}, 171*da0073e9SAndroid Build Coastguard Worker {"probs": 0.3}, 172*da0073e9SAndroid Build Coastguard Worker ], 173*da0073e9SAndroid Build Coastguard Worker ), 174*da0073e9SAndroid Build Coastguard Worker Example( 175*da0073e9SAndroid Build Coastguard Worker Beta, 176*da0073e9SAndroid Build Coastguard Worker [ 177*da0073e9SAndroid Build Coastguard Worker { 178*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.randn(2, 3).exp().requires_grad_(), 179*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.randn(2, 3).exp().requires_grad_(), 180*da0073e9SAndroid Build Coastguard Worker }, 181*da0073e9SAndroid Build Coastguard Worker { 182*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.randn(4).exp().requires_grad_(), 183*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.randn(4).exp().requires_grad_(), 184*da0073e9SAndroid Build Coastguard Worker }, 185*da0073e9SAndroid Build Coastguard Worker ], 186*da0073e9SAndroid Build Coastguard Worker ), 187*da0073e9SAndroid Build Coastguard Worker Example( 188*da0073e9SAndroid Build Coastguard Worker Categorical, 189*da0073e9SAndroid Build Coastguard Worker [ 190*da0073e9SAndroid Build Coastguard Worker { 191*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 192*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 193*da0073e9SAndroid Build Coastguard Worker ) 194*da0073e9SAndroid Build Coastguard Worker }, 195*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 196*da0073e9SAndroid Build Coastguard Worker {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 197*da0073e9SAndroid Build Coastguard Worker ], 198*da0073e9SAndroid Build Coastguard Worker ), 199*da0073e9SAndroid Build Coastguard Worker Example( 200*da0073e9SAndroid Build Coastguard Worker Binomial, 201*da0073e9SAndroid Build Coastguard Worker [ 202*da0073e9SAndroid Build Coastguard Worker { 203*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 204*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 205*da0073e9SAndroid Build Coastguard Worker ), 206*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 207*da0073e9SAndroid Build Coastguard Worker }, 208*da0073e9SAndroid Build Coastguard Worker { 209*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 210*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 211*da0073e9SAndroid Build Coastguard Worker }, 212*da0073e9SAndroid Build Coastguard Worker { 213*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 214*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([10]), 215*da0073e9SAndroid Build Coastguard Worker }, 216*da0073e9SAndroid Build Coastguard Worker { 217*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 218*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([10, 8]), 219*da0073e9SAndroid Build Coastguard Worker }, 220*da0073e9SAndroid Build Coastguard Worker { 221*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 222*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]), 223*da0073e9SAndroid Build Coastguard Worker }, 224*da0073e9SAndroid Build Coastguard Worker { 225*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 226*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor(0.0), 227*da0073e9SAndroid Build Coastguard Worker }, 228*da0073e9SAndroid Build Coastguard Worker ], 229*da0073e9SAndroid Build Coastguard Worker ), 230*da0073e9SAndroid Build Coastguard Worker Example( 231*da0073e9SAndroid Build Coastguard Worker NegativeBinomial, 232*da0073e9SAndroid Build Coastguard Worker [ 233*da0073e9SAndroid Build Coastguard Worker { 234*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 235*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 236*da0073e9SAndroid Build Coastguard Worker ), 237*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 238*da0073e9SAndroid Build Coastguard Worker }, 239*da0073e9SAndroid Build Coastguard Worker { 240*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 241*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 242*da0073e9SAndroid Build Coastguard Worker }, 243*da0073e9SAndroid Build Coastguard Worker { 244*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 245*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([10]), 246*da0073e9SAndroid Build Coastguard Worker }, 247*da0073e9SAndroid Build Coastguard Worker { 248*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 249*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([10, 8]), 250*da0073e9SAndroid Build Coastguard Worker }, 251*da0073e9SAndroid Build Coastguard Worker { 252*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 253*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]), 254*da0073e9SAndroid Build Coastguard Worker }, 255*da0073e9SAndroid Build Coastguard Worker { 256*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 257*da0073e9SAndroid Build Coastguard Worker "total_count": torch.tensor(0.0), 258*da0073e9SAndroid Build Coastguard Worker }, 259*da0073e9SAndroid Build Coastguard Worker ], 260*da0073e9SAndroid Build Coastguard Worker ), 261*da0073e9SAndroid Build Coastguard Worker Example( 262*da0073e9SAndroid Build Coastguard Worker Multinomial, 263*da0073e9SAndroid Build Coastguard Worker [ 264*da0073e9SAndroid Build Coastguard Worker { 265*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 266*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 267*da0073e9SAndroid Build Coastguard Worker ), 268*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 269*da0073e9SAndroid Build Coastguard Worker }, 270*da0073e9SAndroid Build Coastguard Worker { 271*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 272*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 273*da0073e9SAndroid Build Coastguard Worker }, 274*da0073e9SAndroid Build Coastguard Worker ], 275*da0073e9SAndroid Build Coastguard Worker ), 276*da0073e9SAndroid Build Coastguard Worker Example( 277*da0073e9SAndroid Build Coastguard Worker Cauchy, 278*da0073e9SAndroid Build Coastguard Worker [ 279*da0073e9SAndroid Build Coastguard Worker {"loc": 0.0, "scale": 1.0}, 280*da0073e9SAndroid Build Coastguard Worker {"loc": torch.tensor([0.0]), "scale": 1.0}, 281*da0073e9SAndroid Build Coastguard Worker { 282*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([[0.0], [0.0]]), 283*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([[1.0], [1.0]]), 284*da0073e9SAndroid Build Coastguard Worker }, 285*da0073e9SAndroid Build Coastguard Worker ], 286*da0073e9SAndroid Build Coastguard Worker ), 287*da0073e9SAndroid Build Coastguard Worker Example( 288*da0073e9SAndroid Build Coastguard Worker Chi2, 289*da0073e9SAndroid Build Coastguard Worker [ 290*da0073e9SAndroid Build Coastguard Worker {"df": torch.randn(2, 3).exp().requires_grad_()}, 291*da0073e9SAndroid Build Coastguard Worker {"df": torch.randn(1).exp().requires_grad_()}, 292*da0073e9SAndroid Build Coastguard Worker ], 293*da0073e9SAndroid Build Coastguard Worker ), 294*da0073e9SAndroid Build Coastguard Worker Example( 295*da0073e9SAndroid Build Coastguard Worker StudentT, 296*da0073e9SAndroid Build Coastguard Worker [ 297*da0073e9SAndroid Build Coastguard Worker {"df": torch.randn(2, 3).exp().requires_grad_()}, 298*da0073e9SAndroid Build Coastguard Worker {"df": torch.randn(1).exp().requires_grad_()}, 299*da0073e9SAndroid Build Coastguard Worker ], 300*da0073e9SAndroid Build Coastguard Worker ), 301*da0073e9SAndroid Build Coastguard Worker Example( 302*da0073e9SAndroid Build Coastguard Worker Dirichlet, 303*da0073e9SAndroid Build Coastguard Worker [ 304*da0073e9SAndroid Build Coastguard Worker {"concentration": torch.randn(2, 3).exp().requires_grad_()}, 305*da0073e9SAndroid Build Coastguard Worker {"concentration": torch.randn(4).exp().requires_grad_()}, 306*da0073e9SAndroid Build Coastguard Worker ], 307*da0073e9SAndroid Build Coastguard Worker ), 308*da0073e9SAndroid Build Coastguard Worker Example( 309*da0073e9SAndroid Build Coastguard Worker Exponential, 310*da0073e9SAndroid Build Coastguard Worker [ 311*da0073e9SAndroid Build Coastguard Worker {"rate": torch.randn(5, 5).abs().requires_grad_()}, 312*da0073e9SAndroid Build Coastguard Worker {"rate": torch.randn(1).abs().requires_grad_()}, 313*da0073e9SAndroid Build Coastguard Worker ], 314*da0073e9SAndroid Build Coastguard Worker ), 315*da0073e9SAndroid Build Coastguard Worker Example( 316*da0073e9SAndroid Build Coastguard Worker FisherSnedecor, 317*da0073e9SAndroid Build Coastguard Worker [ 318*da0073e9SAndroid Build Coastguard Worker { 319*da0073e9SAndroid Build Coastguard Worker "df1": torch.randn(5, 5).abs().requires_grad_(), 320*da0073e9SAndroid Build Coastguard Worker "df2": torch.randn(5, 5).abs().requires_grad_(), 321*da0073e9SAndroid Build Coastguard Worker }, 322*da0073e9SAndroid Build Coastguard Worker { 323*da0073e9SAndroid Build Coastguard Worker "df1": torch.randn(1).abs().requires_grad_(), 324*da0073e9SAndroid Build Coastguard Worker "df2": torch.randn(1).abs().requires_grad_(), 325*da0073e9SAndroid Build Coastguard Worker }, 326*da0073e9SAndroid Build Coastguard Worker { 327*da0073e9SAndroid Build Coastguard Worker "df1": torch.tensor([1.0]), 328*da0073e9SAndroid Build Coastguard Worker "df2": 1.0, 329*da0073e9SAndroid Build Coastguard Worker }, 330*da0073e9SAndroid Build Coastguard Worker ], 331*da0073e9SAndroid Build Coastguard Worker ), 332*da0073e9SAndroid Build Coastguard Worker Example( 333*da0073e9SAndroid Build Coastguard Worker Gamma, 334*da0073e9SAndroid Build Coastguard Worker [ 335*da0073e9SAndroid Build Coastguard Worker { 336*da0073e9SAndroid Build Coastguard Worker "concentration": torch.randn(2, 3).exp().requires_grad_(), 337*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(2, 3).exp().requires_grad_(), 338*da0073e9SAndroid Build Coastguard Worker }, 339*da0073e9SAndroid Build Coastguard Worker { 340*da0073e9SAndroid Build Coastguard Worker "concentration": torch.randn(1).exp().requires_grad_(), 341*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(1).exp().requires_grad_(), 342*da0073e9SAndroid Build Coastguard Worker }, 343*da0073e9SAndroid Build Coastguard Worker ], 344*da0073e9SAndroid Build Coastguard Worker ), 345*da0073e9SAndroid Build Coastguard Worker Example( 346*da0073e9SAndroid Build Coastguard Worker Gumbel, 347*da0073e9SAndroid Build Coastguard Worker [ 348*da0073e9SAndroid Build Coastguard Worker { 349*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 5, requires_grad=True), 350*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 351*da0073e9SAndroid Build Coastguard Worker }, 352*da0073e9SAndroid Build Coastguard Worker { 353*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(1, requires_grad=True), 354*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(1).abs().requires_grad_(), 355*da0073e9SAndroid Build Coastguard Worker }, 356*da0073e9SAndroid Build Coastguard Worker ], 357*da0073e9SAndroid Build Coastguard Worker ), 358*da0073e9SAndroid Build Coastguard Worker Example(HalfCauchy, [{"scale": 1.0}, {"scale": torch.tensor([[1.0], [1.0]])}]), 359*da0073e9SAndroid Build Coastguard Worker Example( 360*da0073e9SAndroid Build Coastguard Worker HalfNormal, 361*da0073e9SAndroid Build Coastguard Worker [ 362*da0073e9SAndroid Build Coastguard Worker {"scale": torch.randn(5, 5).abs().requires_grad_()}, 363*da0073e9SAndroid Build Coastguard Worker {"scale": torch.randn(1).abs().requires_grad_()}, 364*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([1e-5, 1e-5], requires_grad=True)}, 365*da0073e9SAndroid Build Coastguard Worker ], 366*da0073e9SAndroid Build Coastguard Worker ), 367*da0073e9SAndroid Build Coastguard Worker Example( 368*da0073e9SAndroid Build Coastguard Worker Independent, 369*da0073e9SAndroid Build Coastguard Worker [ 370*da0073e9SAndroid Build Coastguard Worker { 371*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 372*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True), 373*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3).abs().requires_grad_(), 374*da0073e9SAndroid Build Coastguard Worker ), 375*da0073e9SAndroid Build Coastguard Worker "reinterpreted_batch_ndims": 0, 376*da0073e9SAndroid Build Coastguard Worker }, 377*da0073e9SAndroid Build Coastguard Worker { 378*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 379*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True), 380*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3).abs().requires_grad_(), 381*da0073e9SAndroid Build Coastguard Worker ), 382*da0073e9SAndroid Build Coastguard Worker "reinterpreted_batch_ndims": 1, 383*da0073e9SAndroid Build Coastguard Worker }, 384*da0073e9SAndroid Build Coastguard Worker { 385*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 386*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True), 387*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3).abs().requires_grad_(), 388*da0073e9SAndroid Build Coastguard Worker ), 389*da0073e9SAndroid Build Coastguard Worker "reinterpreted_batch_ndims": 2, 390*da0073e9SAndroid Build Coastguard Worker }, 391*da0073e9SAndroid Build Coastguard Worker { 392*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 393*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5, requires_grad=True), 394*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5).abs().requires_grad_(), 395*da0073e9SAndroid Build Coastguard Worker ), 396*da0073e9SAndroid Build Coastguard Worker "reinterpreted_batch_ndims": 2, 397*da0073e9SAndroid Build Coastguard Worker }, 398*da0073e9SAndroid Build Coastguard Worker { 399*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 400*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5, requires_grad=True), 401*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5).abs().requires_grad_(), 402*da0073e9SAndroid Build Coastguard Worker ), 403*da0073e9SAndroid Build Coastguard Worker "reinterpreted_batch_ndims": 3, 404*da0073e9SAndroid Build Coastguard Worker }, 405*da0073e9SAndroid Build Coastguard Worker ], 406*da0073e9SAndroid Build Coastguard Worker ), 407*da0073e9SAndroid Build Coastguard Worker Example( 408*da0073e9SAndroid Build Coastguard Worker Kumaraswamy, 409*da0073e9SAndroid Build Coastguard Worker [ 410*da0073e9SAndroid Build Coastguard Worker { 411*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.empty(2, 3).uniform_(1, 2).requires_grad_(), 412*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.empty(2, 3).uniform_(1, 2).requires_grad_(), 413*da0073e9SAndroid Build Coastguard Worker }, 414*da0073e9SAndroid Build Coastguard Worker { 415*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.rand(4).uniform_(1, 2).requires_grad_(), 416*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.rand(4).uniform_(1, 2).requires_grad_(), 417*da0073e9SAndroid Build Coastguard Worker }, 418*da0073e9SAndroid Build Coastguard Worker ], 419*da0073e9SAndroid Build Coastguard Worker ), 420*da0073e9SAndroid Build Coastguard Worker Example( 421*da0073e9SAndroid Build Coastguard Worker LKJCholesky, 422*da0073e9SAndroid Build Coastguard Worker [ 423*da0073e9SAndroid Build Coastguard Worker {"dim": 2, "concentration": 0.5}, 424*da0073e9SAndroid Build Coastguard Worker { 425*da0073e9SAndroid Build Coastguard Worker "dim": 3, 426*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([0.5, 1.0, 2.0]), 427*da0073e9SAndroid Build Coastguard Worker }, 428*da0073e9SAndroid Build Coastguard Worker {"dim": 100, "concentration": 4.0}, 429*da0073e9SAndroid Build Coastguard Worker ], 430*da0073e9SAndroid Build Coastguard Worker ), 431*da0073e9SAndroid Build Coastguard Worker Example( 432*da0073e9SAndroid Build Coastguard Worker Laplace, 433*da0073e9SAndroid Build Coastguard Worker [ 434*da0073e9SAndroid Build Coastguard Worker { 435*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 5, requires_grad=True), 436*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 437*da0073e9SAndroid Build Coastguard Worker }, 438*da0073e9SAndroid Build Coastguard Worker { 439*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(1, requires_grad=True), 440*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(1).abs().requires_grad_(), 441*da0073e9SAndroid Build Coastguard Worker }, 442*da0073e9SAndroid Build Coastguard Worker { 443*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 0.0], requires_grad=True), 444*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 445*da0073e9SAndroid Build Coastguard Worker }, 446*da0073e9SAndroid Build Coastguard Worker ], 447*da0073e9SAndroid Build Coastguard Worker ), 448*da0073e9SAndroid Build Coastguard Worker Example( 449*da0073e9SAndroid Build Coastguard Worker LogNormal, 450*da0073e9SAndroid Build Coastguard Worker [ 451*da0073e9SAndroid Build Coastguard Worker { 452*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 5, requires_grad=True), 453*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 454*da0073e9SAndroid Build Coastguard Worker }, 455*da0073e9SAndroid Build Coastguard Worker { 456*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(1, requires_grad=True), 457*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(1).abs().requires_grad_(), 458*da0073e9SAndroid Build Coastguard Worker }, 459*da0073e9SAndroid Build Coastguard Worker { 460*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 0.0], requires_grad=True), 461*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 462*da0073e9SAndroid Build Coastguard Worker }, 463*da0073e9SAndroid Build Coastguard Worker ], 464*da0073e9SAndroid Build Coastguard Worker ), 465*da0073e9SAndroid Build Coastguard Worker Example( 466*da0073e9SAndroid Build Coastguard Worker LogisticNormal, 467*da0073e9SAndroid Build Coastguard Worker [ 468*da0073e9SAndroid Build Coastguard Worker { 469*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 5).requires_grad_(), 470*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 471*da0073e9SAndroid Build Coastguard Worker }, 472*da0073e9SAndroid Build Coastguard Worker { 473*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(1).requires_grad_(), 474*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(1).abs().requires_grad_(), 475*da0073e9SAndroid Build Coastguard Worker }, 476*da0073e9SAndroid Build Coastguard Worker { 477*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 0.0], requires_grad=True), 478*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 479*da0073e9SAndroid Build Coastguard Worker }, 480*da0073e9SAndroid Build Coastguard Worker ], 481*da0073e9SAndroid Build Coastguard Worker ), 482*da0073e9SAndroid Build Coastguard Worker Example( 483*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, 484*da0073e9SAndroid Build Coastguard Worker [ 485*da0073e9SAndroid Build Coastguard Worker { 486*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 2, requires_grad=True), 487*da0073e9SAndroid Build Coastguard Worker "cov_factor": torch.randn(5, 2, 1, requires_grad=True), 488*da0073e9SAndroid Build Coastguard Worker "cov_diag": torch.tensor([2.0, 0.25], requires_grad=True), 489*da0073e9SAndroid Build Coastguard Worker }, 490*da0073e9SAndroid Build Coastguard Worker { 491*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(4, 3, requires_grad=True), 492*da0073e9SAndroid Build Coastguard Worker "cov_factor": torch.randn(3, 2, requires_grad=True), 493*da0073e9SAndroid Build Coastguard Worker "cov_diag": torch.tensor([5.0, 1.5, 3.0], requires_grad=True), 494*da0073e9SAndroid Build Coastguard Worker }, 495*da0073e9SAndroid Build Coastguard Worker ], 496*da0073e9SAndroid Build Coastguard Worker ), 497*da0073e9SAndroid Build Coastguard Worker Example( 498*da0073e9SAndroid Build Coastguard Worker MultivariateNormal, 499*da0073e9SAndroid Build Coastguard Worker [ 500*da0073e9SAndroid Build Coastguard Worker { 501*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 2, requires_grad=True), 502*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 503*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 504*da0073e9SAndroid Build Coastguard Worker ), 505*da0073e9SAndroid Build Coastguard Worker }, 506*da0073e9SAndroid Build Coastguard Worker { 507*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(2, 3, requires_grad=True), 508*da0073e9SAndroid Build Coastguard Worker "precision_matrix": torch.tensor( 509*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]], 510*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 511*da0073e9SAndroid Build Coastguard Worker ), 512*da0073e9SAndroid Build Coastguard Worker }, 513*da0073e9SAndroid Build Coastguard Worker { 514*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 3, 2, requires_grad=True), 515*da0073e9SAndroid Build Coastguard Worker "scale_tril": torch.tensor( 516*da0073e9SAndroid Build Coastguard Worker [ 517*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.0], [-0.5, 0.25]], 518*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.0], [0.3, 0.25]], 519*da0073e9SAndroid Build Coastguard Worker [[5.0, 0.0], [-0.5, 1.5]], 520*da0073e9SAndroid Build Coastguard Worker ], 521*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 522*da0073e9SAndroid Build Coastguard Worker ), 523*da0073e9SAndroid Build Coastguard Worker }, 524*da0073e9SAndroid Build Coastguard Worker { 525*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, -1.0]), 526*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 527*da0073e9SAndroid Build Coastguard Worker }, 528*da0073e9SAndroid Build Coastguard Worker ], 529*da0073e9SAndroid Build Coastguard Worker ), 530*da0073e9SAndroid Build Coastguard Worker Example( 531*da0073e9SAndroid Build Coastguard Worker Normal, 532*da0073e9SAndroid Build Coastguard Worker [ 533*da0073e9SAndroid Build Coastguard Worker { 534*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(5, 5, requires_grad=True), 535*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 536*da0073e9SAndroid Build Coastguard Worker }, 537*da0073e9SAndroid Build Coastguard Worker { 538*da0073e9SAndroid Build Coastguard Worker "loc": torch.randn(1, requires_grad=True), 539*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(1).abs().requires_grad_(), 540*da0073e9SAndroid Build Coastguard Worker }, 541*da0073e9SAndroid Build Coastguard Worker { 542*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 0.0], requires_grad=True), 543*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 544*da0073e9SAndroid Build Coastguard Worker }, 545*da0073e9SAndroid Build Coastguard Worker ], 546*da0073e9SAndroid Build Coastguard Worker ), 547*da0073e9SAndroid Build Coastguard Worker Example( 548*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 549*da0073e9SAndroid Build Coastguard Worker [ 550*da0073e9SAndroid Build Coastguard Worker { 551*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 552*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 553*da0073e9SAndroid Build Coastguard Worker ) 554*da0073e9SAndroid Build Coastguard Worker }, 555*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 556*da0073e9SAndroid Build Coastguard Worker {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 557*da0073e9SAndroid Build Coastguard Worker ], 558*da0073e9SAndroid Build Coastguard Worker ), 559*da0073e9SAndroid Build Coastguard Worker Example( 560*da0073e9SAndroid Build Coastguard Worker OneHotCategoricalStraightThrough, 561*da0073e9SAndroid Build Coastguard Worker [ 562*da0073e9SAndroid Build Coastguard Worker { 563*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 564*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 565*da0073e9SAndroid Build Coastguard Worker ) 566*da0073e9SAndroid Build Coastguard Worker }, 567*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 568*da0073e9SAndroid Build Coastguard Worker {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 569*da0073e9SAndroid Build Coastguard Worker ], 570*da0073e9SAndroid Build Coastguard Worker ), 571*da0073e9SAndroid Build Coastguard Worker Example( 572*da0073e9SAndroid Build Coastguard Worker Pareto, 573*da0073e9SAndroid Build Coastguard Worker [ 574*da0073e9SAndroid Build Coastguard Worker {"scale": 1.0, "alpha": 1.0}, 575*da0073e9SAndroid Build Coastguard Worker { 576*da0073e9SAndroid Build Coastguard Worker "scale": (torch.randn(5, 5).abs() + 0.1).requires_grad_(), 577*da0073e9SAndroid Build Coastguard Worker "alpha": (torch.randn(5, 5).abs() + 0.1).requires_grad_(), 578*da0073e9SAndroid Build Coastguard Worker }, 579*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([1.0]), "alpha": 1.0}, 580*da0073e9SAndroid Build Coastguard Worker ], 581*da0073e9SAndroid Build Coastguard Worker ), 582*da0073e9SAndroid Build Coastguard Worker Example( 583*da0073e9SAndroid Build Coastguard Worker Poisson, 584*da0073e9SAndroid Build Coastguard Worker [ 585*da0073e9SAndroid Build Coastguard Worker { 586*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(5, 5).abs().requires_grad_(), 587*da0073e9SAndroid Build Coastguard Worker }, 588*da0073e9SAndroid Build Coastguard Worker { 589*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(3).abs().requires_grad_(), 590*da0073e9SAndroid Build Coastguard Worker }, 591*da0073e9SAndroid Build Coastguard Worker { 592*da0073e9SAndroid Build Coastguard Worker "rate": 0.2, 593*da0073e9SAndroid Build Coastguard Worker }, 594*da0073e9SAndroid Build Coastguard Worker { 595*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([0.0], requires_grad=True), 596*da0073e9SAndroid Build Coastguard Worker }, 597*da0073e9SAndroid Build Coastguard Worker { 598*da0073e9SAndroid Build Coastguard Worker "rate": 0.0, 599*da0073e9SAndroid Build Coastguard Worker }, 600*da0073e9SAndroid Build Coastguard Worker ], 601*da0073e9SAndroid Build Coastguard Worker ), 602*da0073e9SAndroid Build Coastguard Worker Example( 603*da0073e9SAndroid Build Coastguard Worker RelaxedBernoulli, 604*da0073e9SAndroid Build Coastguard Worker [ 605*da0073e9SAndroid Build Coastguard Worker { 606*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([0.5], requires_grad=True), 607*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True), 608*da0073e9SAndroid Build Coastguard Worker }, 609*da0073e9SAndroid Build Coastguard Worker { 610*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([2.0]), 611*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([0.3]), 612*da0073e9SAndroid Build Coastguard Worker }, 613*da0073e9SAndroid Build Coastguard Worker { 614*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([7.2]), 615*da0073e9SAndroid Build Coastguard Worker "logits": torch.tensor([-2.0, 2.0, 1.0, 5.0]), 616*da0073e9SAndroid Build Coastguard Worker }, 617*da0073e9SAndroid Build Coastguard Worker ], 618*da0073e9SAndroid Build Coastguard Worker ), 619*da0073e9SAndroid Build Coastguard Worker Example( 620*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical, 621*da0073e9SAndroid Build Coastguard Worker [ 622*da0073e9SAndroid Build Coastguard Worker { 623*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([0.5], requires_grad=True), 624*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 625*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True 626*da0073e9SAndroid Build Coastguard Worker ), 627*da0073e9SAndroid Build Coastguard Worker }, 628*da0073e9SAndroid Build Coastguard Worker { 629*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([2.0]), 630*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]]), 631*da0073e9SAndroid Build Coastguard Worker }, 632*da0073e9SAndroid Build Coastguard Worker { 633*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([7.2]), 634*da0073e9SAndroid Build Coastguard Worker "logits": torch.tensor([[-2.0, 2.0], [1.0, 5.0]]), 635*da0073e9SAndroid Build Coastguard Worker }, 636*da0073e9SAndroid Build Coastguard Worker ], 637*da0073e9SAndroid Build Coastguard Worker ), 638*da0073e9SAndroid Build Coastguard Worker Example( 639*da0073e9SAndroid Build Coastguard Worker TransformedDistribution, 640*da0073e9SAndroid Build Coastguard Worker [ 641*da0073e9SAndroid Build Coastguard Worker { 642*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 643*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True), 644*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3).abs().requires_grad_(), 645*da0073e9SAndroid Build Coastguard Worker ), 646*da0073e9SAndroid Build Coastguard Worker "transforms": [], 647*da0073e9SAndroid Build Coastguard Worker }, 648*da0073e9SAndroid Build Coastguard Worker { 649*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 650*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, requires_grad=True), 651*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3).abs().requires_grad_(), 652*da0073e9SAndroid Build Coastguard Worker ), 653*da0073e9SAndroid Build Coastguard Worker "transforms": ExpTransform(), 654*da0073e9SAndroid Build Coastguard Worker }, 655*da0073e9SAndroid Build Coastguard Worker { 656*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 657*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5, requires_grad=True), 658*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5).abs().requires_grad_(), 659*da0073e9SAndroid Build Coastguard Worker ), 660*da0073e9SAndroid Build Coastguard Worker "transforms": [ 661*da0073e9SAndroid Build Coastguard Worker AffineTransform(torch.randn(3, 5), torch.randn(3, 5)), 662*da0073e9SAndroid Build Coastguard Worker ExpTransform(), 663*da0073e9SAndroid Build Coastguard Worker ], 664*da0073e9SAndroid Build Coastguard Worker }, 665*da0073e9SAndroid Build Coastguard Worker { 666*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal( 667*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5, requires_grad=True), 668*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 5).abs().requires_grad_(), 669*da0073e9SAndroid Build Coastguard Worker ), 670*da0073e9SAndroid Build Coastguard Worker "transforms": AffineTransform(1, 2), 671*da0073e9SAndroid Build Coastguard Worker }, 672*da0073e9SAndroid Build Coastguard Worker { 673*da0073e9SAndroid Build Coastguard Worker "base_distribution": Uniform( 674*da0073e9SAndroid Build Coastguard Worker torch.tensor(1e8).log(), torch.tensor(1e10).log() 675*da0073e9SAndroid Build Coastguard Worker ), 676*da0073e9SAndroid Build Coastguard Worker "transforms": ExpTransform(), 677*da0073e9SAndroid Build Coastguard Worker }, 678*da0073e9SAndroid Build Coastguard Worker ], 679*da0073e9SAndroid Build Coastguard Worker ), 680*da0073e9SAndroid Build Coastguard Worker Example( 681*da0073e9SAndroid Build Coastguard Worker Uniform, 682*da0073e9SAndroid Build Coastguard Worker [ 683*da0073e9SAndroid Build Coastguard Worker { 684*da0073e9SAndroid Build Coastguard Worker "low": torch.zeros(5, 5, requires_grad=True), 685*da0073e9SAndroid Build Coastguard Worker "high": torch.ones(5, 5, requires_grad=True), 686*da0073e9SAndroid Build Coastguard Worker }, 687*da0073e9SAndroid Build Coastguard Worker { 688*da0073e9SAndroid Build Coastguard Worker "low": torch.zeros(1, requires_grad=True), 689*da0073e9SAndroid Build Coastguard Worker "high": torch.ones(1, requires_grad=True), 690*da0073e9SAndroid Build Coastguard Worker }, 691*da0073e9SAndroid Build Coastguard Worker { 692*da0073e9SAndroid Build Coastguard Worker "low": torch.tensor([1.0, 1.0], requires_grad=True), 693*da0073e9SAndroid Build Coastguard Worker "high": torch.tensor([2.0, 3.0], requires_grad=True), 694*da0073e9SAndroid Build Coastguard Worker }, 695*da0073e9SAndroid Build Coastguard Worker ], 696*da0073e9SAndroid Build Coastguard Worker ), 697*da0073e9SAndroid Build Coastguard Worker Example( 698*da0073e9SAndroid Build Coastguard Worker Weibull, 699*da0073e9SAndroid Build Coastguard Worker [ 700*da0073e9SAndroid Build Coastguard Worker { 701*da0073e9SAndroid Build Coastguard Worker "scale": torch.randn(5, 5).abs().requires_grad_(), 702*da0073e9SAndroid Build Coastguard Worker "concentration": torch.randn(1).abs().requires_grad_(), 703*da0073e9SAndroid Build Coastguard Worker } 704*da0073e9SAndroid Build Coastguard Worker ], 705*da0073e9SAndroid Build Coastguard Worker ), 706*da0073e9SAndroid Build Coastguard Worker Example( 707*da0073e9SAndroid Build Coastguard Worker Wishart, 708*da0073e9SAndroid Build Coastguard Worker [ 709*da0073e9SAndroid Build Coastguard Worker { 710*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 711*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 712*da0073e9SAndroid Build Coastguard Worker ), 713*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([3.0], requires_grad=True), 714*da0073e9SAndroid Build Coastguard Worker }, 715*da0073e9SAndroid Build Coastguard Worker { 716*da0073e9SAndroid Build Coastguard Worker "precision_matrix": torch.tensor( 717*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]], 718*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 719*da0073e9SAndroid Build Coastguard Worker ), 720*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([5.0, 4], requires_grad=True), 721*da0073e9SAndroid Build Coastguard Worker }, 722*da0073e9SAndroid Build Coastguard Worker { 723*da0073e9SAndroid Build Coastguard Worker "scale_tril": torch.tensor( 724*da0073e9SAndroid Build Coastguard Worker [ 725*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.0], [-0.5, 0.25]], 726*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.0], [0.3, 0.25]], 727*da0073e9SAndroid Build Coastguard Worker [[5.0, 0.0], [-0.5, 1.5]], 728*da0073e9SAndroid Build Coastguard Worker ], 729*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 730*da0073e9SAndroid Build Coastguard Worker ), 731*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([5.0, 3.5, 3], requires_grad=True), 732*da0073e9SAndroid Build Coastguard Worker }, 733*da0073e9SAndroid Build Coastguard Worker { 734*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 735*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([3.0]), 736*da0073e9SAndroid Build Coastguard Worker }, 737*da0073e9SAndroid Build Coastguard Worker { 738*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 739*da0073e9SAndroid Build Coastguard Worker "df": 3.0, 740*da0073e9SAndroid Build Coastguard Worker }, 741*da0073e9SAndroid Build Coastguard Worker ], 742*da0073e9SAndroid Build Coastguard Worker ), 743*da0073e9SAndroid Build Coastguard Worker Example( 744*da0073e9SAndroid Build Coastguard Worker MixtureSameFamily, 745*da0073e9SAndroid Build Coastguard Worker [ 746*da0073e9SAndroid Build Coastguard Worker { 747*da0073e9SAndroid Build Coastguard Worker "mixture_distribution": Categorical( 748*da0073e9SAndroid Build Coastguard Worker torch.rand(5, requires_grad=True) 749*da0073e9SAndroid Build Coastguard Worker ), 750*da0073e9SAndroid Build Coastguard Worker "component_distribution": Normal( 751*da0073e9SAndroid Build Coastguard Worker torch.randn(5, requires_grad=True), 752*da0073e9SAndroid Build Coastguard Worker torch.rand(5, requires_grad=True), 753*da0073e9SAndroid Build Coastguard Worker ), 754*da0073e9SAndroid Build Coastguard Worker }, 755*da0073e9SAndroid Build Coastguard Worker { 756*da0073e9SAndroid Build Coastguard Worker "mixture_distribution": Categorical( 757*da0073e9SAndroid Build Coastguard Worker torch.rand(5, requires_grad=True) 758*da0073e9SAndroid Build Coastguard Worker ), 759*da0073e9SAndroid Build Coastguard Worker "component_distribution": MultivariateNormal( 760*da0073e9SAndroid Build Coastguard Worker loc=torch.randn(5, 2, requires_grad=True), 761*da0073e9SAndroid Build Coastguard Worker covariance_matrix=torch.tensor( 762*da0073e9SAndroid Build Coastguard Worker [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 763*da0073e9SAndroid Build Coastguard Worker ), 764*da0073e9SAndroid Build Coastguard Worker ), 765*da0073e9SAndroid Build Coastguard Worker }, 766*da0073e9SAndroid Build Coastguard Worker ], 767*da0073e9SAndroid Build Coastguard Worker ), 768*da0073e9SAndroid Build Coastguard Worker Example( 769*da0073e9SAndroid Build Coastguard Worker VonMises, 770*da0073e9SAndroid Build Coastguard Worker [ 771*da0073e9SAndroid Build Coastguard Worker { 772*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor(1.0, requires_grad=True), 773*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor(10.0, requires_grad=True), 774*da0073e9SAndroid Build Coastguard Worker }, 775*da0073e9SAndroid Build Coastguard Worker { 776*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([0.0, math.pi / 2], requires_grad=True), 777*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([1.0, 10.0], requires_grad=True), 778*da0073e9SAndroid Build Coastguard Worker }, 779*da0073e9SAndroid Build Coastguard Worker ], 780*da0073e9SAndroid Build Coastguard Worker ), 781*da0073e9SAndroid Build Coastguard Worker Example( 782*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli, 783*da0073e9SAndroid Build Coastguard Worker [ 784*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 785*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([0.3], requires_grad=True)}, 786*da0073e9SAndroid Build Coastguard Worker {"probs": 0.3}, 787*da0073e9SAndroid Build Coastguard Worker {"logits": torch.tensor([0.0], requires_grad=True)}, 788*da0073e9SAndroid Build Coastguard Worker ], 789*da0073e9SAndroid Build Coastguard Worker ), 790*da0073e9SAndroid Build Coastguard Worker Example( 791*da0073e9SAndroid Build Coastguard Worker InverseGamma, 792*da0073e9SAndroid Build Coastguard Worker [ 793*da0073e9SAndroid Build Coastguard Worker { 794*da0073e9SAndroid Build Coastguard Worker "concentration": torch.randn(2, 3).exp().requires_grad_(), 795*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(2, 3).exp().requires_grad_(), 796*da0073e9SAndroid Build Coastguard Worker }, 797*da0073e9SAndroid Build Coastguard Worker { 798*da0073e9SAndroid Build Coastguard Worker "concentration": torch.randn(1).exp().requires_grad_(), 799*da0073e9SAndroid Build Coastguard Worker "rate": torch.randn(1).exp().requires_grad_(), 800*da0073e9SAndroid Build Coastguard Worker }, 801*da0073e9SAndroid Build Coastguard Worker ], 802*da0073e9SAndroid Build Coastguard Worker ), 803*da0073e9SAndroid Build Coastguard Worker ] 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Workerdef _get_bad_examples(): 807*da0073e9SAndroid Build Coastguard Worker return [ 808*da0073e9SAndroid Build Coastguard Worker Example( 809*da0073e9SAndroid Build Coastguard Worker Bernoulli, 810*da0073e9SAndroid Build Coastguard Worker [ 811*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 812*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([-0.5], requires_grad=True)}, 813*da0073e9SAndroid Build Coastguard Worker {"probs": 1.00001}, 814*da0073e9SAndroid Build Coastguard Worker ], 815*da0073e9SAndroid Build Coastguard Worker ), 816*da0073e9SAndroid Build Coastguard Worker Example( 817*da0073e9SAndroid Build Coastguard Worker Beta, 818*da0073e9SAndroid Build Coastguard Worker [ 819*da0073e9SAndroid Build Coastguard Worker { 820*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.tensor([0.0], requires_grad=True), 821*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.tensor([0.0], requires_grad=True), 822*da0073e9SAndroid Build Coastguard Worker }, 823*da0073e9SAndroid Build Coastguard Worker { 824*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.tensor([-1.0], requires_grad=True), 825*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.tensor([-2.0], requires_grad=True), 826*da0073e9SAndroid Build Coastguard Worker }, 827*da0073e9SAndroid Build Coastguard Worker ], 828*da0073e9SAndroid Build Coastguard Worker ), 829*da0073e9SAndroid Build Coastguard Worker Example( 830*da0073e9SAndroid Build Coastguard Worker Geometric, 831*da0073e9SAndroid Build Coastguard Worker [ 832*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 833*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([-0.3], requires_grad=True)}, 834*da0073e9SAndroid Build Coastguard Worker {"probs": 1.00000001}, 835*da0073e9SAndroid Build Coastguard Worker ], 836*da0073e9SAndroid Build Coastguard Worker ), 837*da0073e9SAndroid Build Coastguard Worker Example( 838*da0073e9SAndroid Build Coastguard Worker Categorical, 839*da0073e9SAndroid Build Coastguard Worker [ 840*da0073e9SAndroid Build Coastguard Worker { 841*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 842*da0073e9SAndroid Build Coastguard Worker [[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 843*da0073e9SAndroid Build Coastguard Worker ) 844*da0073e9SAndroid Build Coastguard Worker }, 845*da0073e9SAndroid Build Coastguard Worker { 846*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 847*da0073e9SAndroid Build Coastguard Worker [[-1.0, 10.0], [0.0, -1.0]], requires_grad=True 848*da0073e9SAndroid Build Coastguard Worker ) 849*da0073e9SAndroid Build Coastguard Worker }, 850*da0073e9SAndroid Build Coastguard Worker ], 851*da0073e9SAndroid Build Coastguard Worker ), 852*da0073e9SAndroid Build Coastguard Worker Example( 853*da0073e9SAndroid Build Coastguard Worker Binomial, 854*da0073e9SAndroid Build Coastguard Worker [ 855*da0073e9SAndroid Build Coastguard Worker { 856*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 857*da0073e9SAndroid Build Coastguard Worker [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 858*da0073e9SAndroid Build Coastguard Worker ), 859*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 860*da0073e9SAndroid Build Coastguard Worker }, 861*da0073e9SAndroid Build Coastguard Worker { 862*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), 863*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 864*da0073e9SAndroid Build Coastguard Worker }, 865*da0073e9SAndroid Build Coastguard Worker ], 866*da0073e9SAndroid Build Coastguard Worker ), 867*da0073e9SAndroid Build Coastguard Worker Example( 868*da0073e9SAndroid Build Coastguard Worker NegativeBinomial, 869*da0073e9SAndroid Build Coastguard Worker [ 870*da0073e9SAndroid Build Coastguard Worker { 871*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 872*da0073e9SAndroid Build Coastguard Worker [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 873*da0073e9SAndroid Build Coastguard Worker ), 874*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 875*da0073e9SAndroid Build Coastguard Worker }, 876*da0073e9SAndroid Build Coastguard Worker { 877*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), 878*da0073e9SAndroid Build Coastguard Worker "total_count": 10, 879*da0073e9SAndroid Build Coastguard Worker }, 880*da0073e9SAndroid Build Coastguard Worker ], 881*da0073e9SAndroid Build Coastguard Worker ), 882*da0073e9SAndroid Build Coastguard Worker Example( 883*da0073e9SAndroid Build Coastguard Worker Cauchy, 884*da0073e9SAndroid Build Coastguard Worker [ 885*da0073e9SAndroid Build Coastguard Worker {"loc": 0.0, "scale": -1.0}, 886*da0073e9SAndroid Build Coastguard Worker {"loc": torch.tensor([0.0]), "scale": 0.0}, 887*da0073e9SAndroid Build Coastguard Worker { 888*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([[0.0], [-2.0]]), 889*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([[-0.000001], [1.0]]), 890*da0073e9SAndroid Build Coastguard Worker }, 891*da0073e9SAndroid Build Coastguard Worker ], 892*da0073e9SAndroid Build Coastguard Worker ), 893*da0073e9SAndroid Build Coastguard Worker Example( 894*da0073e9SAndroid Build Coastguard Worker Chi2, 895*da0073e9SAndroid Build Coastguard Worker [ 896*da0073e9SAndroid Build Coastguard Worker {"df": torch.tensor([0.0], requires_grad=True)}, 897*da0073e9SAndroid Build Coastguard Worker {"df": torch.tensor([-2.0], requires_grad=True)}, 898*da0073e9SAndroid Build Coastguard Worker ], 899*da0073e9SAndroid Build Coastguard Worker ), 900*da0073e9SAndroid Build Coastguard Worker Example( 901*da0073e9SAndroid Build Coastguard Worker StudentT, 902*da0073e9SAndroid Build Coastguard Worker [ 903*da0073e9SAndroid Build Coastguard Worker {"df": torch.tensor([0.0], requires_grad=True)}, 904*da0073e9SAndroid Build Coastguard Worker {"df": torch.tensor([-2.0], requires_grad=True)}, 905*da0073e9SAndroid Build Coastguard Worker ], 906*da0073e9SAndroid Build Coastguard Worker ), 907*da0073e9SAndroid Build Coastguard Worker Example( 908*da0073e9SAndroid Build Coastguard Worker Dirichlet, 909*da0073e9SAndroid Build Coastguard Worker [ 910*da0073e9SAndroid Build Coastguard Worker {"concentration": torch.tensor([0.0], requires_grad=True)}, 911*da0073e9SAndroid Build Coastguard Worker {"concentration": torch.tensor([-2.0], requires_grad=True)}, 912*da0073e9SAndroid Build Coastguard Worker ], 913*da0073e9SAndroid Build Coastguard Worker ), 914*da0073e9SAndroid Build Coastguard Worker Example( 915*da0073e9SAndroid Build Coastguard Worker Exponential, 916*da0073e9SAndroid Build Coastguard Worker [ 917*da0073e9SAndroid Build Coastguard Worker {"rate": torch.tensor([0.0, 0.0], requires_grad=True)}, 918*da0073e9SAndroid Build Coastguard Worker {"rate": torch.tensor([-2.0], requires_grad=True)}, 919*da0073e9SAndroid Build Coastguard Worker ], 920*da0073e9SAndroid Build Coastguard Worker ), 921*da0073e9SAndroid Build Coastguard Worker Example( 922*da0073e9SAndroid Build Coastguard Worker FisherSnedecor, 923*da0073e9SAndroid Build Coastguard Worker [ 924*da0073e9SAndroid Build Coastguard Worker { 925*da0073e9SAndroid Build Coastguard Worker "df1": torch.tensor([0.0, 0.0], requires_grad=True), 926*da0073e9SAndroid Build Coastguard Worker "df2": torch.tensor([-1.0, -100.0], requires_grad=True), 927*da0073e9SAndroid Build Coastguard Worker }, 928*da0073e9SAndroid Build Coastguard Worker { 929*da0073e9SAndroid Build Coastguard Worker "df1": torch.tensor([1.0, 1.0], requires_grad=True), 930*da0073e9SAndroid Build Coastguard Worker "df2": torch.tensor([0.0, 0.0], requires_grad=True), 931*da0073e9SAndroid Build Coastguard Worker }, 932*da0073e9SAndroid Build Coastguard Worker ], 933*da0073e9SAndroid Build Coastguard Worker ), 934*da0073e9SAndroid Build Coastguard Worker Example( 935*da0073e9SAndroid Build Coastguard Worker Gamma, 936*da0073e9SAndroid Build Coastguard Worker [ 937*da0073e9SAndroid Build Coastguard Worker { 938*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([0.0, 0.0], requires_grad=True), 939*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([-1.0, -100.0], requires_grad=True), 940*da0073e9SAndroid Build Coastguard Worker }, 941*da0073e9SAndroid Build Coastguard Worker { 942*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([1.0, 1.0], requires_grad=True), 943*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([0.0, 0.0], requires_grad=True), 944*da0073e9SAndroid Build Coastguard Worker }, 945*da0073e9SAndroid Build Coastguard Worker ], 946*da0073e9SAndroid Build Coastguard Worker ), 947*da0073e9SAndroid Build Coastguard Worker Example( 948*da0073e9SAndroid Build Coastguard Worker Gumbel, 949*da0073e9SAndroid Build Coastguard Worker [ 950*da0073e9SAndroid Build Coastguard Worker { 951*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 952*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0, 1.0], requires_grad=True), 953*da0073e9SAndroid Build Coastguard Worker }, 954*da0073e9SAndroid Build Coastguard Worker { 955*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 956*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0, -1.0], requires_grad=True), 957*da0073e9SAndroid Build Coastguard Worker }, 958*da0073e9SAndroid Build Coastguard Worker ], 959*da0073e9SAndroid Build Coastguard Worker ), 960*da0073e9SAndroid Build Coastguard Worker Example( 961*da0073e9SAndroid Build Coastguard Worker HalfCauchy, 962*da0073e9SAndroid Build Coastguard Worker [ 963*da0073e9SAndroid Build Coastguard Worker {"scale": -1.0}, 964*da0073e9SAndroid Build Coastguard Worker {"scale": 0.0}, 965*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([[-0.000001], [1.0]])}, 966*da0073e9SAndroid Build Coastguard Worker ], 967*da0073e9SAndroid Build Coastguard Worker ), 968*da0073e9SAndroid Build Coastguard Worker Example( 969*da0073e9SAndroid Build Coastguard Worker HalfNormal, 970*da0073e9SAndroid Build Coastguard Worker [ 971*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([0.0, 1.0], requires_grad=True)}, 972*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([1.0, -1.0], requires_grad=True)}, 973*da0073e9SAndroid Build Coastguard Worker ], 974*da0073e9SAndroid Build Coastguard Worker ), 975*da0073e9SAndroid Build Coastguard Worker Example( 976*da0073e9SAndroid Build Coastguard Worker LKJCholesky, 977*da0073e9SAndroid Build Coastguard Worker [ 978*da0073e9SAndroid Build Coastguard Worker {"dim": -2, "concentration": 0.1}, 979*da0073e9SAndroid Build Coastguard Worker { 980*da0073e9SAndroid Build Coastguard Worker "dim": 1, 981*da0073e9SAndroid Build Coastguard Worker "concentration": 2.0, 982*da0073e9SAndroid Build Coastguard Worker }, 983*da0073e9SAndroid Build Coastguard Worker { 984*da0073e9SAndroid Build Coastguard Worker "dim": 2, 985*da0073e9SAndroid Build Coastguard Worker "concentration": 0.0, 986*da0073e9SAndroid Build Coastguard Worker }, 987*da0073e9SAndroid Build Coastguard Worker ], 988*da0073e9SAndroid Build Coastguard Worker ), 989*da0073e9SAndroid Build Coastguard Worker Example( 990*da0073e9SAndroid Build Coastguard Worker Laplace, 991*da0073e9SAndroid Build Coastguard Worker [ 992*da0073e9SAndroid Build Coastguard Worker { 993*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 994*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0, 1.0], requires_grad=True), 995*da0073e9SAndroid Build Coastguard Worker }, 996*da0073e9SAndroid Build Coastguard Worker { 997*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 998*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0, -1.0], requires_grad=True), 999*da0073e9SAndroid Build Coastguard Worker }, 1000*da0073e9SAndroid Build Coastguard Worker ], 1001*da0073e9SAndroid Build Coastguard Worker ), 1002*da0073e9SAndroid Build Coastguard Worker Example( 1003*da0073e9SAndroid Build Coastguard Worker LogNormal, 1004*da0073e9SAndroid Build Coastguard Worker [ 1005*da0073e9SAndroid Build Coastguard Worker { 1006*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1007*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0, 1.0], requires_grad=True), 1008*da0073e9SAndroid Build Coastguard Worker }, 1009*da0073e9SAndroid Build Coastguard Worker { 1010*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1011*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0, -1.0], requires_grad=True), 1012*da0073e9SAndroid Build Coastguard Worker }, 1013*da0073e9SAndroid Build Coastguard Worker ], 1014*da0073e9SAndroid Build Coastguard Worker ), 1015*da0073e9SAndroid Build Coastguard Worker Example( 1016*da0073e9SAndroid Build Coastguard Worker MultivariateNormal, 1017*da0073e9SAndroid Build Coastguard Worker [ 1018*da0073e9SAndroid Build Coastguard Worker { 1019*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1020*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 1021*da0073e9SAndroid Build Coastguard Worker [[1.0, 0.0], [0.0, -2.0]], requires_grad=True 1022*da0073e9SAndroid Build Coastguard Worker ), 1023*da0073e9SAndroid Build Coastguard Worker }, 1024*da0073e9SAndroid Build Coastguard Worker ], 1025*da0073e9SAndroid Build Coastguard Worker ), 1026*da0073e9SAndroid Build Coastguard Worker Example( 1027*da0073e9SAndroid Build Coastguard Worker Normal, 1028*da0073e9SAndroid Build Coastguard Worker [ 1029*da0073e9SAndroid Build Coastguard Worker { 1030*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1031*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0, 1.0], requires_grad=True), 1032*da0073e9SAndroid Build Coastguard Worker }, 1033*da0073e9SAndroid Build Coastguard Worker { 1034*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1035*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0, -1.0], requires_grad=True), 1036*da0073e9SAndroid Build Coastguard Worker }, 1037*da0073e9SAndroid Build Coastguard Worker { 1038*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([1.0, 0.0], requires_grad=True), 1039*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1e-5, -1e-5], requires_grad=True), 1040*da0073e9SAndroid Build Coastguard Worker }, 1041*da0073e9SAndroid Build Coastguard Worker ], 1042*da0073e9SAndroid Build Coastguard Worker ), 1043*da0073e9SAndroid Build Coastguard Worker Example( 1044*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 1045*da0073e9SAndroid Build Coastguard Worker [ 1046*da0073e9SAndroid Build Coastguard Worker { 1047*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 1048*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True 1049*da0073e9SAndroid Build Coastguard Worker ) 1050*da0073e9SAndroid Build Coastguard Worker }, 1051*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 1052*da0073e9SAndroid Build Coastguard Worker ], 1053*da0073e9SAndroid Build Coastguard Worker ), 1054*da0073e9SAndroid Build Coastguard Worker Example( 1055*da0073e9SAndroid Build Coastguard Worker OneHotCategoricalStraightThrough, 1056*da0073e9SAndroid Build Coastguard Worker [ 1057*da0073e9SAndroid Build Coastguard Worker { 1058*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 1059*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True 1060*da0073e9SAndroid Build Coastguard Worker ) 1061*da0073e9SAndroid Build Coastguard Worker }, 1062*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 1063*da0073e9SAndroid Build Coastguard Worker ], 1064*da0073e9SAndroid Build Coastguard Worker ), 1065*da0073e9SAndroid Build Coastguard Worker Example( 1066*da0073e9SAndroid Build Coastguard Worker Pareto, 1067*da0073e9SAndroid Build Coastguard Worker [ 1068*da0073e9SAndroid Build Coastguard Worker {"scale": 0.0, "alpha": 0.0}, 1069*da0073e9SAndroid Build Coastguard Worker { 1070*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0, 0.0], requires_grad=True), 1071*da0073e9SAndroid Build Coastguard Worker "alpha": torch.tensor([-1e-5, 0.0], requires_grad=True), 1072*da0073e9SAndroid Build Coastguard Worker }, 1073*da0073e9SAndroid Build Coastguard Worker {"scale": torch.tensor([1.0]), "alpha": -1.0}, 1074*da0073e9SAndroid Build Coastguard Worker ], 1075*da0073e9SAndroid Build Coastguard Worker ), 1076*da0073e9SAndroid Build Coastguard Worker Example( 1077*da0073e9SAndroid Build Coastguard Worker Poisson, 1078*da0073e9SAndroid Build Coastguard Worker [ 1079*da0073e9SAndroid Build Coastguard Worker { 1080*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([-0.1], requires_grad=True), 1081*da0073e9SAndroid Build Coastguard Worker }, 1082*da0073e9SAndroid Build Coastguard Worker { 1083*da0073e9SAndroid Build Coastguard Worker "rate": -1.0, 1084*da0073e9SAndroid Build Coastguard Worker }, 1085*da0073e9SAndroid Build Coastguard Worker ], 1086*da0073e9SAndroid Build Coastguard Worker ), 1087*da0073e9SAndroid Build Coastguard Worker Example( 1088*da0073e9SAndroid Build Coastguard Worker RelaxedBernoulli, 1089*da0073e9SAndroid Build Coastguard Worker [ 1090*da0073e9SAndroid Build Coastguard Worker { 1091*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([1.5], requires_grad=True), 1092*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([1.7, 0.2, 0.4], requires_grad=True), 1093*da0073e9SAndroid Build Coastguard Worker }, 1094*da0073e9SAndroid Build Coastguard Worker { 1095*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([2.0]), 1096*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([-1.0]), 1097*da0073e9SAndroid Build Coastguard Worker }, 1098*da0073e9SAndroid Build Coastguard Worker ], 1099*da0073e9SAndroid Build Coastguard Worker ), 1100*da0073e9SAndroid Build Coastguard Worker Example( 1101*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical, 1102*da0073e9SAndroid Build Coastguard Worker [ 1103*da0073e9SAndroid Build Coastguard Worker { 1104*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([0.5], requires_grad=True), 1105*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor( 1106*da0073e9SAndroid Build Coastguard Worker [[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True 1107*da0073e9SAndroid Build Coastguard Worker ), 1108*da0073e9SAndroid Build Coastguard Worker }, 1109*da0073e9SAndroid Build Coastguard Worker { 1110*da0073e9SAndroid Build Coastguard Worker "temperature": torch.tensor([2.0]), 1111*da0073e9SAndroid Build Coastguard Worker "probs": torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]), 1112*da0073e9SAndroid Build Coastguard Worker }, 1113*da0073e9SAndroid Build Coastguard Worker ], 1114*da0073e9SAndroid Build Coastguard Worker ), 1115*da0073e9SAndroid Build Coastguard Worker Example( 1116*da0073e9SAndroid Build Coastguard Worker TransformedDistribution, 1117*da0073e9SAndroid Build Coastguard Worker [ 1118*da0073e9SAndroid Build Coastguard Worker { 1119*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal(0, 1), 1120*da0073e9SAndroid Build Coastguard Worker "transforms": lambda x: x, 1121*da0073e9SAndroid Build Coastguard Worker }, 1122*da0073e9SAndroid Build Coastguard Worker { 1123*da0073e9SAndroid Build Coastguard Worker "base_distribution": Normal(0, 1), 1124*da0073e9SAndroid Build Coastguard Worker "transforms": [lambda x: x], 1125*da0073e9SAndroid Build Coastguard Worker }, 1126*da0073e9SAndroid Build Coastguard Worker ], 1127*da0073e9SAndroid Build Coastguard Worker ), 1128*da0073e9SAndroid Build Coastguard Worker Example( 1129*da0073e9SAndroid Build Coastguard Worker Uniform, 1130*da0073e9SAndroid Build Coastguard Worker [ 1131*da0073e9SAndroid Build Coastguard Worker { 1132*da0073e9SAndroid Build Coastguard Worker "low": torch.tensor([2.0], requires_grad=True), 1133*da0073e9SAndroid Build Coastguard Worker "high": torch.tensor([2.0], requires_grad=True), 1134*da0073e9SAndroid Build Coastguard Worker }, 1135*da0073e9SAndroid Build Coastguard Worker { 1136*da0073e9SAndroid Build Coastguard Worker "low": torch.tensor([0.0], requires_grad=True), 1137*da0073e9SAndroid Build Coastguard Worker "high": torch.tensor([0.0], requires_grad=True), 1138*da0073e9SAndroid Build Coastguard Worker }, 1139*da0073e9SAndroid Build Coastguard Worker { 1140*da0073e9SAndroid Build Coastguard Worker "low": torch.tensor([1.0], requires_grad=True), 1141*da0073e9SAndroid Build Coastguard Worker "high": torch.tensor([0.0], requires_grad=True), 1142*da0073e9SAndroid Build Coastguard Worker }, 1143*da0073e9SAndroid Build Coastguard Worker ], 1144*da0073e9SAndroid Build Coastguard Worker ), 1145*da0073e9SAndroid Build Coastguard Worker Example( 1146*da0073e9SAndroid Build Coastguard Worker Weibull, 1147*da0073e9SAndroid Build Coastguard Worker [ 1148*da0073e9SAndroid Build Coastguard Worker { 1149*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([0.0], requires_grad=True), 1150*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([0.0], requires_grad=True), 1151*da0073e9SAndroid Build Coastguard Worker }, 1152*da0073e9SAndroid Build Coastguard Worker { 1153*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0], requires_grad=True), 1154*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([-1.0], requires_grad=True), 1155*da0073e9SAndroid Build Coastguard Worker }, 1156*da0073e9SAndroid Build Coastguard Worker ], 1157*da0073e9SAndroid Build Coastguard Worker ), 1158*da0073e9SAndroid Build Coastguard Worker Example( 1159*da0073e9SAndroid Build Coastguard Worker Wishart, 1160*da0073e9SAndroid Build Coastguard Worker [ 1161*da0073e9SAndroid Build Coastguard Worker { 1162*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 1163*da0073e9SAndroid Build Coastguard Worker [[1.0, 0.0], [0.0, -2.0]], requires_grad=True 1164*da0073e9SAndroid Build Coastguard Worker ), 1165*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([1.5], requires_grad=True), 1166*da0073e9SAndroid Build Coastguard Worker }, 1167*da0073e9SAndroid Build Coastguard Worker { 1168*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 1169*da0073e9SAndroid Build Coastguard Worker [[1.0, 1.0], [1.0, -2.0]], requires_grad=True 1170*da0073e9SAndroid Build Coastguard Worker ), 1171*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([3.0], requires_grad=True), 1172*da0073e9SAndroid Build Coastguard Worker }, 1173*da0073e9SAndroid Build Coastguard Worker { 1174*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": torch.tensor( 1175*da0073e9SAndroid Build Coastguard Worker [[1.0, 1.0], [1.0, -2.0]], requires_grad=True 1176*da0073e9SAndroid Build Coastguard Worker ), 1177*da0073e9SAndroid Build Coastguard Worker "df": 3.0, 1178*da0073e9SAndroid Build Coastguard Worker }, 1179*da0073e9SAndroid Build Coastguard Worker ], 1180*da0073e9SAndroid Build Coastguard Worker ), 1181*da0073e9SAndroid Build Coastguard Worker Example( 1182*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli, 1183*da0073e9SAndroid Build Coastguard Worker [ 1184*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 1185*da0073e9SAndroid Build Coastguard Worker {"probs": torch.tensor([-0.5], requires_grad=True)}, 1186*da0073e9SAndroid Build Coastguard Worker {"probs": 1.00001}, 1187*da0073e9SAndroid Build Coastguard Worker ], 1188*da0073e9SAndroid Build Coastguard Worker ), 1189*da0073e9SAndroid Build Coastguard Worker Example( 1190*da0073e9SAndroid Build Coastguard Worker InverseGamma, 1191*da0073e9SAndroid Build Coastguard Worker [ 1192*da0073e9SAndroid Build Coastguard Worker { 1193*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([0.0, 0.0], requires_grad=True), 1194*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([-1.0, -100.0], requires_grad=True), 1195*da0073e9SAndroid Build Coastguard Worker }, 1196*da0073e9SAndroid Build Coastguard Worker { 1197*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([1.0, 1.0], requires_grad=True), 1198*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([0.0, 0.0], requires_grad=True), 1199*da0073e9SAndroid Build Coastguard Worker }, 1200*da0073e9SAndroid Build Coastguard Worker ], 1201*da0073e9SAndroid Build Coastguard Worker ), 1202*da0073e9SAndroid Build Coastguard Worker ] 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Workerclass DistributionsTestCase(TestCase): 1206*da0073e9SAndroid Build Coastguard Worker def setUp(self): 1207*da0073e9SAndroid Build Coastguard Worker """The tests assume that the validation flag is set.""" 1208*da0073e9SAndroid Build Coastguard Worker torch.distributions.Distribution.set_default_validate_args(True) 1209*da0073e9SAndroid Build Coastguard Worker super().setUp() 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 1213*da0073e9SAndroid Build Coastguard Workerclass TestDistributions(DistributionsTestCase): 1214*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 1215*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker def _gradcheck_log_prob(self, dist_ctor, ctor_params): 1218*da0073e9SAndroid Build Coastguard Worker # performs gradient checks on log_prob 1219*da0073e9SAndroid Build Coastguard Worker distribution = dist_ctor(*ctor_params) 1220*da0073e9SAndroid Build Coastguard Worker s = distribution.sample() 1221*da0073e9SAndroid Build Coastguard Worker if not distribution.support.is_discrete: 1222*da0073e9SAndroid Build Coastguard Worker s = s.detach().requires_grad_() 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker expected_shape = distribution.batch_shape + distribution.event_shape 1225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.size(), expected_shape) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker def apply_fn(s, *params): 1228*da0073e9SAndroid Build Coastguard Worker return dist_ctor(*params).log_prob(s) 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True) 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker def _check_forward_ad(self, fn): 1233*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 1234*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.0) 1235*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0) 1236*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(x, t) 1237*da0073e9SAndroid Build Coastguard Worker dual_out = fn(dual) 1238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1239*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0 1240*da0073e9SAndroid Build Coastguard Worker ) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker def _check_log_prob(self, dist, asset_fn): 1243*da0073e9SAndroid Build Coastguard Worker # checks that the log_prob matches a reference function 1244*da0073e9SAndroid Build Coastguard Worker s = dist.sample() 1245*da0073e9SAndroid Build Coastguard Worker log_probs = dist.log_prob(s) 1246*da0073e9SAndroid Build Coastguard Worker log_probs_data_flat = log_probs.view(-1) 1247*da0073e9SAndroid Build Coastguard Worker s_data_flat = s.view(len(log_probs_data_flat), -1) 1248*da0073e9SAndroid Build Coastguard Worker for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)): 1249*da0073e9SAndroid Build Coastguard Worker asset_fn(i, val.squeeze(), log_prob) 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker def _check_sampler_sampler( 1252*da0073e9SAndroid Build Coastguard Worker self, 1253*da0073e9SAndroid Build Coastguard Worker torch_dist, 1254*da0073e9SAndroid Build Coastguard Worker ref_dist, 1255*da0073e9SAndroid Build Coastguard Worker message, 1256*da0073e9SAndroid Build Coastguard Worker multivariate=False, 1257*da0073e9SAndroid Build Coastguard Worker circular=False, 1258*da0073e9SAndroid Build Coastguard Worker num_samples=10000, 1259*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-3, 1260*da0073e9SAndroid Build Coastguard Worker ): 1261*da0073e9SAndroid Build Coastguard Worker # Checks that the .sample() method matches a reference function. 1262*da0073e9SAndroid Build Coastguard Worker torch_samples = torch_dist.sample((num_samples,)).squeeze() 1263*da0073e9SAndroid Build Coastguard Worker torch_samples = torch_samples.cpu().numpy() 1264*da0073e9SAndroid Build Coastguard Worker ref_samples = ref_dist.rvs(num_samples).astype(np.float64) 1265*da0073e9SAndroid Build Coastguard Worker if multivariate: 1266*da0073e9SAndroid Build Coastguard Worker # Project onto a random axis. 1267*da0073e9SAndroid Build Coastguard Worker axis = np.random.normal(size=(1,) + torch_samples.shape[1:]) 1268*da0073e9SAndroid Build Coastguard Worker axis /= np.linalg.norm(axis) 1269*da0073e9SAndroid Build Coastguard Worker torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1) 1270*da0073e9SAndroid Build Coastguard Worker ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1) 1271*da0073e9SAndroid Build Coastguard Worker samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples] 1272*da0073e9SAndroid Build Coastguard Worker if circular: 1273*da0073e9SAndroid Build Coastguard Worker samples = [(np.cos(x), v) for (x, v) in samples] 1274*da0073e9SAndroid Build Coastguard Worker shuffle( 1275*da0073e9SAndroid Build Coastguard Worker samples 1276*da0073e9SAndroid Build Coastguard Worker ) # necessary to prevent stable sort from making uneven bins for discrete 1277*da0073e9SAndroid Build Coastguard Worker samples.sort(key=lambda x: x[0]) 1278*da0073e9SAndroid Build Coastguard Worker samples = np.array(samples)[:, 1] 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker # Aggregate into bins filled with roughly zero-mean unit-variance RVs. 1281*da0073e9SAndroid Build Coastguard Worker num_bins = 10 1282*da0073e9SAndroid Build Coastguard Worker samples_per_bin = len(samples) // num_bins 1283*da0073e9SAndroid Build Coastguard Worker bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1) 1284*da0073e9SAndroid Build Coastguard Worker stddev = samples_per_bin**-0.5 1285*da0073e9SAndroid Build Coastguard Worker threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins) 1286*da0073e9SAndroid Build Coastguard Worker message = f"{message}.sample() is biased:\n{bins}" 1287*da0073e9SAndroid Build Coastguard Worker for bias in bins: 1288*da0073e9SAndroid Build Coastguard Worker self.assertLess(-threshold, bias, message) 1289*da0073e9SAndroid Build Coastguard Worker self.assertLess(bias, threshold, message) 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1292*da0073e9SAndroid Build Coastguard Worker def _check_sampler_discrete( 1293*da0073e9SAndroid Build Coastguard Worker self, torch_dist, ref_dist, message, num_samples=10000, failure_rate=1e-3 1294*da0073e9SAndroid Build Coastguard Worker ): 1295*da0073e9SAndroid Build Coastguard Worker """Runs a Chi2-test for the support, but ignores tail instead of combining""" 1296*da0073e9SAndroid Build Coastguard Worker torch_samples = torch_dist.sample((num_samples,)).squeeze() 1297*da0073e9SAndroid Build Coastguard Worker torch_samples = ( 1298*da0073e9SAndroid Build Coastguard Worker torch_samples.float() 1299*da0073e9SAndroid Build Coastguard Worker if torch_samples.dtype == torch.bfloat16 1300*da0073e9SAndroid Build Coastguard Worker else torch_samples 1301*da0073e9SAndroid Build Coastguard Worker ) 1302*da0073e9SAndroid Build Coastguard Worker torch_samples = torch_samples.cpu().numpy() 1303*da0073e9SAndroid Build Coastguard Worker unique, counts = np.unique(torch_samples, return_counts=True) 1304*da0073e9SAndroid Build Coastguard Worker pmf = ref_dist.pmf(unique) 1305*da0073e9SAndroid Build Coastguard Worker pmf = pmf / pmf.sum() # renormalize to 1.0 for chisq test 1306*da0073e9SAndroid Build Coastguard Worker msk = (counts > 5) & ((pmf * num_samples) > 5) 1307*da0073e9SAndroid Build Coastguard Worker self.assertGreater( 1308*da0073e9SAndroid Build Coastguard Worker pmf[msk].sum(), 1309*da0073e9SAndroid Build Coastguard Worker 0.9, 1310*da0073e9SAndroid Build Coastguard Worker "Distribution is too sparse for test; try increasing num_samples", 1311*da0073e9SAndroid Build Coastguard Worker ) 1312*da0073e9SAndroid Build Coastguard Worker # Add a remainder bucket that combines counts for all values 1313*da0073e9SAndroid Build Coastguard Worker # below threshold, if such values exist (i.e. mask has False entries). 1314*da0073e9SAndroid Build Coastguard Worker if not msk.all(): 1315*da0073e9SAndroid Build Coastguard Worker counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)]) 1316*da0073e9SAndroid Build Coastguard Worker pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)]) 1317*da0073e9SAndroid Build Coastguard Worker chisq, p = scipy.stats.chisquare(counts, pmf * num_samples) 1318*da0073e9SAndroid Build Coastguard Worker self.assertGreater(p, failure_rate, message) 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker def _check_enumerate_support(self, dist, examples): 1321*da0073e9SAndroid Build Coastguard Worker for params, expected in examples: 1322*da0073e9SAndroid Build Coastguard Worker params = {k: torch.tensor(v) for k, v in params.items()} 1323*da0073e9SAndroid Build Coastguard Worker d = dist(**params) 1324*da0073e9SAndroid Build Coastguard Worker actual = d.enumerate_support(expand=False) 1325*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(expected, dtype=actual.dtype) 1326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1327*da0073e9SAndroid Build Coastguard Worker actual = d.enumerate_support(expand=True) 1328*da0073e9SAndroid Build Coastguard Worker expected_with_expand = expected.expand( 1329*da0073e9SAndroid Build Coastguard Worker (-1,) + d.batch_shape + d.event_shape 1330*da0073e9SAndroid Build Coastguard Worker ) 1331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected_with_expand) 1332*da0073e9SAndroid Build Coastguard Worker 1333*da0073e9SAndroid Build Coastguard Worker def test_repr(self): 1334*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1335*da0073e9SAndroid Build Coastguard Worker for param in params: 1336*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 1337*da0073e9SAndroid Build Coastguard Worker self.assertTrue(repr(dist).startswith(dist.__class__.__name__)) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker def test_sample_detached(self): 1340*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1341*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 1342*da0073e9SAndroid Build Coastguard Worker variable_params = [ 1343*da0073e9SAndroid Build Coastguard Worker p for p in param.values() if getattr(p, "requires_grad", False) 1344*da0073e9SAndroid Build Coastguard Worker ] 1345*da0073e9SAndroid Build Coastguard Worker if not variable_params: 1346*da0073e9SAndroid Build Coastguard Worker continue 1347*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 1348*da0073e9SAndroid Build Coastguard Worker sample = dist.sample() 1349*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 1350*da0073e9SAndroid Build Coastguard Worker sample.requires_grad, 1351*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .sample() is not detached", 1352*da0073e9SAndroid Build Coastguard Worker ) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 1355*da0073e9SAndroid Build Coastguard Worker def test_rsample_requires_grad(self): 1356*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1357*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 1358*da0073e9SAndroid Build Coastguard Worker if not any(getattr(p, "requires_grad", False) for p in param.values()): 1359*da0073e9SAndroid Build Coastguard Worker continue 1360*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 1361*da0073e9SAndroid Build Coastguard Worker if not dist.has_rsample: 1362*da0073e9SAndroid Build Coastguard Worker continue 1363*da0073e9SAndroid Build Coastguard Worker sample = dist.rsample() 1364*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1365*da0073e9SAndroid Build Coastguard Worker sample.requires_grad, 1366*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .rsample() does not require grad", 1367*da0073e9SAndroid Build Coastguard Worker ) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker def test_enumerate_support_type(self): 1370*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1371*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 1372*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 1373*da0073e9SAndroid Build Coastguard Worker try: 1374*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1375*da0073e9SAndroid Build Coastguard Worker type(dist.sample()) is type(dist.enumerate_support()), 1376*da0073e9SAndroid Build Coastguard Worker msg=( 1377*da0073e9SAndroid Build Coastguard Worker "{} example {}/{}, return type mismatch between " 1378*da0073e9SAndroid Build Coastguard Worker + "sample and enumerate_support." 1379*da0073e9SAndroid Build Coastguard Worker ).format(Dist.__name__, i + 1, len(params)), 1380*da0073e9SAndroid Build Coastguard Worker ) 1381*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 1382*da0073e9SAndroid Build Coastguard Worker pass 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker def test_lazy_property_grad(self): 1385*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, requires_grad=True) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker class Dummy: 1388*da0073e9SAndroid Build Coastguard Worker @lazy_property 1389*da0073e9SAndroid Build Coastguard Worker def y(self): 1390*da0073e9SAndroid Build Coastguard Worker return x + 1 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def test(): 1393*da0073e9SAndroid Build Coastguard Worker x.grad = None 1394*da0073e9SAndroid Build Coastguard Worker Dummy().y.backward() 1395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(1)) 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker test() 1398*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1399*da0073e9SAndroid Build Coastguard Worker test() 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(2) 1402*da0073e9SAndroid Build Coastguard Worker cov = torch.eye(2, requires_grad=True) 1403*da0073e9SAndroid Build Coastguard Worker distn = MultivariateNormal(mean, cov) 1404*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1405*da0073e9SAndroid Build Coastguard Worker distn.scale_tril 1406*da0073e9SAndroid Build Coastguard Worker distn.scale_tril.sum().backward() 1407*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(cov.grad) 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker def test_has_examples(self): 1410*da0073e9SAndroid Build Coastguard Worker distributions_with_examples = {e.Dist for e in _get_examples()} 1411*da0073e9SAndroid Build Coastguard Worker for Dist in globals().values(): 1412*da0073e9SAndroid Build Coastguard Worker if ( 1413*da0073e9SAndroid Build Coastguard Worker isinstance(Dist, type) 1414*da0073e9SAndroid Build Coastguard Worker and issubclass(Dist, Distribution) 1415*da0073e9SAndroid Build Coastguard Worker and Dist is not Distribution 1416*da0073e9SAndroid Build Coastguard Worker and Dist is not ExponentialFamily 1417*da0073e9SAndroid Build Coastguard Worker ): 1418*da0073e9SAndroid Build Coastguard Worker self.assertIn( 1419*da0073e9SAndroid Build Coastguard Worker Dist, 1420*da0073e9SAndroid Build Coastguard Worker distributions_with_examples, 1421*da0073e9SAndroid Build Coastguard Worker f"Please add {Dist.__name__} to the _get_examples list in test_distributions.py", 1422*da0073e9SAndroid Build Coastguard Worker ) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def test_support_attributes(self): 1425*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1426*da0073e9SAndroid Build Coastguard Worker for param in params: 1427*da0073e9SAndroid Build Coastguard Worker d = Dist(**param) 1428*da0073e9SAndroid Build Coastguard Worker event_dim = len(d.event_shape) 1429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.support.event_dim, event_dim) 1430*da0073e9SAndroid Build Coastguard Worker try: 1431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dist.support.event_dim, event_dim) 1432*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 1433*da0073e9SAndroid Build Coastguard Worker pass 1434*da0073e9SAndroid Build Coastguard Worker is_discrete = d.support.is_discrete 1435*da0073e9SAndroid Build Coastguard Worker try: 1436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dist.support.is_discrete, is_discrete) 1437*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 1438*da0073e9SAndroid Build Coastguard Worker pass 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker def test_distribution_expand(self): 1441*da0073e9SAndroid Build Coastguard Worker shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))] 1442*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1443*da0073e9SAndroid Build Coastguard Worker for param in params: 1444*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1445*da0073e9SAndroid Build Coastguard Worker d = Dist(**param) 1446*da0073e9SAndroid Build Coastguard Worker expanded_shape = shape + d.batch_shape 1447*da0073e9SAndroid Build Coastguard Worker original_shape = d.batch_shape + d.event_shape 1448*da0073e9SAndroid Build Coastguard Worker expected_shape = shape + original_shape 1449*da0073e9SAndroid Build Coastguard Worker expanded = d.expand(batch_shape=list(expanded_shape)) 1450*da0073e9SAndroid Build Coastguard Worker sample = expanded.sample() 1451*da0073e9SAndroid Build Coastguard Worker actual_shape = expanded.sample().shape 1452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.__class__, d.__class__) 1453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.sample().shape, original_shape) 1454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.log_prob(sample), d.log_prob(sample)) 1455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_shape, expected_shape) 1456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.batch_shape, expanded_shape) 1457*da0073e9SAndroid Build Coastguard Worker try: 1458*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1459*da0073e9SAndroid Build Coastguard Worker expanded.mean, d.mean.expand(expanded_shape + d.event_shape) 1460*da0073e9SAndroid Build Coastguard Worker ) 1461*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1462*da0073e9SAndroid Build Coastguard Worker expanded.variance, 1463*da0073e9SAndroid Build Coastguard Worker d.variance.expand(expanded_shape + d.event_shape), 1464*da0073e9SAndroid Build Coastguard Worker ) 1465*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 1466*da0073e9SAndroid Build Coastguard Worker pass 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker def test_distribution_subclass_expand(self): 1469*da0073e9SAndroid Build Coastguard Worker expand_by = torch.Size((2,)) 1470*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker class SubClass(Dist): 1473*da0073e9SAndroid Build Coastguard Worker pass 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker for param in params: 1476*da0073e9SAndroid Build Coastguard Worker d = SubClass(**param) 1477*da0073e9SAndroid Build Coastguard Worker expanded_shape = expand_by + d.batch_shape 1478*da0073e9SAndroid Build Coastguard Worker original_shape = d.batch_shape + d.event_shape 1479*da0073e9SAndroid Build Coastguard Worker expected_shape = expand_by + original_shape 1480*da0073e9SAndroid Build Coastguard Worker expanded = d.expand(batch_shape=expanded_shape) 1481*da0073e9SAndroid Build Coastguard Worker sample = expanded.sample() 1482*da0073e9SAndroid Build Coastguard Worker actual_shape = expanded.sample().shape 1483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.__class__, d.__class__) 1484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.sample().shape, original_shape) 1485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.log_prob(sample), d.log_prob(sample)) 1486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_shape, expected_shape) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1489*da0073e9SAndroid Build Coastguard Worker def test_bernoulli(self): 1490*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1491*da0073e9SAndroid Build Coastguard Worker r = torch.tensor(0.3, requires_grad=True) 1492*da0073e9SAndroid Build Coastguard Worker s = 0.3 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3)) 1494*da0073e9SAndroid Build Coastguard Worker self.assertFalse(Bernoulli(p).sample().requires_grad) 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,)) 1496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(r).sample().size(), ()) 1497*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1498*da0073e9SAndroid Build Coastguard Worker Bernoulli(r).sample((3, 2)).size(), 1499*da0073e9SAndroid Build Coastguard Worker ( 1500*da0073e9SAndroid Build Coastguard Worker 3, 1501*da0073e9SAndroid Build Coastguard Worker 2, 1502*da0073e9SAndroid Build Coastguard Worker ), 1503*da0073e9SAndroid Build Coastguard Worker ) 1504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(s).sample().size(), ()) 1505*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Bernoulli, (p,)) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, val, log_prob): 1508*da0073e9SAndroid Build Coastguard Worker prob = p[idx] 1509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, math.log(prob if val else 1 - prob)) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Bernoulli(p), ref_log_prob) 1512*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob) 1513*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, Bernoulli(r).rsample) 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker # check entropy computation 1516*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1517*da0073e9SAndroid Build Coastguard Worker Bernoulli(p).entropy(), 1518*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.6108, 0.5004, 0.6730]), 1519*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 1520*da0073e9SAndroid Build Coastguard Worker rtol=0, 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0])) 1523*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1524*da0073e9SAndroid Build Coastguard Worker Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4, rtol=0 1525*da0073e9SAndroid Build Coastguard Worker ) 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(torch.bernoulli) 1528*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.bernoulli_()) 1529*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.bernoulli_(x.clone().detach())) 1530*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.bernoulli_(x)) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_enumerate_support(self): 1533*da0073e9SAndroid Build Coastguard Worker examples = [ 1534*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1]}, [[0], [1]]), 1535*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1, 0.9]}, [[0], [1]]), 1536*da0073e9SAndroid Build Coastguard Worker ({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]), 1537*da0073e9SAndroid Build Coastguard Worker ] 1538*da0073e9SAndroid Build Coastguard Worker self._check_enumerate_support(Bernoulli, examples) 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_3d(self): 1541*da0073e9SAndroid Build Coastguard Worker p = torch.full((2, 3, 5), 0.5).requires_grad_() 1542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5)) 1543*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1544*da0073e9SAndroid Build Coastguard Worker Bernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5) 1545*da0073e9SAndroid Build Coastguard Worker ) 1546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1549*da0073e9SAndroid Build Coastguard Worker def test_geometric(self): 1550*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1551*da0073e9SAndroid Build Coastguard Worker r = torch.tensor(0.3, requires_grad=True) 1552*da0073e9SAndroid Build Coastguard Worker s = 0.3 1553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3)) 1554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(1).sample(), 0) 1555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(1).log_prob(torch.tensor(1.0)), -inf) 1556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(1).log_prob(torch.tensor(0.0)), 0) 1557*da0073e9SAndroid Build Coastguard Worker self.assertFalse(Geometric(p).sample().requires_grad) 1558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(r).sample((8,)).size(), (8,)) 1559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(r).sample().size(), ()) 1560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2)) 1561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Geometric(s).sample().size(), ()) 1562*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Geometric, (p,)) 1563*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: Geometric(0)) 1564*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, Geometric(r).rsample) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.geometric_(0.2)) 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1569*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1570*da0073e9SAndroid Build Coastguard Worker def test_geometric_log_prob_and_entropy(self): 1571*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1572*da0073e9SAndroid Build Coastguard Worker s = 0.3 1573*da0073e9SAndroid Build Coastguard Worker 1574*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, val, log_prob): 1575*da0073e9SAndroid Build Coastguard Worker prob = p[idx].detach() 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val)) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Geometric(p), ref_log_prob) 1579*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob) 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker # check entropy computation 1582*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1583*da0073e9SAndroid Build Coastguard Worker Geometric(p).entropy(), 1584*da0073e9SAndroid Build Coastguard Worker scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), 1585*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 1586*da0073e9SAndroid Build Coastguard Worker rtol=0, 1587*da0073e9SAndroid Build Coastguard Worker ) 1588*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1589*da0073e9SAndroid Build Coastguard Worker float(Geometric(s).entropy()), 1590*da0073e9SAndroid Build Coastguard Worker scipy.stats.geom(s, loc=-1).entropy().item(), 1591*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 1592*da0073e9SAndroid Build Coastguard Worker rtol=0, 1593*da0073e9SAndroid Build Coastguard Worker ) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1596*da0073e9SAndroid Build Coastguard Worker def test_geometric_sample(self): 1597*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 1598*da0073e9SAndroid Build Coastguard Worker for prob in [0.01, 0.18, 0.8]: 1599*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 1600*da0073e9SAndroid Build Coastguard Worker Geometric(prob), 1601*da0073e9SAndroid Build Coastguard Worker scipy.stats.geom(p=prob, loc=-1), 1602*da0073e9SAndroid Build Coastguard Worker f"Geometric(prob={prob})", 1603*da0073e9SAndroid Build Coastguard Worker ) 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1606*da0073e9SAndroid Build Coastguard Worker def test_binomial(self): 1607*da0073e9SAndroid Build Coastguard Worker p = torch.arange(0.05, 1, 0.1).requires_grad_() 1608*da0073e9SAndroid Build Coastguard Worker for total_count in [1, 2, 10]: 1609*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p]) 1610*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 1611*da0073e9SAndroid Build Coastguard Worker lambda p: Binomial(total_count, None, p.log()), [p] 1612*da0073e9SAndroid Build Coastguard Worker ) 1613*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, Binomial(10, p).rsample) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker test_binomial_half = set_default_dtype(torch.float16)(test_binomial) 1616*da0073e9SAndroid Build Coastguard Worker test_binomial_bfloat16 = set_default_dtype(torch.bfloat16)(test_binomial) 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1619*da0073e9SAndroid Build Coastguard Worker def test_binomial_sample(self): 1620*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 1621*da0073e9SAndroid Build Coastguard Worker for prob in [0.01, 0.1, 0.5, 0.8, 0.9]: 1622*da0073e9SAndroid Build Coastguard Worker for count in [2, 10, 100, 500]: 1623*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 1624*da0073e9SAndroid Build Coastguard Worker Binomial(total_count=count, probs=prob), 1625*da0073e9SAndroid Build Coastguard Worker scipy.stats.binom(count, prob), 1626*da0073e9SAndroid Build Coastguard Worker f"Binomial(total_count={count}, probs={prob})", 1627*da0073e9SAndroid Build Coastguard Worker ) 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1630*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1631*da0073e9SAndroid Build Coastguard Worker def test_binomial_log_prob_and_entropy(self): 1632*da0073e9SAndroid Build Coastguard Worker probs = torch.arange(0.05, 1, 0.1) 1633*da0073e9SAndroid Build Coastguard Worker for total_count in [1, 2, 10]: 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 1636*da0073e9SAndroid Build Coastguard Worker p = probs.view(-1)[idx].item() 1637*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.binom(total_count, p).logpmf(x) 1638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Binomial(total_count, probs), ref_log_prob) 1641*da0073e9SAndroid Build Coastguard Worker logits = probs_to_logits(probs, is_binary=True) 1642*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker bin = Binomial(total_count, logits=logits) 1645*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1646*da0073e9SAndroid Build Coastguard Worker bin.entropy(), 1647*da0073e9SAndroid Build Coastguard Worker scipy.stats.binom( 1648*da0073e9SAndroid Build Coastguard Worker total_count, bin.probs.detach().numpy(), loc=-1 1649*da0073e9SAndroid Build Coastguard Worker ).entropy(), 1650*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 1651*da0073e9SAndroid Build Coastguard Worker rtol=0, 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker def test_binomial_stable(self): 1655*da0073e9SAndroid Build Coastguard Worker logits = torch.tensor([-100.0, 100.0], dtype=torch.float) 1656*da0073e9SAndroid Build Coastguard Worker total_count = 1.0 1657*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.0, 0.0], dtype=torch.float) 1658*da0073e9SAndroid Build Coastguard Worker log_prob = Binomial(total_count, logits=logits).log_prob(x) 1659*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isfinite(log_prob).all()) 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker # make sure that the grad at logits=0, value=0 is 0.5 1662*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.0, requires_grad=True) 1663*da0073e9SAndroid Build Coastguard Worker y = Binomial(total_count, logits=x).log_prob(torch.tensor(0.0)) 1664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad(y, x)[0], torch.tensor(-0.5)) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1667*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1668*da0073e9SAndroid Build Coastguard Worker def test_binomial_log_prob_vectorized_count(self): 1669*da0073e9SAndroid Build Coastguard Worker probs = torch.tensor([0.2, 0.7, 0.9]) 1670*da0073e9SAndroid Build Coastguard Worker for total_count, sample in [ 1671*da0073e9SAndroid Build Coastguard Worker (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])), 1672*da0073e9SAndroid Build Coastguard Worker (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])), 1673*da0073e9SAndroid Build Coastguard Worker ]: 1674*da0073e9SAndroid Build Coastguard Worker log_prob = Binomial(total_count, probs).log_prob(sample) 1675*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.binom( 1676*da0073e9SAndroid Build Coastguard Worker total_count.cpu().numpy(), probs.cpu().numpy() 1677*da0073e9SAndroid Build Coastguard Worker ).logpmf(sample) 1678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-4, rtol=0) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker def test_binomial_enumerate_support(self): 1681*da0073e9SAndroid Build Coastguard Worker examples = [ 1682*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]), 1683*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]), 1684*da0073e9SAndroid Build Coastguard Worker ( 1685*da0073e9SAndroid Build Coastguard Worker {"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3}, 1686*da0073e9SAndroid Build Coastguard Worker [[[0]], [[1]], [[2]], [[3]]], 1687*da0073e9SAndroid Build Coastguard Worker ), 1688*da0073e9SAndroid Build Coastguard Worker ] 1689*da0073e9SAndroid Build Coastguard Worker self._check_enumerate_support(Binomial, examples) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1692*da0073e9SAndroid Build Coastguard Worker def test_binomial_extreme_vals(self): 1693*da0073e9SAndroid Build Coastguard Worker total_count = 100 1694*da0073e9SAndroid Build Coastguard Worker bin0 = Binomial(total_count, 0) 1695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin0.sample(), 0) 1696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin0.log_prob(torch.tensor([0.0]))[0], 0, atol=1e-3, rtol=0) 1697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float(bin0.log_prob(torch.tensor([1.0])).exp()), 0) 1698*da0073e9SAndroid Build Coastguard Worker bin1 = Binomial(total_count, 1) 1699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin1.sample(), total_count) 1700*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1701*da0073e9SAndroid Build Coastguard Worker bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0 1702*da0073e9SAndroid Build Coastguard Worker ) 1703*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1704*da0073e9SAndroid Build Coastguard Worker float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0 1705*da0073e9SAndroid Build Coastguard Worker ) 1706*da0073e9SAndroid Build Coastguard Worker zero_counts = torch.zeros(torch.Size((2, 2))) 1707*da0073e9SAndroid Build Coastguard Worker bin2 = Binomial(zero_counts, 1) 1708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin2.sample(), zero_counts) 1709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin2.log_prob(zero_counts), zero_counts) 1710*da0073e9SAndroid Build Coastguard Worker 1711*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1712*da0073e9SAndroid Build Coastguard Worker def test_binomial_vectorized_count(self): 1713*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 1714*da0073e9SAndroid Build Coastguard Worker total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64) 1715*da0073e9SAndroid Build Coastguard Worker bin0 = Binomial(total_count, torch.tensor(1.0)) 1716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin0.sample(), total_count) 1717*da0073e9SAndroid Build Coastguard Worker bin1 = Binomial(total_count, torch.tensor(0.5)) 1718*da0073e9SAndroid Build Coastguard Worker samples = bin1.sample(torch.Size((100000,))) 1719*da0073e9SAndroid Build Coastguard Worker self.assertTrue((samples <= total_count.type_as(samples)).all()) 1720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0) 1721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1724*da0073e9SAndroid Build Coastguard Worker def test_negative_binomial(self): 1725*da0073e9SAndroid Build Coastguard Worker p = torch.arange(0.05, 1, 0.1).requires_grad_() 1726*da0073e9SAndroid Build Coastguard Worker for total_count in [1, 2, 10]: 1727*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p]) 1728*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 1729*da0073e9SAndroid Build Coastguard Worker lambda p: NegativeBinomial(total_count, None, p.log()), [p] 1730*da0073e9SAndroid Build Coastguard Worker ) 1731*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample) 1732*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy) 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1735*da0073e9SAndroid Build Coastguard Worker def test_negative_binomial_log_prob(self): 1736*da0073e9SAndroid Build Coastguard Worker probs = torch.arange(0.05, 1, 0.1) 1737*da0073e9SAndroid Build Coastguard Worker for total_count in [1, 2, 10]: 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 1740*da0073e9SAndroid Build Coastguard Worker p = probs.view(-1)[idx].item() 1741*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x) 1742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob) 1745*da0073e9SAndroid Build Coastguard Worker logits = probs_to_logits(probs, is_binary=True) 1746*da0073e9SAndroid Build Coastguard Worker self._check_log_prob( 1747*da0073e9SAndroid Build Coastguard Worker NegativeBinomial(total_count, logits=logits), ref_log_prob 1748*da0073e9SAndroid Build Coastguard Worker ) 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1751*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1752*da0073e9SAndroid Build Coastguard Worker def test_negative_binomial_log_prob_vectorized_count(self): 1753*da0073e9SAndroid Build Coastguard Worker probs = torch.tensor([0.2, 0.7, 0.9]) 1754*da0073e9SAndroid Build Coastguard Worker for total_count, sample in [ 1755*da0073e9SAndroid Build Coastguard Worker (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])), 1756*da0073e9SAndroid Build Coastguard Worker (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])), 1757*da0073e9SAndroid Build Coastguard Worker ]: 1758*da0073e9SAndroid Build Coastguard Worker log_prob = NegativeBinomial(total_count, probs).log_prob(sample) 1759*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.nbinom( 1760*da0073e9SAndroid Build Coastguard Worker total_count.cpu().numpy(), 1 - probs.cpu().numpy() 1761*da0073e9SAndroid Build Coastguard Worker ).logpmf(sample) 1762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-4, rtol=0) 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not found") 1765*da0073e9SAndroid Build Coastguard Worker def test_zero_excluded_binomial(self): 1766*da0073e9SAndroid Build Coastguard Worker vals = Binomial( 1767*da0073e9SAndroid Build Coastguard Worker total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda() 1768*da0073e9SAndroid Build Coastguard Worker ).sample(torch.Size((100000000,))) 1769*da0073e9SAndroid Build Coastguard Worker self.assertTrue((vals >= 0).all()) 1770*da0073e9SAndroid Build Coastguard Worker vals = Binomial( 1771*da0073e9SAndroid Build Coastguard Worker total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda() 1772*da0073e9SAndroid Build Coastguard Worker ).sample(torch.Size((100000000,))) 1773*da0073e9SAndroid Build Coastguard Worker self.assertTrue((vals < 2).all()) 1774*da0073e9SAndroid Build Coastguard Worker vals = Binomial( 1775*da0073e9SAndroid Build Coastguard Worker total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda() 1776*da0073e9SAndroid Build Coastguard Worker ).sample(torch.Size((10000,))) 1777*da0073e9SAndroid Build Coastguard Worker # vals should be roughly half zeroes, half ones 1778*da0073e9SAndroid Build Coastguard Worker assert (vals == 0.0).sum() > 4000 1779*da0073e9SAndroid Build Coastguard Worker assert (vals == 1.0).sum() > 4000 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1782*da0073e9SAndroid Build Coastguard Worker def test_multinomial_1d(self): 1783*da0073e9SAndroid Build Coastguard Worker total_count = 10 1784*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1785*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Multinomial(total_count, p).sample().size(), (3,)) 1786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3)) 1787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3)) 1788*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) 1789*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) 1790*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, Multinomial(10, p).rsample) 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1793*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1794*da0073e9SAndroid Build Coastguard Worker def test_multinomial_1d_log_prob_and_entropy(self): 1795*da0073e9SAndroid Build Coastguard Worker total_count = 10 1796*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1797*da0073e9SAndroid Build Coastguard Worker dist = Multinomial(total_count, probs=p) 1798*da0073e9SAndroid Build Coastguard Worker x = dist.sample() 1799*da0073e9SAndroid Build Coastguard Worker log_prob = dist.log_prob(x) 1800*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 1801*da0073e9SAndroid Build Coastguard Worker scipy.stats.multinomial.logpmf( 1802*da0073e9SAndroid Build Coastguard Worker x.numpy(), n=total_count, p=dist.probs.detach().numpy() 1803*da0073e9SAndroid Build Coastguard Worker ) 1804*da0073e9SAndroid Build Coastguard Worker ) 1805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected) 1806*da0073e9SAndroid Build Coastguard Worker 1807*da0073e9SAndroid Build Coastguard Worker dist = Multinomial(total_count, logits=p.log()) 1808*da0073e9SAndroid Build Coastguard Worker x = dist.sample() 1809*da0073e9SAndroid Build Coastguard Worker log_prob = dist.log_prob(x) 1810*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 1811*da0073e9SAndroid Build Coastguard Worker scipy.stats.multinomial.logpmf( 1812*da0073e9SAndroid Build Coastguard Worker x.numpy(), n=total_count, p=dist.probs.detach().numpy() 1813*da0073e9SAndroid Build Coastguard Worker ) 1814*da0073e9SAndroid Build Coastguard Worker ) 1815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected) 1816*da0073e9SAndroid Build Coastguard Worker 1817*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.multinomial.entropy( 1818*da0073e9SAndroid Build Coastguard Worker total_count, dist.probs.detach().numpy() 1819*da0073e9SAndroid Build Coastguard Worker ) 1820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0) 1821*da0073e9SAndroid Build Coastguard Worker 1822*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1823*da0073e9SAndroid Build Coastguard Worker def test_multinomial_2d(self): 1824*da0073e9SAndroid Build Coastguard Worker total_count = 10 1825*da0073e9SAndroid Build Coastguard Worker probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1826*da0073e9SAndroid Build Coastguard Worker probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1827*da0073e9SAndroid Build Coastguard Worker p = torch.tensor(probabilities, requires_grad=True) 1828*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(probabilities_1, requires_grad=True) 1829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3)) 1830*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1831*da0073e9SAndroid Build Coastguard Worker Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3) 1832*da0073e9SAndroid Build Coastguard Worker ) 1833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3)) 1834*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) 1835*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) 1836*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of probs 1839*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1840*da0073e9SAndroid Build Coastguard Worker Multinomial(total_count, s).sample(), 1841*da0073e9SAndroid Build Coastguard Worker torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64), 1842*da0073e9SAndroid Build Coastguard Worker ) 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1845*da0073e9SAndroid Build Coastguard Worker def test_categorical_1d(self): 1846*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1847*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(Categorical(p).mean)) 1848*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(Categorical(p).variance)) 1849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample().size(), ()) 1850*da0073e9SAndroid Build Coastguard Worker self.assertFalse(Categorical(p).sample().requires_grad) 1851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2)) 1852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample((1,)).size(), (1,)) 1853*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Categorical, (p,)) 1854*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, Categorical(p).rsample) 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1857*da0073e9SAndroid Build Coastguard Worker def test_categorical_2d(self): 1858*da0073e9SAndroid Build Coastguard Worker probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1859*da0073e9SAndroid Build Coastguard Worker probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1860*da0073e9SAndroid Build Coastguard Worker p = torch.tensor(probabilities, requires_grad=True) 1861*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(probabilities_1, requires_grad=True) 1862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).mean.size(), (2,)) 1863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).variance.size(), (2,)) 1864*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(Categorical(p).mean)) 1865*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(Categorical(p).variance)) 1866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample().size(), (2,)) 1867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2)) 1868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2)) 1869*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Categorical, (p,)) 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of probs 1872*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) 1873*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1874*da0073e9SAndroid Build Coastguard Worker Categorical(s).sample(sample_shape=(2,)), torch.tensor([[0, 1], [0, 1]]) 1875*da0073e9SAndroid Build Coastguard Worker ) 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, val, log_prob): 1878*da0073e9SAndroid Build Coastguard Worker sample_prob = p[idx][val] / p[idx].sum() 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, math.log(sample_prob)) 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Categorical(p), ref_log_prob) 1882*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Categorical(logits=p.log()), ref_log_prob) 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker # check entropy computation 1885*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1886*da0073e9SAndroid Build Coastguard Worker Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0 1887*da0073e9SAndroid Build Coastguard Worker ) 1888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0])) 1889*da0073e9SAndroid Build Coastguard Worker # issue gh-40553 1890*da0073e9SAndroid Build Coastguard Worker logits = p.log() 1891*da0073e9SAndroid Build Coastguard Worker logits[1, 1] = logits[0, 2] = float("-inf") 1892*da0073e9SAndroid Build Coastguard Worker e = Categorical(logits=logits).entropy() 1893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker def test_categorical_enumerate_support(self): 1896*da0073e9SAndroid Build Coastguard Worker examples = [ 1897*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]), 1898*da0073e9SAndroid Build Coastguard Worker ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]), 1899*da0073e9SAndroid Build Coastguard Worker ] 1900*da0073e9SAndroid Build Coastguard Worker self._check_enumerate_support(Categorical, examples) 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1903*da0073e9SAndroid Build Coastguard Worker def test_one_hot_categorical_1d(self): 1904*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(OneHotCategorical(p).sample().size(), (3,)) 1906*da0073e9SAndroid Build Coastguard Worker self.assertFalse(OneHotCategorical(p).sample().requires_grad) 1907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3)) 1908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3)) 1909*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(OneHotCategorical, (p,)) 1910*da0073e9SAndroid Build Coastguard Worker self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample) 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1913*da0073e9SAndroid Build Coastguard Worker def test_one_hot_categorical_2d(self): 1914*da0073e9SAndroid Build Coastguard Worker probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1915*da0073e9SAndroid Build Coastguard Worker probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1916*da0073e9SAndroid Build Coastguard Worker p = torch.tensor(probabilities, requires_grad=True) 1917*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(probabilities_1, requires_grad=True) 1918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3)) 1919*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1920*da0073e9SAndroid Build Coastguard Worker OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3) 1921*da0073e9SAndroid Build Coastguard Worker ) 1922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3)) 1923*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(OneHotCategorical, (p,)) 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker dist = OneHotCategorical(p) 1926*da0073e9SAndroid Build Coastguard Worker x = dist.sample() 1927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1])) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker def test_one_hot_categorical_enumerate_support(self): 1930*da0073e9SAndroid Build Coastguard Worker examples = [ 1931*da0073e9SAndroid Build Coastguard Worker ({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 1932*da0073e9SAndroid Build Coastguard Worker ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]), 1933*da0073e9SAndroid Build Coastguard Worker ] 1934*da0073e9SAndroid Build Coastguard Worker self._check_enumerate_support(OneHotCategorical, examples) 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker def test_poisson_forward_ad(self): 1937*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(torch.poisson) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker def test_poisson_shape(self): 1940*da0073e9SAndroid Build Coastguard Worker rate = torch.randn(2, 3).abs().requires_grad_() 1941*da0073e9SAndroid Build Coastguard Worker rate_1d = torch.randn(1).abs().requires_grad_() 1942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Poisson(rate).sample().size(), (2, 3)) 1943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3)) 1944*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Poisson(rate_1d).sample().size(), (1,)) 1945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1)) 1946*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,)) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1949*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 1950*da0073e9SAndroid Build Coastguard Worker def test_poisson_log_prob(self): 1951*da0073e9SAndroid Build Coastguard Worker rate = torch.randn(2, 3).abs().requires_grad_() 1952*da0073e9SAndroid Build Coastguard Worker rate_1d = torch.randn(1).abs().requires_grad_() 1953*da0073e9SAndroid Build Coastguard Worker rate_zero = torch.zeros([], requires_grad=True) 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(ref_rate, idx, x, log_prob): 1956*da0073e9SAndroid Build Coastguard Worker l = ref_rate.view(-1)[idx].detach() 1957*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.poisson.logpmf(x, l) 1958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1959*da0073e9SAndroid Build Coastguard Worker 1960*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) 1961*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args)) 1962*da0073e9SAndroid Build Coastguard Worker self._check_log_prob( 1963*da0073e9SAndroid Build Coastguard Worker Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args) 1964*da0073e9SAndroid Build Coastguard Worker ) 1965*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Poisson, (rate,)) 1966*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Poisson, (rate_1d,)) 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker # We cannot check gradients automatically for zero rates because the finite difference 1969*da0073e9SAndroid Build Coastguard Worker # approximation enters the forbidden parameter space. We instead compare with the 1970*da0073e9SAndroid Build Coastguard Worker # theoretical results. 1971*da0073e9SAndroid Build Coastguard Worker dist = Poisson(rate_zero) 1972*da0073e9SAndroid Build Coastguard Worker dist.log_prob(torch.ones_like(rate_zero)).backward() 1973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rate_zero.grad, torch.inf) 1974*da0073e9SAndroid Build Coastguard Worker 1975*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1976*da0073e9SAndroid Build Coastguard Worker def test_poisson_sample(self): 1977*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 1978*da0073e9SAndroid Build Coastguard Worker saved_dtype = torch.get_default_dtype() 1979*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double, torch.bfloat16, torch.half]: 1980*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(dtype) 1981*da0073e9SAndroid Build Coastguard Worker for rate in [0.1, 1.0, 5.0]: 1982*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 1983*da0073e9SAndroid Build Coastguard Worker Poisson(rate), 1984*da0073e9SAndroid Build Coastguard Worker scipy.stats.poisson(rate), 1985*da0073e9SAndroid Build Coastguard Worker f"Poisson(lambda={rate})", 1986*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-3, 1987*da0073e9SAndroid Build Coastguard Worker ) 1988*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(saved_dtype) 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not found") 1991*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1992*da0073e9SAndroid Build Coastguard Worker def test_poisson_gpu_sample(self): 1993*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 1994*da0073e9SAndroid Build Coastguard Worker for rate in [0.12, 0.9, 4.0]: 1995*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 1996*da0073e9SAndroid Build Coastguard Worker Poisson(torch.tensor([rate]).cuda()), 1997*da0073e9SAndroid Build Coastguard Worker scipy.stats.poisson(rate), 1998*da0073e9SAndroid Build Coastguard Worker f"Poisson(lambda={rate}, cuda)", 1999*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-3, 2000*da0073e9SAndroid Build Coastguard Worker ) 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2003*da0073e9SAndroid Build Coastguard Worker def test_relaxed_bernoulli(self): 2004*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 2005*da0073e9SAndroid Build Coastguard Worker r = torch.tensor(0.3, requires_grad=True) 2006*da0073e9SAndroid Build Coastguard Worker s = 0.3 2007*da0073e9SAndroid Build Coastguard Worker temp = torch.tensor(0.67, requires_grad=True) 2008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3)) 2009*da0073e9SAndroid Build Coastguard Worker self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad) 2010*da0073e9SAndroid Build Coastguard Worker self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,)) 2011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ()) 2012*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2013*da0073e9SAndroid Build Coastguard Worker RelaxedBernoulli(temp, r).sample((3, 2)).size(), 2014*da0073e9SAndroid Build Coastguard Worker ( 2015*da0073e9SAndroid Build Coastguard Worker 3, 2016*da0073e9SAndroid Build Coastguard Worker 2, 2017*da0073e9SAndroid Build Coastguard Worker ), 2018*da0073e9SAndroid Build Coastguard Worker ) 2019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ()) 2020*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(RelaxedBernoulli, (temp, p)) 2021*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(RelaxedBernoulli, (temp, r)) 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker # test that rsample doesn't fail 2024*da0073e9SAndroid Build Coastguard Worker s = RelaxedBernoulli(temp, p).rsample() 2025*da0073e9SAndroid Build Coastguard Worker s.backward(torch.ones_like(s)) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2028*da0073e9SAndroid Build Coastguard Worker def test_rounded_relaxed_bernoulli(self): 2029*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2030*da0073e9SAndroid Build Coastguard Worker 2031*da0073e9SAndroid Build Coastguard Worker class Rounded: 2032*da0073e9SAndroid Build Coastguard Worker def __init__(self, dist): 2033*da0073e9SAndroid Build Coastguard Worker self.dist = dist 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def sample(self, *args, **kwargs): 2036*da0073e9SAndroid Build Coastguard Worker return torch.round(self.dist.sample(*args, **kwargs)) 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]): 2039*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 2040*da0073e9SAndroid Build Coastguard Worker Rounded(RelaxedBernoulli(temp, probs)), 2041*da0073e9SAndroid Build Coastguard Worker scipy.stats.bernoulli(probs), 2042*da0073e9SAndroid Build Coastguard Worker f"Rounded(RelaxedBernoulli(temp={temp}, probs={probs}))", 2043*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-3, 2044*da0073e9SAndroid Build Coastguard Worker ) 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker for probs in [0.001, 0.2, 0.999]: 2047*da0073e9SAndroid Build Coastguard Worker equal_probs = torch.tensor(0.5) 2048*da0073e9SAndroid Build Coastguard Worker dist = RelaxedBernoulli(1e10, probs) 2049*da0073e9SAndroid Build Coastguard Worker s = dist.rsample() 2050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(equal_probs, s) 2051*da0073e9SAndroid Build Coastguard Worker 2052*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2053*da0073e9SAndroid Build Coastguard Worker def test_relaxed_one_hot_categorical_1d(self): 2054*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 2055*da0073e9SAndroid Build Coastguard Worker temp = torch.tensor(0.67, requires_grad=True) 2056*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2057*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,) 2058*da0073e9SAndroid Build Coastguard Worker ) 2059*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 2060*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad 2061*da0073e9SAndroid Build Coastguard Worker ) 2062*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2063*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), 2064*da0073e9SAndroid Build Coastguard Worker (2, 2, 3), 2065*da0073e9SAndroid Build Coastguard Worker ) 2066*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2067*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), 2068*da0073e9SAndroid Build Coastguard Worker (1, 3), 2069*da0073e9SAndroid Build Coastguard Worker ) 2070*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2071*da0073e9SAndroid Build Coastguard Worker lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p) 2072*da0073e9SAndroid Build Coastguard Worker ) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2075*da0073e9SAndroid Build Coastguard Worker def test_relaxed_one_hot_categorical_2d(self): 2076*da0073e9SAndroid Build Coastguard Worker probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 2077*da0073e9SAndroid Build Coastguard Worker probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 2078*da0073e9SAndroid Build Coastguard Worker temp = torch.tensor([3.0], requires_grad=True) 2079*da0073e9SAndroid Build Coastguard Worker # The lower the temperature, the more unstable the log_prob gradcheck is 2080*da0073e9SAndroid Build Coastguard Worker # w.r.t. the sample. Values below 0.25 empirically fail the default tol. 2081*da0073e9SAndroid Build Coastguard Worker temp_2 = torch.tensor([0.25], requires_grad=True) 2082*da0073e9SAndroid Build Coastguard Worker p = torch.tensor(probabilities, requires_grad=True) 2083*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(probabilities_1, requires_grad=True) 2084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3)) 2085*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2086*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), 2087*da0073e9SAndroid Build Coastguard Worker (3, 4, 2, 3), 2088*da0073e9SAndroid Build Coastguard Worker ) 2089*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2090*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3) 2091*da0073e9SAndroid Build Coastguard Worker ) 2092*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2093*da0073e9SAndroid Build Coastguard Worker lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p) 2094*da0073e9SAndroid Build Coastguard Worker ) 2095*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2096*da0073e9SAndroid Build Coastguard Worker lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), 2097*da0073e9SAndroid Build Coastguard Worker (temp_2, p), 2098*da0073e9SAndroid Build Coastguard Worker ) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2101*da0073e9SAndroid Build Coastguard Worker def test_argmax_relaxed_categorical(self): 2102*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker class ArgMax: 2105*da0073e9SAndroid Build Coastguard Worker def __init__(self, dist): 2106*da0073e9SAndroid Build Coastguard Worker self.dist = dist 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker def sample(self, *args, **kwargs): 2109*da0073e9SAndroid Build Coastguard Worker s = self.dist.sample(*args, **kwargs) 2110*da0073e9SAndroid Build Coastguard Worker _, idx = torch.max(s, -1) 2111*da0073e9SAndroid Build Coastguard Worker return idx 2112*da0073e9SAndroid Build Coastguard Worker 2113*da0073e9SAndroid Build Coastguard Worker class ScipyCategorical: 2114*da0073e9SAndroid Build Coastguard Worker def __init__(self, dist): 2115*da0073e9SAndroid Build Coastguard Worker self.dist = dist 2116*da0073e9SAndroid Build Coastguard Worker 2117*da0073e9SAndroid Build Coastguard Worker def pmf(self, samples): 2118*da0073e9SAndroid Build Coastguard Worker new_samples = np.zeros(samples.shape + self.dist.p.shape) 2119*da0073e9SAndroid Build Coastguard Worker new_samples[np.arange(samples.shape[0]), samples] = 1 2120*da0073e9SAndroid Build Coastguard Worker return self.dist.pmf(new_samples) 2121*da0073e9SAndroid Build Coastguard Worker 2122*da0073e9SAndroid Build Coastguard Worker for probs, temp in product( 2123*da0073e9SAndroid Build Coastguard Worker [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0] 2124*da0073e9SAndroid Build Coastguard Worker ): 2125*da0073e9SAndroid Build Coastguard Worker self._check_sampler_discrete( 2126*da0073e9SAndroid Build Coastguard Worker ArgMax(RelaxedOneHotCategorical(temp, probs)), 2127*da0073e9SAndroid Build Coastguard Worker ScipyCategorical(scipy.stats.multinomial(1, probs)), 2128*da0073e9SAndroid Build Coastguard Worker f"Rounded(RelaxedOneHotCategorical(temp={temp}, probs={probs}))", 2129*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-3, 2130*da0073e9SAndroid Build Coastguard Worker ) 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]: 2133*da0073e9SAndroid Build Coastguard Worker equal_probs = torch.ones(probs.size()) / probs.size()[0] 2134*da0073e9SAndroid Build Coastguard Worker dist = RelaxedOneHotCategorical(1e10, probs) 2135*da0073e9SAndroid Build Coastguard Worker s = dist.rsample() 2136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(equal_probs, s) 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2139*da0073e9SAndroid Build Coastguard Worker def test_uniform(self): 2140*da0073e9SAndroid Build Coastguard Worker low = torch.zeros(5, 5, requires_grad=True) 2141*da0073e9SAndroid Build Coastguard Worker high = (torch.ones(5, 5) * 3).requires_grad_() 2142*da0073e9SAndroid Build Coastguard Worker low_1d = torch.zeros(1, requires_grad=True) 2143*da0073e9SAndroid Build Coastguard Worker high_1d = (torch.ones(1) * 3).requires_grad_() 2144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low, high).sample().size(), (5, 5)) 2145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5)) 2146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,)) 2147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1)) 2148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,)) 2149*da0073e9SAndroid Build Coastguard Worker 2150*da0073e9SAndroid Build Coastguard Worker # Check log_prob computation when value outside range 2151*da0073e9SAndroid Build Coastguard Worker uniform = Uniform(low_1d, high_1d, validate_args=False) 2152*da0073e9SAndroid Build Coastguard Worker above_high = torch.tensor([4.0]) 2153*da0073e9SAndroid Build Coastguard Worker below_low = torch.tensor([-1.0]) 2154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.log_prob(above_high).item(), -inf) 2155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.log_prob(below_low).item(), -inf) 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker # check cdf computation when value outside range 2158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.cdf(below_low).item(), 0) 2159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.cdf(above_high).item(), 1) 2160*da0073e9SAndroid Build Coastguard Worker 2161*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2162*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Uniform, (low, high)) 2163*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Uniform, (low, 1.0)) 2164*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Uniform, (0.0, high)) 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 2167*da0073e9SAndroid Build Coastguard Worker rand = low.new(low.size()).uniform_() 2168*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 2169*da0073e9SAndroid Build Coastguard Worker u = Uniform(low, high).rsample() 2170*da0073e9SAndroid Build Coastguard Worker u.backward(torch.ones_like(u)) 2171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(low.grad, 1 - rand) 2172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(high.grad, rand) 2173*da0073e9SAndroid Build Coastguard Worker low.grad.zero_() 2174*da0073e9SAndroid Build Coastguard Worker high.grad.zero_() 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.uniform_()) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2179*da0073e9SAndroid Build Coastguard Worker def test_vonmises_sample(self): 2180*da0073e9SAndroid Build Coastguard Worker for loc in [0.0, math.pi / 2.0]: 2181*da0073e9SAndroid Build Coastguard Worker for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]: 2182*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2183*da0073e9SAndroid Build Coastguard Worker VonMises(loc, concentration), 2184*da0073e9SAndroid Build Coastguard Worker scipy.stats.vonmises(loc=loc, kappa=concentration), 2185*da0073e9SAndroid Build Coastguard Worker f"VonMises(loc={loc}, concentration={concentration})", 2186*da0073e9SAndroid Build Coastguard Worker num_samples=int(1e5), 2187*da0073e9SAndroid Build Coastguard Worker circular=True, 2188*da0073e9SAndroid Build Coastguard Worker ) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker def test_vonmises_logprob(self): 2191*da0073e9SAndroid Build Coastguard Worker concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0] 2192*da0073e9SAndroid Build Coastguard Worker for concentration in concentrations: 2193*da0073e9SAndroid Build Coastguard Worker grid = torch.arange(0.0, 2 * math.pi, 1e-4) 2194*da0073e9SAndroid Build Coastguard Worker prob = VonMises(0.0, concentration).log_prob(grid).exp() 2195*da0073e9SAndroid Build Coastguard Worker norm = prob.mean().item() * 2 * math.pi 2196*da0073e9SAndroid Build Coastguard Worker self.assertLess(abs(norm - 1), 1e-3) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2199*da0073e9SAndroid Build Coastguard Worker def test_cauchy(self): 2200*da0073e9SAndroid Build Coastguard Worker loc = torch.zeros(5, 5, requires_grad=True) 2201*da0073e9SAndroid Build Coastguard Worker scale = torch.ones(5, 5, requires_grad=True) 2202*da0073e9SAndroid Build Coastguard Worker loc_1d = torch.zeros(1, requires_grad=True) 2203*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.ones(1, requires_grad=True) 2204*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean)) 2205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf) 2206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5)) 2207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5)) 2208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,)) 2209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 2210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,)) 2211*da0073e9SAndroid Build Coastguard Worker 2212*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2213*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Cauchy, (loc, scale)) 2214*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Cauchy, (loc, 1.0)) 2215*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Cauchy, (0.0, scale)) 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 2218*da0073e9SAndroid Build Coastguard Worker eps = loc.new(loc.size()).cauchy_() 2219*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 2220*da0073e9SAndroid Build Coastguard Worker c = Cauchy(loc, scale).rsample() 2221*da0073e9SAndroid Build Coastguard Worker c.backward(torch.ones_like(c)) 2222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loc.grad, torch.ones_like(scale)) 2223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale.grad, eps) 2224*da0073e9SAndroid Build Coastguard Worker loc.grad.zero_() 2225*da0073e9SAndroid Build Coastguard Worker scale.grad.zero_() 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.cauchy_()) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2230*da0073e9SAndroid Build Coastguard Worker def test_halfcauchy(self): 2231*da0073e9SAndroid Build Coastguard Worker scale = torch.ones(5, 5, requires_grad=True) 2232*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.ones(1, requires_grad=True) 2233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all()) 2234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(scale_1d).variance, inf) 2235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5)) 2236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5)) 2237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,)) 2238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1)) 2239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,)) 2240*da0073e9SAndroid Build Coastguard Worker 2241*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2242*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(HalfCauchy, (scale,)) 2243*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(HalfCauchy, (1.0,)) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 2246*da0073e9SAndroid Build Coastguard Worker eps = scale.new(scale.size()).cauchy_().abs_() 2247*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 2248*da0073e9SAndroid Build Coastguard Worker c = HalfCauchy(scale).rsample() 2249*da0073e9SAndroid Build Coastguard Worker c.backward(torch.ones_like(c)) 2250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale.grad, eps) 2251*da0073e9SAndroid Build Coastguard Worker scale.grad.zero_() 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2254*da0073e9SAndroid Build Coastguard Worker def test_halfnormal(self): 2255*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 5).abs().requires_grad_() 2256*da0073e9SAndroid Build Coastguard Worker std_1d = torch.randn(1).abs().requires_grad_() 2257*da0073e9SAndroid Build Coastguard Worker std_delta = torch.tensor([1e-5, 1e-5]) 2258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(std).sample().size(), (5, 5)) 2259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5)) 2260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1)) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(std_1d).sample().size(), (1,)) 2262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(0.6).sample((1,)).size(), (1,)) 2263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,)) 2264*da0073e9SAndroid Build Coastguard Worker 2265*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of std 2266*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2267*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2268*da0073e9SAndroid Build Coastguard Worker HalfNormal(std_delta).sample(sample_shape=(1, 2)), 2269*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]), 2270*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 2271*da0073e9SAndroid Build Coastguard Worker rtol=0, 2272*da0073e9SAndroid Build Coastguard Worker ) 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(HalfNormal, (std,)) 2275*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(HalfNormal, (1.0,)) 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker # check .log_prob() can broadcast. 2278*da0073e9SAndroid Build Coastguard Worker dist = HalfNormal(torch.ones(2, 1, 4)) 2279*da0073e9SAndroid Build Coastguard Worker log_prob = dist.log_prob(torch.ones(3, 1)) 2280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob.shape, (2, 3, 4)) 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2283*da0073e9SAndroid Build Coastguard Worker def test_halfnormal_logprob(self): 2284*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 1).abs().requires_grad_() 2285*da0073e9SAndroid Build Coastguard Worker 2286*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 2287*da0073e9SAndroid Build Coastguard Worker s = std.view(-1)[idx].detach() 2288*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.halfnorm(scale=s).logpdf(x) 2289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2290*da0073e9SAndroid Build Coastguard Worker 2291*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(HalfNormal(std), ref_log_prob) 2292*da0073e9SAndroid Build Coastguard Worker 2293*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2294*da0073e9SAndroid Build Coastguard Worker def test_halfnormal_sample(self): 2295*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2296*da0073e9SAndroid Build Coastguard Worker for std in [0.1, 1.0, 10.0]: 2297*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2298*da0073e9SAndroid Build Coastguard Worker HalfNormal(std), 2299*da0073e9SAndroid Build Coastguard Worker scipy.stats.halfnorm(scale=std), 2300*da0073e9SAndroid Build Coastguard Worker f"HalfNormal(scale={std})", 2301*da0073e9SAndroid Build Coastguard Worker ) 2302*da0073e9SAndroid Build Coastguard Worker 2303*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2304*da0073e9SAndroid Build Coastguard Worker def test_inversegamma(self): 2305*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(2, 3).exp().requires_grad_() 2306*da0073e9SAndroid Build Coastguard Worker beta = torch.randn(2, 3).exp().requires_grad_() 2307*da0073e9SAndroid Build Coastguard Worker alpha_1d = torch.randn(1).exp().requires_grad_() 2308*da0073e9SAndroid Build Coastguard Worker beta_1d = torch.randn(1).exp().requires_grad_() 2309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(alpha, beta).sample().size(), (2, 3)) 2310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 2311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 2312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample().size(), (1,)) 2313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(0.5, 0.5).sample().size(), ()) 2314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(InverseGamma(0.5, 0.5).sample((1,)).size(), (1,)) 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(InverseGamma, (alpha, beta)) 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker dist = InverseGamma(torch.ones(4), torch.ones(2, 1, 1)) 2319*da0073e9SAndroid Build Coastguard Worker log_prob = dist.log_prob(torch.ones(3, 1)) 2320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob.shape, (2, 3, 4)) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2323*da0073e9SAndroid Build Coastguard Worker def test_inversegamma_sample(self): 2324*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2325*da0073e9SAndroid Build Coastguard Worker for concentration, rate in product([2, 5], [0.1, 1.0, 10.0]): 2326*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2327*da0073e9SAndroid Build Coastguard Worker InverseGamma(concentration, rate), 2328*da0073e9SAndroid Build Coastguard Worker scipy.stats.invgamma(concentration, scale=rate), 2329*da0073e9SAndroid Build Coastguard Worker "InverseGamma()", 2330*da0073e9SAndroid Build Coastguard Worker ) 2331*da0073e9SAndroid Build Coastguard Worker 2332*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2333*da0073e9SAndroid Build Coastguard Worker def test_lognormal(self): 2334*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 5, requires_grad=True) 2335*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 5).abs().requires_grad_() 2336*da0073e9SAndroid Build Coastguard Worker mean_1d = torch.randn(1, requires_grad=True) 2337*da0073e9SAndroid Build Coastguard Worker std_1d = torch.randn(1).abs().requires_grad_() 2338*da0073e9SAndroid Build Coastguard Worker mean_delta = torch.tensor([1.0, 0.0]) 2339*da0073e9SAndroid Build Coastguard Worker std_delta = torch.tensor([1e-5, 1e-5]) 2340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5)) 2341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5)) 2342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1)) 2343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,)) 2344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(0.2, 0.6).sample((1,)).size(), (1,)) 2345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,)) 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of mean, std 2348*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2349*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2350*da0073e9SAndroid Build Coastguard Worker LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)), 2351*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]), 2352*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 2353*da0073e9SAndroid Build Coastguard Worker rtol=0, 2354*da0073e9SAndroid Build Coastguard Worker ) 2355*da0073e9SAndroid Build Coastguard Worker 2356*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(LogNormal, (mean, std)) 2357*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(LogNormal, (mean, 1.0)) 2358*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(LogNormal, (0.0, std)) 2359*da0073e9SAndroid Build Coastguard Worker 2360*da0073e9SAndroid Build Coastguard Worker # check .log_prob() can broadcast. 2361*da0073e9SAndroid Build Coastguard Worker dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1)) 2362*da0073e9SAndroid Build Coastguard Worker log_prob = dist.log_prob(torch.ones(3, 1)) 2363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob.shape, (2, 3, 4)) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.log_normal_()) 2366*da0073e9SAndroid Build Coastguard Worker 2367*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2368*da0073e9SAndroid Build Coastguard Worker def test_lognormal_logprob(self): 2369*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 1, requires_grad=True) 2370*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 1).abs().requires_grad_() 2371*da0073e9SAndroid Build Coastguard Worker 2372*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 2373*da0073e9SAndroid Build Coastguard Worker m = mean.view(-1)[idx].detach() 2374*da0073e9SAndroid Build Coastguard Worker s = std.view(-1)[idx].detach() 2375*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x) 2376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(LogNormal(mean, std), ref_log_prob) 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2381*da0073e9SAndroid Build Coastguard Worker def test_lognormal_sample(self): 2382*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2383*da0073e9SAndroid Build Coastguard Worker for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 2384*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2385*da0073e9SAndroid Build Coastguard Worker LogNormal(mean, std), 2386*da0073e9SAndroid Build Coastguard Worker scipy.stats.lognorm(scale=math.exp(mean), s=std), 2387*da0073e9SAndroid Build Coastguard Worker f"LogNormal(loc={mean}, scale={std})", 2388*da0073e9SAndroid Build Coastguard Worker ) 2389*da0073e9SAndroid Build Coastguard Worker 2390*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2391*da0073e9SAndroid Build Coastguard Worker def test_logisticnormal(self): 2392*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 2393*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 5).requires_grad_() 2394*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 5).abs().requires_grad_() 2395*da0073e9SAndroid Build Coastguard Worker mean_1d = torch.randn(1).requires_grad_() 2396*da0073e9SAndroid Build Coastguard Worker std_1d = torch.randn(1).abs().requires_grad_() 2397*da0073e9SAndroid Build Coastguard Worker mean_delta = torch.tensor([1.0, 0.0]) 2398*da0073e9SAndroid Build Coastguard Worker std_delta = torch.tensor([1e-5, 1e-5]) 2399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6)) 2400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6)) 2401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2)) 2402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,)) 2403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(0.2, 0.6).sample().size(), (2,)) 2404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,)) 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of mean, std 2407*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2408*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2409*da0073e9SAndroid Build Coastguard Worker LogisticNormal(mean_delta, std_delta).sample(), 2410*da0073e9SAndroid Build Coastguard Worker torch.tensor( 2411*da0073e9SAndroid Build Coastguard Worker [ 2412*da0073e9SAndroid Build Coastguard Worker math.exp(1) / (1.0 + 1.0 + math.exp(1)), 2413*da0073e9SAndroid Build Coastguard Worker 1.0 / (1.0 + 1.0 + math.exp(1)), 2414*da0073e9SAndroid Build Coastguard Worker 1.0 / (1.0 + 1.0 + math.exp(1)), 2415*da0073e9SAndroid Build Coastguard Worker ] 2416*da0073e9SAndroid Build Coastguard Worker ), 2417*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 2418*da0073e9SAndroid Build Coastguard Worker rtol=0, 2419*da0073e9SAndroid Build Coastguard Worker ) 2420*da0073e9SAndroid Build Coastguard Worker 2421*da0073e9SAndroid Build Coastguard Worker # TODO: gradcheck seems to mutate the sample values so that the simplex 2422*da0073e9SAndroid Build Coastguard Worker # constraint fails by a very small margin. 2423*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2424*da0073e9SAndroid Build Coastguard Worker lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std) 2425*da0073e9SAndroid Build Coastguard Worker ) 2426*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2427*da0073e9SAndroid Build Coastguard Worker lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0) 2428*da0073e9SAndroid Build Coastguard Worker ) 2429*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2430*da0073e9SAndroid Build Coastguard Worker lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std) 2431*da0073e9SAndroid Build Coastguard Worker ) 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2434*da0073e9SAndroid Build Coastguard Worker def test_logisticnormal_logprob(self): 2435*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 7).requires_grad_() 2436*da0073e9SAndroid Build Coastguard Worker std = torch.randn(5, 7).abs().requires_grad_() 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker # Smoke test for now 2439*da0073e9SAndroid Build Coastguard Worker # TODO: Once _check_log_prob works with multidimensional distributions, 2440*da0073e9SAndroid Build Coastguard Worker # add proper testing of the log probabilities. 2441*da0073e9SAndroid Build Coastguard Worker dist = LogisticNormal(mean, std) 2442*da0073e9SAndroid Build Coastguard Worker assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker def _get_logistic_normal_ref_sampler(self, base_dist): 2445*da0073e9SAndroid Build Coastguard Worker def _sampler(num_samples): 2446*da0073e9SAndroid Build Coastguard Worker x = base_dist.rvs(num_samples) 2447*da0073e9SAndroid Build Coastguard Worker offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1)) 2448*da0073e9SAndroid Build Coastguard Worker z = 1.0 / (1.0 + np.exp(offset - x)) 2449*da0073e9SAndroid Build Coastguard Worker z_cumprod = np.cumprod(1 - z, axis=-1) 2450*da0073e9SAndroid Build Coastguard Worker y1 = np.pad(z, ((0, 0), (0, 1)), mode="constant", constant_values=1.0) 2451*da0073e9SAndroid Build Coastguard Worker y2 = np.pad( 2452*da0073e9SAndroid Build Coastguard Worker z_cumprod, ((0, 0), (1, 0)), mode="constant", constant_values=1.0 2453*da0073e9SAndroid Build Coastguard Worker ) 2454*da0073e9SAndroid Build Coastguard Worker return y1 * y2 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker return _sampler 2457*da0073e9SAndroid Build Coastguard Worker 2458*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2459*da0073e9SAndroid Build Coastguard Worker def test_logisticnormal_sample(self): 2460*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2461*da0073e9SAndroid Build Coastguard Worker means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)]) 2462*da0073e9SAndroid Build Coastguard Worker covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)]) 2463*da0073e9SAndroid Build Coastguard Worker for mean, cov in product(means, covs): 2464*da0073e9SAndroid Build Coastguard Worker base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov) 2465*da0073e9SAndroid Build Coastguard Worker ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov) 2466*da0073e9SAndroid Build Coastguard Worker ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist) 2467*da0073e9SAndroid Build Coastguard Worker mean_th = torch.tensor(mean) 2468*da0073e9SAndroid Build Coastguard Worker std_th = torch.tensor(np.sqrt(np.diag(cov))) 2469*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2470*da0073e9SAndroid Build Coastguard Worker LogisticNormal(mean_th, std_th), 2471*da0073e9SAndroid Build Coastguard Worker ref_dist, 2472*da0073e9SAndroid Build Coastguard Worker f"LogisticNormal(loc={mean_th}, scale={std_th})", 2473*da0073e9SAndroid Build Coastguard Worker multivariate=True, 2474*da0073e9SAndroid Build Coastguard Worker ) 2475*da0073e9SAndroid Build Coastguard Worker 2476*da0073e9SAndroid Build Coastguard Worker def test_mixture_same_family_shape(self): 2477*da0073e9SAndroid Build Coastguard Worker normal_case_1d = MixtureSameFamily( 2478*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5)) 2479*da0073e9SAndroid Build Coastguard Worker ) 2480*da0073e9SAndroid Build Coastguard Worker normal_case_1d_batch = MixtureSameFamily( 2481*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(3, 5)), Normal(torch.randn(3, 5), torch.rand(3, 5)) 2482*da0073e9SAndroid Build Coastguard Worker ) 2483*da0073e9SAndroid Build Coastguard Worker normal_case_1d_multi_batch = MixtureSameFamily( 2484*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(4, 3, 5)), 2485*da0073e9SAndroid Build Coastguard Worker Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)), 2486*da0073e9SAndroid Build Coastguard Worker ) 2487*da0073e9SAndroid Build Coastguard Worker normal_case_2d = MixtureSameFamily( 2488*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(5)), 2489*da0073e9SAndroid Build Coastguard Worker Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1), 2490*da0073e9SAndroid Build Coastguard Worker ) 2491*da0073e9SAndroid Build Coastguard Worker normal_case_2d_batch = MixtureSameFamily( 2492*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(3, 5)), 2493*da0073e9SAndroid Build Coastguard Worker Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1), 2494*da0073e9SAndroid Build Coastguard Worker ) 2495*da0073e9SAndroid Build Coastguard Worker normal_case_2d_multi_batch = MixtureSameFamily( 2496*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(4, 3, 5)), 2497*da0073e9SAndroid Build Coastguard Worker Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1), 2498*da0073e9SAndroid Build Coastguard Worker ) 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d.sample().size(), ()) 2501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d.sample((2,)).size(), (2,)) 2502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7)) 2503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_batch.sample().size(), (3,)) 2504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3)) 2505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3)) 2506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3)) 2507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3)) 2508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3)) 2509*da0073e9SAndroid Build Coastguard Worker 2510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d.sample().size(), (2,)) 2511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2)) 2512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2)) 2513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2)) 2514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2)) 2515*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2)) 2516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2)) 2517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2)) 2518*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2519*da0073e9SAndroid Build Coastguard Worker normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2) 2520*da0073e9SAndroid Build Coastguard Worker ) 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2523*da0073e9SAndroid Build Coastguard Worker def test_mixture_same_family_log_prob(self): 2524*da0073e9SAndroid Build Coastguard Worker probs = torch.rand(5, 5).softmax(dim=-1) 2525*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5, 5) 2526*da0073e9SAndroid Build Coastguard Worker scale = torch.rand(5, 5) 2527*da0073e9SAndroid Build Coastguard Worker 2528*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 2529*da0073e9SAndroid Build Coastguard Worker p = probs[idx].numpy() 2530*da0073e9SAndroid Build Coastguard Worker m = loc[idx].numpy() 2531*da0073e9SAndroid Build Coastguard Worker s = scale[idx].numpy() 2532*da0073e9SAndroid Build Coastguard Worker mix = scipy.stats.multinomial(1, p) 2533*da0073e9SAndroid Build Coastguard Worker comp = scipy.stats.norm(m, s) 2534*da0073e9SAndroid Build Coastguard Worker expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p)) 2535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2536*da0073e9SAndroid Build Coastguard Worker 2537*da0073e9SAndroid Build Coastguard Worker self._check_log_prob( 2538*da0073e9SAndroid Build Coastguard Worker MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)), 2539*da0073e9SAndroid Build Coastguard Worker ref_log_prob, 2540*da0073e9SAndroid Build Coastguard Worker ) 2541*da0073e9SAndroid Build Coastguard Worker 2542*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2543*da0073e9SAndroid Build Coastguard Worker def test_mixture_same_family_sample(self): 2544*da0073e9SAndroid Build Coastguard Worker probs = torch.rand(5).softmax(dim=-1) 2545*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5) 2546*da0073e9SAndroid Build Coastguard Worker scale = torch.rand(5) 2547*da0073e9SAndroid Build Coastguard Worker 2548*da0073e9SAndroid Build Coastguard Worker class ScipyMixtureNormal: 2549*da0073e9SAndroid Build Coastguard Worker def __init__(self, probs, mu, std): 2550*da0073e9SAndroid Build Coastguard Worker self.probs = probs 2551*da0073e9SAndroid Build Coastguard Worker self.mu = mu 2552*da0073e9SAndroid Build Coastguard Worker self.std = std 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker def rvs(self, n_sample): 2555*da0073e9SAndroid Build Coastguard Worker comp_samples = [ 2556*da0073e9SAndroid Build Coastguard Worker scipy.stats.norm(m, s).rvs(n_sample) 2557*da0073e9SAndroid Build Coastguard Worker for m, s in zip(self.mu, self.std) 2558*da0073e9SAndroid Build Coastguard Worker ] 2559*da0073e9SAndroid Build Coastguard Worker mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample) 2560*da0073e9SAndroid Build Coastguard Worker samples = [] 2561*da0073e9SAndroid Build Coastguard Worker for i in range(n_sample): 2562*da0073e9SAndroid Build Coastguard Worker samples.append(comp_samples[mix_samples[i].argmax()][i]) 2563*da0073e9SAndroid Build Coastguard Worker return np.asarray(samples) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2566*da0073e9SAndroid Build Coastguard Worker MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)), 2567*da0073e9SAndroid Build Coastguard Worker ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()), 2568*da0073e9SAndroid Build Coastguard Worker f"""MixtureSameFamily(Categorical(probs={probs}), 2569*da0073e9SAndroid Build Coastguard Worker Normal(loc={loc}, scale={scale}))""", 2570*da0073e9SAndroid Build Coastguard Worker ) 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2573*da0073e9SAndroid Build Coastguard Worker def test_normal(self): 2574*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5, 5, requires_grad=True) 2575*da0073e9SAndroid Build Coastguard Worker scale = torch.randn(5, 5).abs().requires_grad_() 2576*da0073e9SAndroid Build Coastguard Worker loc_1d = torch.randn(1, requires_grad=True) 2577*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.randn(1).abs().requires_grad_() 2578*da0073e9SAndroid Build Coastguard Worker loc_delta = torch.tensor([1.0, 0.0]) 2579*da0073e9SAndroid Build Coastguard Worker scale_delta = torch.tensor([1e-5, 1e-5]) 2580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(loc, scale).sample().size(), (5, 5)) 2581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5)) 2582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 2583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,)) 2584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(0.2, 0.6).sample((1,)).size(), (1,)) 2585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,)) 2586*da0073e9SAndroid Build Coastguard Worker 2587*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of mean, std 2588*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 2589*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2590*da0073e9SAndroid Build Coastguard Worker Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)), 2591*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]), 2592*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 2593*da0073e9SAndroid Build Coastguard Worker rtol=0, 2594*da0073e9SAndroid Build Coastguard Worker ) 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Normal, (loc, scale)) 2597*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Normal, (loc, 1.0)) 2598*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Normal, (0.0, scale)) 2599*da0073e9SAndroid Build Coastguard Worker 2600*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 2601*da0073e9SAndroid Build Coastguard Worker eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale)) 2602*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 2603*da0073e9SAndroid Build Coastguard Worker z = Normal(loc, scale).rsample() 2604*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones_like(z)) 2605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loc.grad, torch.ones_like(loc)) 2606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale.grad, eps) 2607*da0073e9SAndroid Build Coastguard Worker loc.grad.zero_() 2608*da0073e9SAndroid Build Coastguard Worker scale.grad.zero_() 2609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (5, 5)) 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 2612*da0073e9SAndroid Build Coastguard Worker m = loc.view(-1)[idx] 2613*da0073e9SAndroid Build Coastguard Worker s = scale.view(-1)[idx] 2614*da0073e9SAndroid Build Coastguard Worker expected = math.exp(-((x - m) ** 2) / (2 * s**2)) / math.sqrt( 2615*da0073e9SAndroid Build Coastguard Worker 2 * math.pi * s**2 2616*da0073e9SAndroid Build Coastguard Worker ) 2617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0) 2618*da0073e9SAndroid Build Coastguard Worker 2619*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Normal(loc, scale), ref_log_prob) 2620*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(torch.normal) 2621*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: torch.normal(x, 0.5)) 2622*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: torch.normal(0.2, x)) 2623*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: torch.normal(x, x)) 2624*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.normal_()) 2625*da0073e9SAndroid Build Coastguard Worker 2626*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2627*da0073e9SAndroid Build Coastguard Worker def test_normal_sample(self): 2628*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2629*da0073e9SAndroid Build Coastguard Worker for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 2630*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2631*da0073e9SAndroid Build Coastguard Worker Normal(loc, scale), 2632*da0073e9SAndroid Build Coastguard Worker scipy.stats.norm(loc=loc, scale=scale), 2633*da0073e9SAndroid Build Coastguard Worker f"Normal(mean={loc}, std={scale})", 2634*da0073e9SAndroid Build Coastguard Worker ) 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2637*da0073e9SAndroid Build Coastguard Worker def test_lowrank_multivariate_normal_shape(self): 2638*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 3, requires_grad=True) 2639*da0073e9SAndroid Build Coastguard Worker mean_no_batch = torch.randn(3, requires_grad=True) 2640*da0073e9SAndroid Build Coastguard Worker mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True) 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker # construct PSD covariance 2643*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(3, 1, requires_grad=True) 2644*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(3).abs().requires_grad_() 2645*da0073e9SAndroid Build Coastguard Worker 2646*da0073e9SAndroid Build Coastguard Worker # construct batch of PSD covariances 2647*da0073e9SAndroid Build Coastguard Worker cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True) 2648*da0073e9SAndroid Build Coastguard Worker cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_() 2649*da0073e9SAndroid Build Coastguard Worker 2650*da0073e9SAndroid Build Coastguard Worker # ensure that sample, batch, event shapes all handled correctly 2651*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2652*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample().size(), 2653*da0073e9SAndroid Build Coastguard Worker (5, 3), 2654*da0073e9SAndroid Build Coastguard Worker ) 2655*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2656*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2657*da0073e9SAndroid Build Coastguard Worker .sample() 2658*da0073e9SAndroid Build Coastguard Worker .size(), 2659*da0073e9SAndroid Build Coastguard Worker (3,), 2660*da0073e9SAndroid Build Coastguard Worker ) 2661*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2662*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2663*da0073e9SAndroid Build Coastguard Worker .sample() 2664*da0073e9SAndroid Build Coastguard Worker .size(), 2665*da0073e9SAndroid Build Coastguard Worker (6, 5, 3), 2666*da0073e9SAndroid Build Coastguard Worker ) 2667*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2668*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2,)).size(), 2669*da0073e9SAndroid Build Coastguard Worker (2, 5, 3), 2670*da0073e9SAndroid Build Coastguard Worker ) 2671*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2672*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2673*da0073e9SAndroid Build Coastguard Worker .sample((2,)) 2674*da0073e9SAndroid Build Coastguard Worker .size(), 2675*da0073e9SAndroid Build Coastguard Worker (2, 3), 2676*da0073e9SAndroid Build Coastguard Worker ) 2677*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2678*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2679*da0073e9SAndroid Build Coastguard Worker .sample((2,)) 2680*da0073e9SAndroid Build Coastguard Worker .size(), 2681*da0073e9SAndroid Build Coastguard Worker (2, 6, 5, 3), 2682*da0073e9SAndroid Build Coastguard Worker ) 2683*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2684*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2, 7)).size(), 2685*da0073e9SAndroid Build Coastguard Worker (2, 7, 5, 3), 2686*da0073e9SAndroid Build Coastguard Worker ) 2687*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2688*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2689*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2690*da0073e9SAndroid Build Coastguard Worker .size(), 2691*da0073e9SAndroid Build Coastguard Worker (2, 7, 3), 2692*da0073e9SAndroid Build Coastguard Worker ) 2693*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2694*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2695*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2696*da0073e9SAndroid Build Coastguard Worker .size(), 2697*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2698*da0073e9SAndroid Build Coastguard Worker ) 2699*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2700*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched) 2701*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2702*da0073e9SAndroid Build Coastguard Worker .size(), 2703*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2704*da0073e9SAndroid Build Coastguard Worker ) 2705*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2706*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal( 2707*da0073e9SAndroid Build Coastguard Worker mean_no_batch, cov_factor_batched, cov_diag_batched 2708*da0073e9SAndroid Build Coastguard Worker ) 2709*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2710*da0073e9SAndroid Build Coastguard Worker .size(), 2711*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2712*da0073e9SAndroid Build Coastguard Worker ) 2713*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2714*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal( 2715*da0073e9SAndroid Build Coastguard Worker mean_multi_batch, cov_factor_batched, cov_diag_batched 2716*da0073e9SAndroid Build Coastguard Worker ) 2717*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2718*da0073e9SAndroid Build Coastguard Worker .size(), 2719*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2720*da0073e9SAndroid Build Coastguard Worker ) 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker # check gradients 2723*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2724*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, (mean, cov_factor, cov_diag) 2725*da0073e9SAndroid Build Coastguard Worker ) 2726*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2727*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, (mean_multi_batch, cov_factor, cov_diag) 2728*da0073e9SAndroid Build Coastguard Worker ) 2729*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob( 2730*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, 2731*da0073e9SAndroid Build Coastguard Worker (mean_multi_batch, cov_factor_batched, cov_diag_batched), 2732*da0073e9SAndroid Build Coastguard Worker ) 2733*da0073e9SAndroid Build Coastguard Worker 2734*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2735*da0073e9SAndroid Build Coastguard Worker def test_lowrank_multivariate_normal_log_prob(self): 2736*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(3, requires_grad=True) 2737*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(3, 1, requires_grad=True) 2738*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(3).abs().requires_grad_() 2739*da0073e9SAndroid Build Coastguard Worker cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker # check that logprob values match scipy logpdf, 2742*da0073e9SAndroid Build Coastguard Worker # and that covariance and scale_tril parameters are equivalent 2743*da0073e9SAndroid Build Coastguard Worker dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2744*da0073e9SAndroid Build Coastguard Worker ref_dist = scipy.stats.multivariate_normal( 2745*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 2746*da0073e9SAndroid Build Coastguard Worker ) 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker x = dist1.sample((10,)) 2749*da0073e9SAndroid Build Coastguard Worker expected = ref_dist.logpdf(x.numpy()) 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2752*da0073e9SAndroid Build Coastguard Worker 0.0, 2753*da0073e9SAndroid Build Coastguard Worker np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 2754*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 2755*da0073e9SAndroid Build Coastguard Worker rtol=0, 2756*da0073e9SAndroid Build Coastguard Worker ) 2757*da0073e9SAndroid Build Coastguard Worker 2758*da0073e9SAndroid Build Coastguard Worker # Double-check that batched versions behave the same as unbatched 2759*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 3, requires_grad=True) 2760*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(5, 3, 2, requires_grad=True) 2761*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(5, 3).abs().requires_grad_() 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2764*da0073e9SAndroid Build Coastguard Worker dist_unbatched = [ 2765*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i]) 2766*da0073e9SAndroid Build Coastguard Worker for i in range(mean.size(0)) 2767*da0073e9SAndroid Build Coastguard Worker ] 2768*da0073e9SAndroid Build Coastguard Worker 2769*da0073e9SAndroid Build Coastguard Worker x = dist_batched.sample((10,)) 2770*da0073e9SAndroid Build Coastguard Worker batched_prob = dist_batched.log_prob(x) 2771*da0073e9SAndroid Build Coastguard Worker unbatched_prob = torch.stack( 2772*da0073e9SAndroid Build Coastguard Worker [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 2773*da0073e9SAndroid Build Coastguard Worker ).t() 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob.shape, unbatched_prob.shape) 2776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 2777*da0073e9SAndroid Build Coastguard Worker 2778*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2779*da0073e9SAndroid Build Coastguard Worker def test_lowrank_multivariate_normal_sample(self): 2780*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2781*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, requires_grad=True) 2782*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(5, 1, requires_grad=True) 2783*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(5).abs().requires_grad_() 2784*da0073e9SAndroid Build Coastguard Worker cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2785*da0073e9SAndroid Build Coastguard Worker 2786*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 2787*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(mean, cov_factor, cov_diag), 2788*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal( 2789*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 2790*da0073e9SAndroid Build Coastguard Worker ), 2791*da0073e9SAndroid Build Coastguard Worker f"LowRankMultivariateNormal(loc={mean}, cov_factor={cov_factor}, cov_diag={cov_diag})", 2792*da0073e9SAndroid Build Coastguard Worker multivariate=True, 2793*da0073e9SAndroid Build Coastguard Worker ) 2794*da0073e9SAndroid Build Coastguard Worker 2795*da0073e9SAndroid Build Coastguard Worker def test_lowrank_multivariate_normal_properties(self): 2796*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5) 2797*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(5, 2) 2798*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(5).abs() 2799*da0073e9SAndroid Build Coastguard Worker cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2800*da0073e9SAndroid Build Coastguard Worker m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag) 2801*da0073e9SAndroid Build Coastguard Worker m2 = MultivariateNormal(loc=loc, covariance_matrix=cov) 2802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.mean, m2.mean) 2803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.variance, m2.variance) 2804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.covariance_matrix, m2.covariance_matrix) 2805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.scale_tril, m2.scale_tril) 2806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.precision_matrix, m2.precision_matrix) 2807*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1.entropy(), m2.entropy()) 2808*da0073e9SAndroid Build Coastguard Worker 2809*da0073e9SAndroid Build Coastguard Worker def test_lowrank_multivariate_normal_moments(self): 2810*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2811*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5) 2812*da0073e9SAndroid Build Coastguard Worker cov_factor = torch.randn(5, 2) 2813*da0073e9SAndroid Build Coastguard Worker cov_diag = torch.randn(5).abs() 2814*da0073e9SAndroid Build Coastguard Worker d = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2815*da0073e9SAndroid Build Coastguard Worker samples = d.rsample((100000,)) 2816*da0073e9SAndroid Build Coastguard Worker empirical_mean = samples.mean(0) 2817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0) 2818*da0073e9SAndroid Build Coastguard Worker empirical_var = samples.var(0) 2819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0) 2820*da0073e9SAndroid Build Coastguard Worker 2821*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2822*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_shape(self): 2823*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 3, requires_grad=True) 2824*da0073e9SAndroid Build Coastguard Worker mean_no_batch = torch.randn(3, requires_grad=True) 2825*da0073e9SAndroid Build Coastguard Worker mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True) 2826*da0073e9SAndroid Build Coastguard Worker 2827*da0073e9SAndroid Build Coastguard Worker # construct PSD covariance 2828*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(3, 10) 2829*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 2830*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 2831*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 2832*da0073e9SAndroid Build Coastguard Worker 2833*da0073e9SAndroid Build Coastguard Worker # construct batch of PSD covariances 2834*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(6, 5, 3, 10) 2835*da0073e9SAndroid Build Coastguard Worker cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 2836*da0073e9SAndroid Build Coastguard Worker prec_batched = cov_batched.inverse() 2837*da0073e9SAndroid Build Coastguard Worker scale_tril_batched = torch.linalg.cholesky(cov_batched) 2838*da0073e9SAndroid Build Coastguard Worker 2839*da0073e9SAndroid Build Coastguard Worker # ensure that sample, batch, event shapes all handled correctly 2840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3)) 2841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,)) 2842*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2843*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3) 2844*da0073e9SAndroid Build Coastguard Worker ) 2845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3)) 2846*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2847*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3) 2848*da0073e9SAndroid Build Coastguard Worker ) 2849*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2850*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3) 2851*da0073e9SAndroid Build Coastguard Worker ) 2852*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2853*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3) 2854*da0073e9SAndroid Build Coastguard Worker ) 2855*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2856*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3) 2857*da0073e9SAndroid Build Coastguard Worker ) 2858*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2859*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), 2860*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2861*da0073e9SAndroid Build Coastguard Worker ) 2862*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2863*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3) 2864*da0073e9SAndroid Build Coastguard Worker ) 2865*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2866*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), 2867*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2868*da0073e9SAndroid Build Coastguard Worker ) 2869*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2870*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), 2871*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2872*da0073e9SAndroid Build Coastguard Worker ) 2873*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2874*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), 2875*da0073e9SAndroid Build Coastguard Worker (2, 7, 5, 3), 2876*da0073e9SAndroid Build Coastguard Worker ) 2877*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2878*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, precision_matrix=prec_batched) 2879*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2880*da0073e9SAndroid Build Coastguard Worker .size(), 2881*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2882*da0073e9SAndroid Build Coastguard Worker ) 2883*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2884*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), 2885*da0073e9SAndroid Build Coastguard Worker (2, 7, 5, 3), 2886*da0073e9SAndroid Build Coastguard Worker ) 2887*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2888*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, scale_tril=scale_tril_batched) 2889*da0073e9SAndroid Build Coastguard Worker .sample((2, 7)) 2890*da0073e9SAndroid Build Coastguard Worker .size(), 2891*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, 3), 2892*da0073e9SAndroid Build Coastguard Worker ) 2893*da0073e9SAndroid Build Coastguard Worker 2894*da0073e9SAndroid Build Coastguard Worker # check gradients 2895*da0073e9SAndroid Build Coastguard Worker # We write a custom gradcheck function to maintain the symmetry 2896*da0073e9SAndroid Build Coastguard Worker # of the perturbed covariances and their inverses (precision) 2897*da0073e9SAndroid Build Coastguard Worker def multivariate_normal_log_prob_gradcheck( 2898*da0073e9SAndroid Build Coastguard Worker mean, covariance=None, precision=None, scale_tril=None 2899*da0073e9SAndroid Build Coastguard Worker ): 2900*da0073e9SAndroid Build Coastguard Worker mvn_samples = ( 2901*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, covariance, precision, scale_tril) 2902*da0073e9SAndroid Build Coastguard Worker .sample() 2903*da0073e9SAndroid Build Coastguard Worker .requires_grad_() 2904*da0073e9SAndroid Build Coastguard Worker ) 2905*da0073e9SAndroid Build Coastguard Worker 2906*da0073e9SAndroid Build Coastguard Worker def gradcheck_func(samples, mu, sigma, prec, scale_tril): 2907*da0073e9SAndroid Build Coastguard Worker if sigma is not None: 2908*da0073e9SAndroid Build Coastguard Worker sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance 2909*da0073e9SAndroid Build Coastguard Worker if prec is not None: 2910*da0073e9SAndroid Build Coastguard Worker prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision 2911*da0073e9SAndroid Build Coastguard Worker if scale_tril is not None: 2912*da0073e9SAndroid Build Coastguard Worker scale_tril = scale_tril.tril() 2913*da0073e9SAndroid Build Coastguard Worker return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples) 2914*da0073e9SAndroid Build Coastguard Worker 2915*da0073e9SAndroid Build Coastguard Worker gradcheck( 2916*da0073e9SAndroid Build Coastguard Worker gradcheck_func, 2917*da0073e9SAndroid Build Coastguard Worker (mvn_samples, mean, covariance, precision, scale_tril), 2918*da0073e9SAndroid Build Coastguard Worker raise_exception=True, 2919*da0073e9SAndroid Build Coastguard Worker ) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean, cov) 2922*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov) 2923*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched) 2924*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean, None, prec) 2925*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched) 2926*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril) 2927*da0073e9SAndroid Build Coastguard Worker multivariate_normal_log_prob_gradcheck( 2928*da0073e9SAndroid Build Coastguard Worker mean_no_batch, None, None, scale_tril_batched 2929*da0073e9SAndroid Build Coastguard Worker ) 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 2932*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_stable_with_precision_matrix(self): 2933*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2934*da0073e9SAndroid Build Coastguard Worker P = torch.exp(-((x - x.unsqueeze(-1)) ** 2)) # RBF kernel 2935*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(x.new_zeros(10), precision_matrix=P) 2936*da0073e9SAndroid Build Coastguard Worker 2937*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2938*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_log_prob(self): 2939*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(3, requires_grad=True) 2940*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(3, 10) 2941*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 2942*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 2943*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 2944*da0073e9SAndroid Build Coastguard Worker 2945*da0073e9SAndroid Build Coastguard Worker # check that logprob values match scipy logpdf, 2946*da0073e9SAndroid Build Coastguard Worker # and that covariance and scale_tril parameters are equivalent 2947*da0073e9SAndroid Build Coastguard Worker dist1 = MultivariateNormal(mean, cov) 2948*da0073e9SAndroid Build Coastguard Worker dist2 = MultivariateNormal(mean, precision_matrix=prec) 2949*da0073e9SAndroid Build Coastguard Worker dist3 = MultivariateNormal(mean, scale_tril=scale_tril) 2950*da0073e9SAndroid Build Coastguard Worker ref_dist = scipy.stats.multivariate_normal( 2951*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 2952*da0073e9SAndroid Build Coastguard Worker ) 2953*da0073e9SAndroid Build Coastguard Worker 2954*da0073e9SAndroid Build Coastguard Worker x = dist1.sample((10,)) 2955*da0073e9SAndroid Build Coastguard Worker expected = ref_dist.logpdf(x.numpy()) 2956*da0073e9SAndroid Build Coastguard Worker 2957*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2958*da0073e9SAndroid Build Coastguard Worker 0.0, 2959*da0073e9SAndroid Build Coastguard Worker np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 2960*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 2961*da0073e9SAndroid Build Coastguard Worker rtol=0, 2962*da0073e9SAndroid Build Coastguard Worker ) 2963*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2964*da0073e9SAndroid Build Coastguard Worker 0.0, 2965*da0073e9SAndroid Build Coastguard Worker np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2), 2966*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 2967*da0073e9SAndroid Build Coastguard Worker rtol=0, 2968*da0073e9SAndroid Build Coastguard Worker ) 2969*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2970*da0073e9SAndroid Build Coastguard Worker 0.0, 2971*da0073e9SAndroid Build Coastguard Worker np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2), 2972*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 2973*da0073e9SAndroid Build Coastguard Worker rtol=0, 2974*da0073e9SAndroid Build Coastguard Worker ) 2975*da0073e9SAndroid Build Coastguard Worker 2976*da0073e9SAndroid Build Coastguard Worker # Double-check that batched versions behave the same as unbatched 2977*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5, 3, requires_grad=True) 2978*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(5, 3, 10) 2979*da0073e9SAndroid Build Coastguard Worker cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 2980*da0073e9SAndroid Build Coastguard Worker 2981*da0073e9SAndroid Build Coastguard Worker dist_batched = MultivariateNormal(mean, cov) 2982*da0073e9SAndroid Build Coastguard Worker dist_unbatched = [ 2983*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0)) 2984*da0073e9SAndroid Build Coastguard Worker ] 2985*da0073e9SAndroid Build Coastguard Worker 2986*da0073e9SAndroid Build Coastguard Worker x = dist_batched.sample((10,)) 2987*da0073e9SAndroid Build Coastguard Worker batched_prob = dist_batched.log_prob(x) 2988*da0073e9SAndroid Build Coastguard Worker unbatched_prob = torch.stack( 2989*da0073e9SAndroid Build Coastguard Worker [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 2990*da0073e9SAndroid Build Coastguard Worker ).t() 2991*da0073e9SAndroid Build Coastguard Worker 2992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob.shape, unbatched_prob.shape) 2993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 2994*da0073e9SAndroid Build Coastguard Worker 2995*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2996*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_sample(self): 2997*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 2998*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(3, requires_grad=True) 2999*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(3, 10) 3000*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3001*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 3002*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3005*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, cov), 3006*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal( 3007*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 3008*da0073e9SAndroid Build Coastguard Worker ), 3009*da0073e9SAndroid Build Coastguard Worker f"MultivariateNormal(loc={mean}, cov={cov})", 3010*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3011*da0073e9SAndroid Build Coastguard Worker ) 3012*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3013*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, precision_matrix=prec), 3014*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal( 3015*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 3016*da0073e9SAndroid Build Coastguard Worker ), 3017*da0073e9SAndroid Build Coastguard Worker f"MultivariateNormal(loc={mean}, atol={prec})", 3018*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3019*da0073e9SAndroid Build Coastguard Worker ) 3020*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3021*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(mean, scale_tril=scale_tril), 3022*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal( 3023*da0073e9SAndroid Build Coastguard Worker mean.detach().numpy(), cov.detach().numpy() 3024*da0073e9SAndroid Build Coastguard Worker ), 3025*da0073e9SAndroid Build Coastguard Worker f"MultivariateNormal(loc={mean}, scale_tril={scale_tril})", 3026*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3027*da0073e9SAndroid Build Coastguard Worker ) 3028*da0073e9SAndroid Build Coastguard Worker 3029*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3030*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_properties(self): 3031*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5) 3032*da0073e9SAndroid Build Coastguard Worker scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) 3033*da0073e9SAndroid Build Coastguard Worker m = MultivariateNormal(loc=loc, scale_tril=scale_tril) 3034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t())) 3035*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3036*da0073e9SAndroid Build Coastguard Worker m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]) 3037*da0073e9SAndroid Build Coastguard Worker ) 3038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix)) 3039*da0073e9SAndroid Build Coastguard Worker 3040*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3041*da0073e9SAndroid Build Coastguard Worker def test_multivariate_normal_moments(self): 3042*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3043*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(5) 3044*da0073e9SAndroid Build Coastguard Worker scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) 3045*da0073e9SAndroid Build Coastguard Worker d = MultivariateNormal(mean, scale_tril=scale_tril) 3046*da0073e9SAndroid Build Coastguard Worker samples = d.rsample((100000,)) 3047*da0073e9SAndroid Build Coastguard Worker empirical_mean = samples.mean(0) 3048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0) 3049*da0073e9SAndroid Build Coastguard Worker empirical_var = samples.var(0) 3050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0) 3051*da0073e9SAndroid Build Coastguard Worker 3052*da0073e9SAndroid Build Coastguard Worker # We applied same tests in Multivariate Normal distribution for Wishart distribution 3053*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3054*da0073e9SAndroid Build Coastguard Worker def test_wishart_shape(self): 3055*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3056*da0073e9SAndroid Build Coastguard Worker ndim = 3 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker df = torch.rand(5, requires_grad=True) + ndim 3059*da0073e9SAndroid Build Coastguard Worker df_no_batch = torch.rand([], requires_grad=True) + ndim 3060*da0073e9SAndroid Build Coastguard Worker df_multi_batch = torch.rand(6, 5, requires_grad=True) + ndim 3061*da0073e9SAndroid Build Coastguard Worker 3062*da0073e9SAndroid Build Coastguard Worker # construct PSD covariance 3063*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(ndim, 10) 3064*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3065*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 3066*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3067*da0073e9SAndroid Build Coastguard Worker 3068*da0073e9SAndroid Build Coastguard Worker # construct batch of PSD covariances 3069*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(6, 5, ndim, 10) 3070*da0073e9SAndroid Build Coastguard Worker cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 3071*da0073e9SAndroid Build Coastguard Worker prec_batched = cov_batched.inverse() 3072*da0073e9SAndroid Build Coastguard Worker scale_tril_batched = torch.linalg.cholesky(cov_batched) 3073*da0073e9SAndroid Build Coastguard Worker 3074*da0073e9SAndroid Build Coastguard Worker # ensure that sample, batch, event shapes all handled correctly 3075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Wishart(df, cov).sample().size(), (5, ndim, ndim)) 3076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (ndim, ndim)) 3077*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3078*da0073e9SAndroid Build Coastguard Worker Wishart(df_multi_batch, cov).sample().size(), (6, 5, ndim, ndim) 3079*da0073e9SAndroid Build Coastguard Worker ) 3080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, ndim, ndim)) 3081*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, ndim, ndim)) 3082*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3083*da0073e9SAndroid Build Coastguard Worker Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, ndim, ndim) 3084*da0073e9SAndroid Build Coastguard Worker ) 3085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, ndim, ndim)) 3086*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3087*da0073e9SAndroid Build Coastguard Worker Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, ndim, ndim) 3088*da0073e9SAndroid Build Coastguard Worker ) 3089*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3090*da0073e9SAndroid Build Coastguard Worker Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim) 3091*da0073e9SAndroid Build Coastguard Worker ) 3092*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3093*da0073e9SAndroid Build Coastguard Worker Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim) 3094*da0073e9SAndroid Build Coastguard Worker ) 3095*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3096*da0073e9SAndroid Build Coastguard Worker Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), 3097*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, ndim, ndim), 3098*da0073e9SAndroid Build Coastguard Worker ) 3099*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3100*da0073e9SAndroid Build Coastguard Worker Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), 3101*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, ndim, ndim), 3102*da0073e9SAndroid Build Coastguard Worker ) 3103*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3104*da0073e9SAndroid Build Coastguard Worker Wishart(df, precision_matrix=prec).sample((2, 7)).size(), 3105*da0073e9SAndroid Build Coastguard Worker (2, 7, 5, ndim, ndim), 3106*da0073e9SAndroid Build Coastguard Worker ) 3107*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3108*da0073e9SAndroid Build Coastguard Worker Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), 3109*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, ndim, ndim), 3110*da0073e9SAndroid Build Coastguard Worker ) 3111*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3112*da0073e9SAndroid Build Coastguard Worker Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), 3113*da0073e9SAndroid Build Coastguard Worker (2, 7, 5, ndim, ndim), 3114*da0073e9SAndroid Build Coastguard Worker ) 3115*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3116*da0073e9SAndroid Build Coastguard Worker Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), 3117*da0073e9SAndroid Build Coastguard Worker (2, 7, 6, 5, ndim, ndim), 3118*da0073e9SAndroid Build Coastguard Worker ) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker # check gradients 3121*da0073e9SAndroid Build Coastguard Worker # Modified and applied the same tests for multivariate_normal 3122*da0073e9SAndroid Build Coastguard Worker def wishart_log_prob_gradcheck( 3123*da0073e9SAndroid Build Coastguard Worker df=None, covariance=None, precision=None, scale_tril=None 3124*da0073e9SAndroid Build Coastguard Worker ): 3125*da0073e9SAndroid Build Coastguard Worker wishart_samples = ( 3126*da0073e9SAndroid Build Coastguard Worker Wishart(df, covariance, precision, scale_tril).sample().requires_grad_() 3127*da0073e9SAndroid Build Coastguard Worker ) 3128*da0073e9SAndroid Build Coastguard Worker 3129*da0073e9SAndroid Build Coastguard Worker def gradcheck_func(samples, nu, sigma, prec, scale_tril): 3130*da0073e9SAndroid Build Coastguard Worker if sigma is not None: 3131*da0073e9SAndroid Build Coastguard Worker sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance 3132*da0073e9SAndroid Build Coastguard Worker if prec is not None: 3133*da0073e9SAndroid Build Coastguard Worker prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision 3134*da0073e9SAndroid Build Coastguard Worker if scale_tril is not None: 3135*da0073e9SAndroid Build Coastguard Worker scale_tril = scale_tril.tril() 3136*da0073e9SAndroid Build Coastguard Worker return Wishart(nu, sigma, prec, scale_tril).log_prob(samples) 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker gradcheck( 3139*da0073e9SAndroid Build Coastguard Worker gradcheck_func, 3140*da0073e9SAndroid Build Coastguard Worker (wishart_samples, df, covariance, precision, scale_tril), 3141*da0073e9SAndroid Build Coastguard Worker raise_exception=True, 3142*da0073e9SAndroid Build Coastguard Worker ) 3143*da0073e9SAndroid Build Coastguard Worker 3144*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df, cov) 3145*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df_multi_batch, cov) 3146*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df_multi_batch, cov_batched) 3147*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df, None, prec) 3148*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df_no_batch, None, prec_batched) 3149*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df, None, None, scale_tril) 3150*da0073e9SAndroid Build Coastguard Worker wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched) 3151*da0073e9SAndroid Build Coastguard Worker 3152*da0073e9SAndroid Build Coastguard Worker def test_wishart_stable_with_precision_matrix(self): 3153*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3154*da0073e9SAndroid Build Coastguard Worker ndim = 10 3155*da0073e9SAndroid Build Coastguard Worker x = torch.randn(ndim) 3156*da0073e9SAndroid Build Coastguard Worker P = torch.exp(-((x - x.unsqueeze(-1)) ** 2)) # RBF kernel 3157*da0073e9SAndroid Build Coastguard Worker Wishart(torch.tensor(ndim), precision_matrix=P) 3158*da0073e9SAndroid Build Coastguard Worker 3159*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3160*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3161*da0073e9SAndroid Build Coastguard Worker def test_wishart_log_prob(self): 3162*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3163*da0073e9SAndroid Build Coastguard Worker ndim = 3 3164*da0073e9SAndroid Build Coastguard Worker df = torch.rand([], requires_grad=True) + ndim - 1 3165*da0073e9SAndroid Build Coastguard Worker # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3166*da0073e9SAndroid Build Coastguard Worker if version.parse(scipy.__version__) < version.parse("1.7.0"): 3167*da0073e9SAndroid Build Coastguard Worker df += 1.0 3168*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(ndim, 10) 3169*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3170*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 3171*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3172*da0073e9SAndroid Build Coastguard Worker 3173*da0073e9SAndroid Build Coastguard Worker # check that logprob values match scipy logpdf, 3174*da0073e9SAndroid Build Coastguard Worker # and that covariance and scale_tril parameters are equivalent 3175*da0073e9SAndroid Build Coastguard Worker dist1 = Wishart(df, cov) 3176*da0073e9SAndroid Build Coastguard Worker dist2 = Wishart(df, precision_matrix=prec) 3177*da0073e9SAndroid Build Coastguard Worker dist3 = Wishart(df, scale_tril=scale_tril) 3178*da0073e9SAndroid Build Coastguard Worker ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy()) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker x = dist1.sample((1000,)) 3181*da0073e9SAndroid Build Coastguard Worker expected = ref_dist.logpdf(x.transpose(0, 2).numpy()) 3182*da0073e9SAndroid Build Coastguard Worker 3183*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3184*da0073e9SAndroid Build Coastguard Worker 0.0, 3185*da0073e9SAndroid Build Coastguard Worker np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 3186*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 3187*da0073e9SAndroid Build Coastguard Worker rtol=0, 3188*da0073e9SAndroid Build Coastguard Worker ) 3189*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3190*da0073e9SAndroid Build Coastguard Worker 0.0, 3191*da0073e9SAndroid Build Coastguard Worker np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2), 3192*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 3193*da0073e9SAndroid Build Coastguard Worker rtol=0, 3194*da0073e9SAndroid Build Coastguard Worker ) 3195*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3196*da0073e9SAndroid Build Coastguard Worker 0.0, 3197*da0073e9SAndroid Build Coastguard Worker np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2), 3198*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 3199*da0073e9SAndroid Build Coastguard Worker rtol=0, 3200*da0073e9SAndroid Build Coastguard Worker ) 3201*da0073e9SAndroid Build Coastguard Worker 3202*da0073e9SAndroid Build Coastguard Worker # Double-check that batched versions behave the same as unbatched 3203*da0073e9SAndroid Build Coastguard Worker df = torch.rand(5, requires_grad=True) + ndim - 1 3204*da0073e9SAndroid Build Coastguard Worker # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3205*da0073e9SAndroid Build Coastguard Worker if version.parse(scipy.__version__) < version.parse("1.7.0"): 3206*da0073e9SAndroid Build Coastguard Worker df += 1.0 3207*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(5, ndim, 10) 3208*da0073e9SAndroid Build Coastguard Worker cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 3209*da0073e9SAndroid Build Coastguard Worker 3210*da0073e9SAndroid Build Coastguard Worker dist_batched = Wishart(df, cov) 3211*da0073e9SAndroid Build Coastguard Worker dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))] 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker x = dist_batched.sample((1000,)) 3214*da0073e9SAndroid Build Coastguard Worker batched_prob = dist_batched.log_prob(x) 3215*da0073e9SAndroid Build Coastguard Worker unbatched_prob = torch.stack( 3216*da0073e9SAndroid Build Coastguard Worker [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 3217*da0073e9SAndroid Build Coastguard Worker ).t() 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob.shape, unbatched_prob.shape) 3220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3223*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3224*da0073e9SAndroid Build Coastguard Worker def test_wishart_sample(self): 3225*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3226*da0073e9SAndroid Build Coastguard Worker ndim = 3 3227*da0073e9SAndroid Build Coastguard Worker df = torch.rand([], requires_grad=True) + ndim - 1 3228*da0073e9SAndroid Build Coastguard Worker # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3229*da0073e9SAndroid Build Coastguard Worker if version.parse(scipy.__version__) < version.parse("1.7.0"): 3230*da0073e9SAndroid Build Coastguard Worker df += 1.0 3231*da0073e9SAndroid Build Coastguard Worker tmp = torch.randn(ndim, 10) 3232*da0073e9SAndroid Build Coastguard Worker cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3233*da0073e9SAndroid Build Coastguard Worker prec = cov.inverse().requires_grad_() 3234*da0073e9SAndroid Build Coastguard Worker scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy()) 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3239*da0073e9SAndroid Build Coastguard Worker Wishart(df, cov), 3240*da0073e9SAndroid Build Coastguard Worker ref_dist, 3241*da0073e9SAndroid Build Coastguard Worker f"Wishart(df={df}, covariance_matrix={cov})", 3242*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3243*da0073e9SAndroid Build Coastguard Worker ) 3244*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3245*da0073e9SAndroid Build Coastguard Worker Wishart(df, precision_matrix=prec), 3246*da0073e9SAndroid Build Coastguard Worker ref_dist, 3247*da0073e9SAndroid Build Coastguard Worker f"Wishart(df={df}, precision_matrix={prec})", 3248*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3249*da0073e9SAndroid Build Coastguard Worker ) 3250*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3251*da0073e9SAndroid Build Coastguard Worker Wishart(df, scale_tril=scale_tril), 3252*da0073e9SAndroid Build Coastguard Worker ref_dist, 3253*da0073e9SAndroid Build Coastguard Worker f"Wishart(df={df}, scale_tril={scale_tril})", 3254*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3255*da0073e9SAndroid Build Coastguard Worker ) 3256*da0073e9SAndroid Build Coastguard Worker 3257*da0073e9SAndroid Build Coastguard Worker def test_wishart_properties(self): 3258*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3259*da0073e9SAndroid Build Coastguard Worker ndim = 5 3260*da0073e9SAndroid Build Coastguard Worker df = torch.rand([]) + ndim - 1 3261*da0073e9SAndroid Build Coastguard Worker scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim)) 3262*da0073e9SAndroid Build Coastguard Worker m = Wishart(df=df, scale_tril=scale_tril) 3263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t())) 3264*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3265*da0073e9SAndroid Build Coastguard Worker m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]) 3266*da0073e9SAndroid Build Coastguard Worker ) 3267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix)) 3268*da0073e9SAndroid Build Coastguard Worker 3269*da0073e9SAndroid Build Coastguard Worker def test_wishart_moments(self): 3270*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3271*da0073e9SAndroid Build Coastguard Worker ndim = 3 3272*da0073e9SAndroid Build Coastguard Worker df = torch.rand([]) + ndim - 1 3273*da0073e9SAndroid Build Coastguard Worker scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim)) 3274*da0073e9SAndroid Build Coastguard Worker d = Wishart(df=df, scale_tril=scale_tril) 3275*da0073e9SAndroid Build Coastguard Worker samples = d.rsample((ndim * ndim * 100000,)) 3276*da0073e9SAndroid Build Coastguard Worker empirical_mean = samples.mean(0) 3277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0) 3278*da0073e9SAndroid Build Coastguard Worker empirical_var = samples.var(0) 3279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0) 3280*da0073e9SAndroid Build Coastguard Worker 3281*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3282*da0073e9SAndroid Build Coastguard Worker def test_exponential(self): 3283*da0073e9SAndroid Build Coastguard Worker rate = torch.randn(5, 5).abs().requires_grad_() 3284*da0073e9SAndroid Build Coastguard Worker rate_1d = torch.randn(1).abs().requires_grad_() 3285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate).sample().size(), (5, 5)) 3286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5)) 3287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1)) 3288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate_1d).sample().size(), (1,)) 3289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) 3290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) 3291*da0073e9SAndroid Build Coastguard Worker 3292*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Exponential, (rate,)) 3293*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 3294*da0073e9SAndroid Build Coastguard Worker eps = rate.new(rate.size()).exponential_() 3295*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 3296*da0073e9SAndroid Build Coastguard Worker z = Exponential(rate).rsample() 3297*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones_like(z)) 3298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rate.grad, -eps / rate**2) 3299*da0073e9SAndroid Build Coastguard Worker rate.grad.zero_() 3300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (5, 5)) 3301*da0073e9SAndroid Build Coastguard Worker 3302*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3303*da0073e9SAndroid Build Coastguard Worker m = rate.view(-1)[idx] 3304*da0073e9SAndroid Build Coastguard Worker expected = math.log(m) - m * x 3305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3306*da0073e9SAndroid Build Coastguard Worker 3307*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Exponential(rate), ref_log_prob) 3308*da0073e9SAndroid Build Coastguard Worker self._check_forward_ad(lambda x: x.exponential_()) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker def mean_var(lambd, sample): 3311*da0073e9SAndroid Build Coastguard Worker sample.exponential_(lambd) 3312*da0073e9SAndroid Build Coastguard Worker mean = sample.float().mean() 3313*da0073e9SAndroid Build Coastguard Worker var = sample.float().var() 3314*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2) 3315*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2) 3316*da0073e9SAndroid Build Coastguard Worker 3317*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]: 3318*da0073e9SAndroid Build Coastguard Worker for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]: 3319*da0073e9SAndroid Build Coastguard Worker sample_len = 50000 3320*da0073e9SAndroid Build Coastguard Worker mean_var(lambd, torch.rand(sample_len, dtype=dtype)) 3321*da0073e9SAndroid Build Coastguard Worker 3322*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3323*da0073e9SAndroid Build Coastguard Worker def test_exponential_sample(self): 3324*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 3325*da0073e9SAndroid Build Coastguard Worker for rate in [1e-5, 1.0, 10.0]: 3326*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3327*da0073e9SAndroid Build Coastguard Worker Exponential(rate), 3328*da0073e9SAndroid Build Coastguard Worker scipy.stats.expon(scale=1.0 / rate), 3329*da0073e9SAndroid Build Coastguard Worker f"Exponential(rate={rate})", 3330*da0073e9SAndroid Build Coastguard Worker ) 3331*da0073e9SAndroid Build Coastguard Worker 3332*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3333*da0073e9SAndroid Build Coastguard Worker def test_laplace(self): 3334*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(5, 5, requires_grad=True) 3335*da0073e9SAndroid Build Coastguard Worker scale = torch.randn(5, 5).abs().requires_grad_() 3336*da0073e9SAndroid Build Coastguard Worker loc_1d = torch.randn(1, requires_grad=True) 3337*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.randn(1, requires_grad=True) 3338*da0073e9SAndroid Build Coastguard Worker loc_delta = torch.tensor([1.0, 0.0]) 3339*da0073e9SAndroid Build Coastguard Worker scale_delta = torch.tensor([1e-5, 1e-5]) 3340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5)) 3341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5)) 3342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 3343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,)) 3344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(0.2, 0.6).sample((1,)).size(), (1,)) 3345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,)) 3346*da0073e9SAndroid Build Coastguard Worker 3347*da0073e9SAndroid Build Coastguard Worker # sample check for extreme value of mean, std 3348*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) 3349*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3350*da0073e9SAndroid Build Coastguard Worker Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)), 3351*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]), 3352*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 3353*da0073e9SAndroid Build Coastguard Worker rtol=0, 3354*da0073e9SAndroid Build Coastguard Worker ) 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Laplace, (loc, scale)) 3357*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Laplace, (loc, 1.0)) 3358*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(Laplace, (0.0, scale)) 3359*da0073e9SAndroid Build Coastguard Worker 3360*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 3361*da0073e9SAndroid Build Coastguard Worker eps = torch.ones_like(loc).uniform_(-0.5, 0.5) 3362*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 3363*da0073e9SAndroid Build Coastguard Worker z = Laplace(loc, scale).rsample() 3364*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones_like(z)) 3365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loc.grad, torch.ones_like(loc)) 3366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs())) 3367*da0073e9SAndroid Build Coastguard Worker loc.grad.zero_() 3368*da0073e9SAndroid Build Coastguard Worker scale.grad.zero_() 3369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (5, 5)) 3370*da0073e9SAndroid Build Coastguard Worker 3371*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3372*da0073e9SAndroid Build Coastguard Worker m = loc.view(-1)[idx] 3373*da0073e9SAndroid Build Coastguard Worker s = scale.view(-1)[idx] 3374*da0073e9SAndroid Build Coastguard Worker expected = -math.log(2 * s) - abs(x - m) / s 3375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3376*da0073e9SAndroid Build Coastguard Worker 3377*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Laplace(loc, scale), ref_log_prob) 3378*da0073e9SAndroid Build Coastguard Worker 3379*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3380*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3381*da0073e9SAndroid Build Coastguard Worker def test_laplace_sample(self): 3382*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 3383*da0073e9SAndroid Build Coastguard Worker for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 3384*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3385*da0073e9SAndroid Build Coastguard Worker Laplace(loc, scale), 3386*da0073e9SAndroid Build Coastguard Worker scipy.stats.laplace(loc=loc, scale=scale), 3387*da0073e9SAndroid Build Coastguard Worker f"Laplace(loc={loc}, scale={scale})", 3388*da0073e9SAndroid Build Coastguard Worker ) 3389*da0073e9SAndroid Build Coastguard Worker 3390*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3391*da0073e9SAndroid Build Coastguard Worker def test_gamma_shape(self): 3392*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(2, 3).exp().requires_grad_() 3393*da0073e9SAndroid Build Coastguard Worker beta = torch.randn(2, 3).exp().requires_grad_() 3394*da0073e9SAndroid Build Coastguard Worker alpha_1d = torch.randn(1).exp().requires_grad_() 3395*da0073e9SAndroid Build Coastguard Worker beta_1d = torch.randn(1).exp().requires_grad_() 3396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3)) 3397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 3398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 3399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,)) 3400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(0.5, 0.5).sample().size(), ()) 3401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,)) 3402*da0073e9SAndroid Build Coastguard Worker 3403*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3404*da0073e9SAndroid Build Coastguard Worker a = alpha.view(-1)[idx].detach() 3405*da0073e9SAndroid Build Coastguard Worker b = beta.view(-1)[idx].detach() 3406*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b) 3407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3408*da0073e9SAndroid Build Coastguard Worker 3409*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Gamma(alpha, beta), ref_log_prob) 3410*da0073e9SAndroid Build Coastguard Worker 3411*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3412*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3413*da0073e9SAndroid Build Coastguard Worker def test_gamma_gpu_shape(self): 3414*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(2, 3).cuda().exp().requires_grad_() 3415*da0073e9SAndroid Build Coastguard Worker beta = torch.randn(2, 3).cuda().exp().requires_grad_() 3416*da0073e9SAndroid Build Coastguard Worker alpha_1d = torch.randn(1).cuda().exp().requires_grad_() 3417*da0073e9SAndroid Build Coastguard Worker beta_1d = torch.randn(1).cuda().exp().requires_grad_() 3418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3)) 3419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 3420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 3421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,)) 3422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(0.5, 0.5).sample().size(), ()) 3423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,)) 3424*da0073e9SAndroid Build Coastguard Worker 3425*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3426*da0073e9SAndroid Build Coastguard Worker a = alpha.view(-1)[idx].detach().cpu() 3427*da0073e9SAndroid Build Coastguard Worker b = beta.view(-1)[idx].detach().cpu() 3428*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b) 3429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3430*da0073e9SAndroid Build Coastguard Worker 3431*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Gamma(alpha, beta), ref_log_prob) 3432*da0073e9SAndroid Build Coastguard Worker 3433*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3434*da0073e9SAndroid Build Coastguard Worker def test_gamma_sample(self): 3435*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3436*da0073e9SAndroid Build Coastguard Worker for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3437*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3438*da0073e9SAndroid Build Coastguard Worker Gamma(alpha, beta), 3439*da0073e9SAndroid Build Coastguard Worker scipy.stats.gamma(alpha, scale=1.0 / beta), 3440*da0073e9SAndroid Build Coastguard Worker f"Gamma(concentration={alpha}, rate={beta})", 3441*da0073e9SAndroid Build Coastguard Worker ) 3442*da0073e9SAndroid Build Coastguard Worker 3443*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3444*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3445*da0073e9SAndroid Build Coastguard Worker def test_gamma_gpu_sample(self): 3446*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) 3447*da0073e9SAndroid Build Coastguard Worker for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3448*da0073e9SAndroid Build Coastguard Worker a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda() 3449*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3450*da0073e9SAndroid Build Coastguard Worker Gamma(a, b), 3451*da0073e9SAndroid Build Coastguard Worker scipy.stats.gamma(alpha, scale=1.0 / beta), 3452*da0073e9SAndroid Build Coastguard Worker f"Gamma(alpha={alpha}, beta={beta})", 3453*da0073e9SAndroid Build Coastguard Worker failure_rate=1e-4, 3454*da0073e9SAndroid Build Coastguard Worker ) 3455*da0073e9SAndroid Build Coastguard Worker 3456*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3457*da0073e9SAndroid Build Coastguard Worker def test_pareto(self): 3458*da0073e9SAndroid Build Coastguard Worker scale = torch.randn(2, 3).abs().requires_grad_() 3459*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(2, 3).abs().requires_grad_() 3460*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.randn(1).abs().requires_grad_() 3461*da0073e9SAndroid Build Coastguard Worker alpha_1d = torch.randn(1).abs().requires_grad_() 3462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale_1d, 0.5).mean, inf) 3463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale_1d, 0.5).variance, inf) 3464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3)) 3465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3)) 3466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1)) 3467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,)) 3468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(1.0, 1.0).sample().size(), ()) 3469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,)) 3470*da0073e9SAndroid Build Coastguard Worker 3471*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3472*da0073e9SAndroid Build Coastguard Worker s = scale.view(-1)[idx].detach() 3473*da0073e9SAndroid Build Coastguard Worker a = alpha.view(-1)[idx].detach() 3474*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.pareto.logpdf(x, a, scale=s) 3475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3476*da0073e9SAndroid Build Coastguard Worker 3477*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Pareto(scale, alpha), ref_log_prob) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3480*da0073e9SAndroid Build Coastguard Worker def test_pareto_sample(self): 3481*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 3482*da0073e9SAndroid Build Coastguard Worker for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3483*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3484*da0073e9SAndroid Build Coastguard Worker Pareto(scale, alpha), 3485*da0073e9SAndroid Build Coastguard Worker scipy.stats.pareto(alpha, scale=scale), 3486*da0073e9SAndroid Build Coastguard Worker f"Pareto(scale={scale}, alpha={alpha})", 3487*da0073e9SAndroid Build Coastguard Worker ) 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3490*da0073e9SAndroid Build Coastguard Worker def test_gumbel(self): 3491*da0073e9SAndroid Build Coastguard Worker loc = torch.randn(2, 3, requires_grad=True) 3492*da0073e9SAndroid Build Coastguard Worker scale = torch.randn(2, 3).abs().requires_grad_() 3493*da0073e9SAndroid Build Coastguard Worker loc_1d = torch.randn(1, requires_grad=True) 3494*da0073e9SAndroid Build Coastguard Worker scale_1d = torch.randn(1).abs().requires_grad_() 3495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3)) 3496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3)) 3497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,)) 3498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 3499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ()) 3500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,)) 3501*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3502*da0073e9SAndroid Build Coastguard Worker Gumbel( 3503*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, dtype=torch.float32), 3504*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, dtype=torch.float32), 3505*da0073e9SAndroid Build Coastguard Worker validate_args=False, 3506*da0073e9SAndroid Build Coastguard Worker ).cdf(20.0), 3507*da0073e9SAndroid Build Coastguard Worker 1.0, 3508*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 3509*da0073e9SAndroid Build Coastguard Worker rtol=0, 3510*da0073e9SAndroid Build Coastguard Worker ) 3511*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3512*da0073e9SAndroid Build Coastguard Worker Gumbel( 3513*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, dtype=torch.float64), 3514*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, dtype=torch.float64), 3515*da0073e9SAndroid Build Coastguard Worker validate_args=False, 3516*da0073e9SAndroid Build Coastguard Worker ).cdf(50.0), 3517*da0073e9SAndroid Build Coastguard Worker 1.0, 3518*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 3519*da0073e9SAndroid Build Coastguard Worker rtol=0, 3520*da0073e9SAndroid Build Coastguard Worker ) 3521*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3522*da0073e9SAndroid Build Coastguard Worker Gumbel( 3523*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, dtype=torch.float32), 3524*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, dtype=torch.float32), 3525*da0073e9SAndroid Build Coastguard Worker validate_args=False, 3526*da0073e9SAndroid Build Coastguard Worker ).cdf(-5.0), 3527*da0073e9SAndroid Build Coastguard Worker 0.0, 3528*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 3529*da0073e9SAndroid Build Coastguard Worker rtol=0, 3530*da0073e9SAndroid Build Coastguard Worker ) 3531*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3532*da0073e9SAndroid Build Coastguard Worker Gumbel( 3533*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, dtype=torch.float64), 3534*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, dtype=torch.float64), 3535*da0073e9SAndroid Build Coastguard Worker validate_args=False, 3536*da0073e9SAndroid Build Coastguard Worker ).cdf(-10.0), 3537*da0073e9SAndroid Build Coastguard Worker 0.0, 3538*da0073e9SAndroid Build Coastguard Worker atol=1e-8, 3539*da0073e9SAndroid Build Coastguard Worker rtol=0, 3540*da0073e9SAndroid Build Coastguard Worker ) 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3543*da0073e9SAndroid Build Coastguard Worker l = loc.view(-1)[idx].detach() 3544*da0073e9SAndroid Build Coastguard Worker s = scale.view(-1)[idx].detach() 3545*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s) 3546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3547*da0073e9SAndroid Build Coastguard Worker 3548*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Gumbel(loc, scale), ref_log_prob) 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3551*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3552*da0073e9SAndroid Build Coastguard Worker def test_gumbel_sample(self): 3553*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see note [Randomized statistical tests] 3554*da0073e9SAndroid Build Coastguard Worker for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3555*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3556*da0073e9SAndroid Build Coastguard Worker Gumbel(loc, scale), 3557*da0073e9SAndroid Build Coastguard Worker scipy.stats.gumbel_r(loc=loc, scale=scale), 3558*da0073e9SAndroid Build Coastguard Worker f"Gumbel(loc={loc}, scale={scale})", 3559*da0073e9SAndroid Build Coastguard Worker ) 3560*da0073e9SAndroid Build Coastguard Worker 3561*da0073e9SAndroid Build Coastguard Worker def test_kumaraswamy_shape(self): 3562*da0073e9SAndroid Build Coastguard Worker concentration1 = torch.randn(2, 3).abs().requires_grad_() 3563*da0073e9SAndroid Build Coastguard Worker concentration0 = torch.randn(2, 3).abs().requires_grad_() 3564*da0073e9SAndroid Build Coastguard Worker concentration1_1d = torch.randn(1).abs().requires_grad_() 3565*da0073e9SAndroid Build Coastguard Worker concentration0_1d = torch.randn(1).abs().requires_grad_() 3566*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3567*da0073e9SAndroid Build Coastguard Worker Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3) 3568*da0073e9SAndroid Build Coastguard Worker ) 3569*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3570*da0073e9SAndroid Build Coastguard Worker Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3) 3571*da0073e9SAndroid Build Coastguard Worker ) 3572*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3573*da0073e9SAndroid Build Coastguard Worker Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,) 3574*da0073e9SAndroid Build Coastguard Worker ) 3575*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3576*da0073e9SAndroid Build Coastguard Worker Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), 3577*da0073e9SAndroid Build Coastguard Worker (1, 1), 3578*da0073e9SAndroid Build Coastguard Worker ) 3579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ()) 3580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,)) 3581*da0073e9SAndroid Build Coastguard Worker 3582*da0073e9SAndroid Build Coastguard Worker # Kumaraswamy distribution is not implemented in SciPy 3583*da0073e9SAndroid Build Coastguard Worker # Hence these tests are explicit 3584*da0073e9SAndroid Build Coastguard Worker def test_kumaraswamy_mean_variance(self): 3585*da0073e9SAndroid Build Coastguard Worker c1_1 = torch.randn(2, 3).abs().requires_grad_() 3586*da0073e9SAndroid Build Coastguard Worker c0_1 = torch.randn(2, 3).abs().requires_grad_() 3587*da0073e9SAndroid Build Coastguard Worker c1_2 = torch.randn(4).abs().requires_grad_() 3588*da0073e9SAndroid Build Coastguard Worker c0_2 = torch.randn(4).abs().requires_grad_() 3589*da0073e9SAndroid Build Coastguard Worker cases = [(c1_1, c0_1), (c1_2, c0_2)] 3590*da0073e9SAndroid Build Coastguard Worker for i, (a, b) in enumerate(cases): 3591*da0073e9SAndroid Build Coastguard Worker m = Kumaraswamy(a, b) 3592*da0073e9SAndroid Build Coastguard Worker samples = m.sample((60000,)) 3593*da0073e9SAndroid Build Coastguard Worker expected = samples.mean(0) 3594*da0073e9SAndroid Build Coastguard Worker actual = m.mean 3595*da0073e9SAndroid Build Coastguard Worker error = (expected - actual).abs() 3596*da0073e9SAndroid Build Coastguard Worker max_error = max(error[error == error]) 3597*da0073e9SAndroid Build Coastguard Worker self.assertLess( 3598*da0073e9SAndroid Build Coastguard Worker max_error, 3599*da0073e9SAndroid Build Coastguard Worker 0.01, 3600*da0073e9SAndroid Build Coastguard Worker f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .mean", 3601*da0073e9SAndroid Build Coastguard Worker ) 3602*da0073e9SAndroid Build Coastguard Worker expected = samples.var(0) 3603*da0073e9SAndroid Build Coastguard Worker actual = m.variance 3604*da0073e9SAndroid Build Coastguard Worker error = (expected - actual).abs() 3605*da0073e9SAndroid Build Coastguard Worker max_error = max(error[error == error]) 3606*da0073e9SAndroid Build Coastguard Worker self.assertLess( 3607*da0073e9SAndroid Build Coastguard Worker max_error, 3608*da0073e9SAndroid Build Coastguard Worker 0.01, 3609*da0073e9SAndroid Build Coastguard Worker f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .variance", 3610*da0073e9SAndroid Build Coastguard Worker ) 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3613*da0073e9SAndroid Build Coastguard Worker def test_fishersnedecor(self): 3614*da0073e9SAndroid Build Coastguard Worker df1 = torch.randn(2, 3).abs().requires_grad_() 3615*da0073e9SAndroid Build Coastguard Worker df2 = torch.randn(2, 3).abs().requires_grad_() 3616*da0073e9SAndroid Build Coastguard Worker df1_1d = torch.randn(1).abs() 3617*da0073e9SAndroid Build Coastguard Worker df2_1d = torch.randn(1).abs() 3618*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean)) 3619*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance)) 3620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3)) 3621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3)) 3622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,)) 3623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1)) 3624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ()) 3625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,)) 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3628*da0073e9SAndroid Build Coastguard Worker f1 = df1.view(-1)[idx].detach() 3629*da0073e9SAndroid Build Coastguard Worker f2 = df2.view(-1)[idx].detach() 3630*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.f.logpdf(x, f1, f2) 3631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob) 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3636*da0073e9SAndroid Build Coastguard Worker def test_fishersnedecor_sample(self): 3637*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see note [Randomized statistical tests] 3638*da0073e9SAndroid Build Coastguard Worker for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]): 3639*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3640*da0073e9SAndroid Build Coastguard Worker FisherSnedecor(df1, df2), 3641*da0073e9SAndroid Build Coastguard Worker scipy.stats.f(df1, df2), 3642*da0073e9SAndroid Build Coastguard Worker f"FisherSnedecor(loc={df1}, scale={df2})", 3643*da0073e9SAndroid Build Coastguard Worker ) 3644*da0073e9SAndroid Build Coastguard Worker 3645*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3646*da0073e9SAndroid Build Coastguard Worker def test_chi2_shape(self): 3647*da0073e9SAndroid Build Coastguard Worker df = torch.randn(2, 3).exp().requires_grad_() 3648*da0073e9SAndroid Build Coastguard Worker df_1d = torch.randn(1).exp().requires_grad_() 3649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(df).sample().size(), (2, 3)) 3650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3)) 3651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1)) 3652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(df_1d).sample().size(), (1,)) 3653*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3654*da0073e9SAndroid Build Coastguard Worker Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), () 3655*da0073e9SAndroid Build Coastguard Worker ) 3656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(0.5).sample().size(), ()) 3657*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,)) 3658*da0073e9SAndroid Build Coastguard Worker 3659*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3660*da0073e9SAndroid Build Coastguard Worker d = df.view(-1)[idx].detach() 3661*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.chi2.logpdf(x, d) 3662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3663*da0073e9SAndroid Build Coastguard Worker 3664*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(Chi2(df), ref_log_prob) 3665*da0073e9SAndroid Build Coastguard Worker 3666*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3667*da0073e9SAndroid Build Coastguard Worker def test_chi2_sample(self): 3668*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3669*da0073e9SAndroid Build Coastguard Worker for df in [0.1, 1.0, 5.0]: 3670*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3671*da0073e9SAndroid Build Coastguard Worker Chi2(df), scipy.stats.chi2(df), f"Chi2(df={df})" 3672*da0073e9SAndroid Build Coastguard Worker ) 3673*da0073e9SAndroid Build Coastguard Worker 3674*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3675*da0073e9SAndroid Build Coastguard Worker def test_studentT(self): 3676*da0073e9SAndroid Build Coastguard Worker df = torch.randn(2, 3).exp().requires_grad_() 3677*da0073e9SAndroid Build Coastguard Worker df_1d = torch.randn(1).exp().requires_grad_() 3678*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(StudentT(1).mean)) 3679*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_all_nan(StudentT(1).variance)) 3680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(2).variance, inf) 3681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(df).sample().size(), (2, 3)) 3682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3)) 3683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1)) 3684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(df_1d).sample().size(), (1,)) 3685*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3686*da0073e9SAndroid Build Coastguard Worker StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), () 3687*da0073e9SAndroid Build Coastguard Worker ) 3688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(0.5).sample().size(), ()) 3689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,)) 3690*da0073e9SAndroid Build Coastguard Worker 3691*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, x, log_prob): 3692*da0073e9SAndroid Build Coastguard Worker d = df.view(-1)[idx].detach() 3693*da0073e9SAndroid Build Coastguard Worker expected = scipy.stats.t.logpdf(x, d) 3694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3695*da0073e9SAndroid Build Coastguard Worker 3696*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(StudentT(df), ref_log_prob) 3697*da0073e9SAndroid Build Coastguard Worker 3698*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3699*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3700*da0073e9SAndroid Build Coastguard Worker def test_studentT_sample(self): 3701*da0073e9SAndroid Build Coastguard Worker set_rng_seed(11) # see Note [Randomized statistical tests] 3702*da0073e9SAndroid Build Coastguard Worker for df, loc, scale in product( 3703*da0073e9SAndroid Build Coastguard Worker [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0] 3704*da0073e9SAndroid Build Coastguard Worker ): 3705*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3706*da0073e9SAndroid Build Coastguard Worker StudentT(df=df, loc=loc, scale=scale), 3707*da0073e9SAndroid Build Coastguard Worker scipy.stats.t(df=df, loc=loc, scale=scale), 3708*da0073e9SAndroid Build Coastguard Worker f"StudentT(df={df}, loc={loc}, scale={scale})", 3709*da0073e9SAndroid Build Coastguard Worker ) 3710*da0073e9SAndroid Build Coastguard Worker 3711*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3712*da0073e9SAndroid Build Coastguard Worker def test_studentT_log_prob(self): 3713*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3714*da0073e9SAndroid Build Coastguard Worker num_samples = 10 3715*da0073e9SAndroid Build Coastguard Worker for df, loc, scale in product( 3716*da0073e9SAndroid Build Coastguard Worker [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0] 3717*da0073e9SAndroid Build Coastguard Worker ): 3718*da0073e9SAndroid Build Coastguard Worker dist = StudentT(df=df, loc=loc, scale=scale) 3719*da0073e9SAndroid Build Coastguard Worker x = dist.sample((num_samples,)) 3720*da0073e9SAndroid Build Coastguard Worker actual_log_prob = dist.log_prob(x) 3721*da0073e9SAndroid Build Coastguard Worker for i in range(num_samples): 3722*da0073e9SAndroid Build Coastguard Worker expected_log_prob = scipy.stats.t.logpdf( 3723*da0073e9SAndroid Build Coastguard Worker x[i], df=df, loc=loc, scale=scale 3724*da0073e9SAndroid Build Coastguard Worker ) 3725*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3726*da0073e9SAndroid Build Coastguard Worker float(actual_log_prob[i]), 3727*da0073e9SAndroid Build Coastguard Worker float(expected_log_prob), 3728*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 3729*da0073e9SAndroid Build Coastguard Worker rtol=0, 3730*da0073e9SAndroid Build Coastguard Worker ) 3731*da0073e9SAndroid Build Coastguard Worker 3732*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_shape(self): 3733*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(2, 3).exp().requires_grad_() 3734*da0073e9SAndroid Build Coastguard Worker alpha_1d = torch.randn(4).exp().requires_grad_() 3735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3)) 3736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3)) 3737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,)) 3738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4)) 3739*da0073e9SAndroid Build Coastguard Worker 3740*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3741*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3742*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_log_prob(self): 3743*da0073e9SAndroid Build Coastguard Worker num_samples = 10 3744*da0073e9SAndroid Build Coastguard Worker alpha = torch.exp(torch.randn(5)) 3745*da0073e9SAndroid Build Coastguard Worker dist = Dirichlet(alpha) 3746*da0073e9SAndroid Build Coastguard Worker x = dist.sample((num_samples,)) 3747*da0073e9SAndroid Build Coastguard Worker actual_log_prob = dist.log_prob(x) 3748*da0073e9SAndroid Build Coastguard Worker for i in range(num_samples): 3749*da0073e9SAndroid Build Coastguard Worker expected_log_prob = scipy.stats.dirichlet.logpdf( 3750*da0073e9SAndroid Build Coastguard Worker x[i].numpy(), alpha.numpy() 3751*da0073e9SAndroid Build Coastguard Worker ) 3752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0) 3753*da0073e9SAndroid Build Coastguard Worker 3754*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3755*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_log_prob_zero(self): 3756*da0073e9SAndroid Build Coastguard Worker # Specifically test the special case where x=0 and alpha=1. The PDF is 3757*da0073e9SAndroid Build Coastguard Worker # proportional to x**(alpha-1), which in this case works out to 0**0=1. 3758*da0073e9SAndroid Build Coastguard Worker # The log PDF of this term should therefore be 0. However, it's easy 3759*da0073e9SAndroid Build Coastguard Worker # to accidentally introduce NaNs by calculating log(x) without regard 3760*da0073e9SAndroid Build Coastguard Worker # for the value of alpha-1. 3761*da0073e9SAndroid Build Coastguard Worker alpha = torch.tensor([1, 2]) 3762*da0073e9SAndroid Build Coastguard Worker dist = Dirichlet(alpha) 3763*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1]) 3764*da0073e9SAndroid Build Coastguard Worker actual_log_prob = dist.log_prob(x) 3765*da0073e9SAndroid Build Coastguard Worker expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy()) 3766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0) 3767*da0073e9SAndroid Build Coastguard Worker 3768*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3769*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_sample(self): 3770*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 3771*da0073e9SAndroid Build Coastguard Worker alpha = torch.exp(torch.randn(3)) 3772*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3773*da0073e9SAndroid Build Coastguard Worker Dirichlet(alpha), 3774*da0073e9SAndroid Build Coastguard Worker scipy.stats.dirichlet(alpha.numpy()), 3775*da0073e9SAndroid Build Coastguard Worker f"Dirichlet(alpha={list(alpha)})", 3776*da0073e9SAndroid Build Coastguard Worker multivariate=True, 3777*da0073e9SAndroid Build Coastguard Worker ) 3778*da0073e9SAndroid Build Coastguard Worker 3779*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_mode(self): 3780*da0073e9SAndroid Build Coastguard Worker # Test a few edge cases for the Dirichlet distribution mode. This also covers beta distributions. 3781*da0073e9SAndroid Build Coastguard Worker concentrations_and_modes = [ 3782*da0073e9SAndroid Build Coastguard Worker ([2, 2, 1], [0.5, 0.5, 0.0]), 3783*da0073e9SAndroid Build Coastguard Worker ([3, 2, 1], [2 / 3, 1 / 3, 0]), 3784*da0073e9SAndroid Build Coastguard Worker ([0.5, 0.2, 0.2], [1.0, 0.0, 0.0]), 3785*da0073e9SAndroid Build Coastguard Worker ([1, 1, 1], [nan, nan, nan]), 3786*da0073e9SAndroid Build Coastguard Worker ] 3787*da0073e9SAndroid Build Coastguard Worker for concentration, mode in concentrations_and_modes: 3788*da0073e9SAndroid Build Coastguard Worker dist = Dirichlet(torch.tensor(concentration)) 3789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.mode, torch.tensor(mode)) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker def test_beta_shape(self): 3792*da0073e9SAndroid Build Coastguard Worker con1 = torch.randn(2, 3).exp().requires_grad_() 3793*da0073e9SAndroid Build Coastguard Worker con0 = torch.randn(2, 3).exp().requires_grad_() 3794*da0073e9SAndroid Build Coastguard Worker con1_1d = torch.randn(4).exp().requires_grad_() 3795*da0073e9SAndroid Build Coastguard Worker con0_1d = torch.randn(4).exp().requires_grad_() 3796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(con1, con0).sample().size(), (2, 3)) 3797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3)) 3798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,)) 3799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4)) 3800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(0.1, 0.3).sample().size(), ()) 3801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,)) 3802*da0073e9SAndroid Build Coastguard Worker 3803*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3804*da0073e9SAndroid Build Coastguard Worker def test_beta_log_prob(self): 3805*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 3806*da0073e9SAndroid Build Coastguard Worker con1 = np.exp(np.random.normal()) 3807*da0073e9SAndroid Build Coastguard Worker con0 = np.exp(np.random.normal()) 3808*da0073e9SAndroid Build Coastguard Worker dist = Beta(con1, con0) 3809*da0073e9SAndroid Build Coastguard Worker x = dist.sample() 3810*da0073e9SAndroid Build Coastguard Worker actual_log_prob = dist.log_prob(x).sum() 3811*da0073e9SAndroid Build Coastguard Worker expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0) 3812*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3813*da0073e9SAndroid Build Coastguard Worker float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0 3814*da0073e9SAndroid Build Coastguard Worker ) 3815*da0073e9SAndroid Build Coastguard Worker 3816*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3817*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3818*da0073e9SAndroid Build Coastguard Worker def test_beta_sample(self): 3819*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) # see Note [Randomized statistical tests] 3820*da0073e9SAndroid Build Coastguard Worker for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]): 3821*da0073e9SAndroid Build Coastguard Worker self._check_sampler_sampler( 3822*da0073e9SAndroid Build Coastguard Worker Beta(con1, con0), 3823*da0073e9SAndroid Build Coastguard Worker scipy.stats.beta(con1, con0), 3824*da0073e9SAndroid Build Coastguard Worker f"Beta(alpha={con1}, beta={con0})", 3825*da0073e9SAndroid Build Coastguard Worker ) 3826*da0073e9SAndroid Build Coastguard Worker # Check that small alphas do not cause NANs. 3827*da0073e9SAndroid Build Coastguard Worker for Tensor in [torch.FloatTensor, torch.DoubleTensor]: 3828*da0073e9SAndroid Build Coastguard Worker x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0] 3829*da0073e9SAndroid Build Coastguard Worker self.assertTrue(np.isfinite(x) and x > 0, f"Invalid Beta.sample(): {x}") 3830*da0073e9SAndroid Build Coastguard Worker 3831*da0073e9SAndroid Build Coastguard Worker def test_beta_underflow(self): 3832*da0073e9SAndroid Build Coastguard Worker # For low values of (alpha, beta), the gamma samples can underflow 3833*da0073e9SAndroid Build Coastguard Worker # with float32 and result in a spurious mode at 0.5. To prevent this, 3834*da0073e9SAndroid Build Coastguard Worker # torch._sample_dirichlet works with double precision for intermediate 3835*da0073e9SAndroid Build Coastguard Worker # calculations. 3836*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 3837*da0073e9SAndroid Build Coastguard Worker num_samples = 50000 3838*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 3839*da0073e9SAndroid Build Coastguard Worker conc = torch.tensor(1e-2, dtype=dtype) 3840*da0073e9SAndroid Build Coastguard Worker beta_samples = Beta(conc, conc).sample([num_samples]) 3841*da0073e9SAndroid Build Coastguard Worker self.assertEqual((beta_samples == 0).sum(), 0) 3842*da0073e9SAndroid Build Coastguard Worker self.assertEqual((beta_samples == 1).sum(), 0) 3843*da0073e9SAndroid Build Coastguard Worker # assert support is concentrated around 0 and 1 3844*da0073e9SAndroid Build Coastguard Worker frac_zeros = float((beta_samples < 0.1).sum()) / num_samples 3845*da0073e9SAndroid Build Coastguard Worker frac_ones = float((beta_samples > 0.9).sum()) / num_samples 3846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0) 3847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0) 3848*da0073e9SAndroid Build Coastguard Worker 3849*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3850*da0073e9SAndroid Build Coastguard Worker def test_beta_underflow_gpu(self): 3851*da0073e9SAndroid Build Coastguard Worker set_rng_seed(1) 3852*da0073e9SAndroid Build Coastguard Worker num_samples = 50000 3853*da0073e9SAndroid Build Coastguard Worker conc = torch.tensor(1e-2, dtype=torch.float64).cuda() 3854*da0073e9SAndroid Build Coastguard Worker beta_samples = Beta(conc, conc).sample([num_samples]) 3855*da0073e9SAndroid Build Coastguard Worker self.assertEqual((beta_samples == 0).sum(), 0) 3856*da0073e9SAndroid Build Coastguard Worker self.assertEqual((beta_samples == 1).sum(), 0) 3857*da0073e9SAndroid Build Coastguard Worker # assert support is concentrated around 0 and 1 3858*da0073e9SAndroid Build Coastguard Worker frac_zeros = float((beta_samples < 0.1).sum()) / num_samples 3859*da0073e9SAndroid Build Coastguard Worker frac_ones = float((beta_samples > 0.9).sum()) / num_samples 3860*da0073e9SAndroid Build Coastguard Worker # TODO: increase precision once imbalance on GPU is fixed. 3861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0) 3862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0) 3863*da0073e9SAndroid Build Coastguard Worker 3864*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 3865*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli(self): 3866*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 3867*da0073e9SAndroid Build Coastguard Worker r = torch.tensor(0.3, requires_grad=True) 3868*da0073e9SAndroid Build Coastguard Worker s = 0.3 3869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3)) 3870*da0073e9SAndroid Build Coastguard Worker self.assertFalse(ContinuousBernoulli(p).sample().requires_grad) 3871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,)) 3872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(r).sample().size(), ()) 3873*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3874*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(r).sample((3, 2)).size(), 3875*da0073e9SAndroid Build Coastguard Worker ( 3876*da0073e9SAndroid Build Coastguard Worker 3, 3877*da0073e9SAndroid Build Coastguard Worker 2, 3878*da0073e9SAndroid Build Coastguard Worker ), 3879*da0073e9SAndroid Build Coastguard Worker ) 3880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(s).sample().size(), ()) 3881*da0073e9SAndroid Build Coastguard Worker self._gradcheck_log_prob(ContinuousBernoulli, (p,)) 3882*da0073e9SAndroid Build Coastguard Worker 3883*da0073e9SAndroid Build Coastguard Worker def ref_log_prob(idx, val, log_prob): 3884*da0073e9SAndroid Build Coastguard Worker prob = p[idx] 3885*da0073e9SAndroid Build Coastguard Worker if prob > 0.499 and prob < 0.501: # using default value of lim here 3886*da0073e9SAndroid Build Coastguard Worker log_norm_const = ( 3887*da0073e9SAndroid Build Coastguard Worker math.log(2.0) 3888*da0073e9SAndroid Build Coastguard Worker + 4.0 / 3.0 * math.pow(prob - 0.5, 2) 3889*da0073e9SAndroid Build Coastguard Worker + 104.0 / 45.0 * math.pow(prob - 0.5, 4) 3890*da0073e9SAndroid Build Coastguard Worker ) 3891*da0073e9SAndroid Build Coastguard Worker else: 3892*da0073e9SAndroid Build Coastguard Worker log_norm_const = math.log( 3893*da0073e9SAndroid Build Coastguard Worker 2.0 * math.atanh(1.0 - 2.0 * prob) / (1.0 - 2.0 * prob) 3894*da0073e9SAndroid Build Coastguard Worker ) 3895*da0073e9SAndroid Build Coastguard Worker res = ( 3896*da0073e9SAndroid Build Coastguard Worker val * math.log(prob) + (1.0 - val) * math.log1p(-prob) + log_norm_const 3897*da0073e9SAndroid Build Coastguard Worker ) 3898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_prob, res) 3899*da0073e9SAndroid Build Coastguard Worker 3900*da0073e9SAndroid Build Coastguard Worker self._check_log_prob(ContinuousBernoulli(p), ref_log_prob) 3901*da0073e9SAndroid Build Coastguard Worker self._check_log_prob( 3902*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob 3903*da0073e9SAndroid Build Coastguard Worker ) 3904*da0073e9SAndroid Build Coastguard Worker 3905*da0073e9SAndroid Build Coastguard Worker # check entropy computation 3906*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3907*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(p).entropy(), 3908*da0073e9SAndroid Build Coastguard Worker torch.tensor([-0.02938, -0.07641, -0.00682]), 3909*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 3910*da0073e9SAndroid Build Coastguard Worker rtol=0, 3911*da0073e9SAndroid Build Coastguard Worker ) 3912*da0073e9SAndroid Build Coastguard Worker # entropy below corresponds to the clamped value of prob when using float 64 3913*da0073e9SAndroid Build Coastguard Worker # the value for float32 should be -1.76898 3914*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3915*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(torch.tensor([0.0])).entropy(), 3916*da0073e9SAndroid Build Coastguard Worker torch.tensor([-2.58473]), 3917*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 3918*da0073e9SAndroid Build Coastguard Worker rtol=0, 3919*da0073e9SAndroid Build Coastguard Worker ) 3920*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3921*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4, rtol=0 3922*da0073e9SAndroid Build Coastguard Worker ) 3923*da0073e9SAndroid Build Coastguard Worker 3924*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_3d(self): 3925*da0073e9SAndroid Build Coastguard Worker p = torch.full((2, 3, 5), 0.5).requires_grad_() 3926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5)) 3927*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3928*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5) 3929*da0073e9SAndroid Build Coastguard Worker ) 3930*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Worker def test_lkj_cholesky_log_prob(self): 3933*da0073e9SAndroid Build Coastguard Worker def tril_cholesky_to_tril_corr(x): 3934*da0073e9SAndroid Build Coastguard Worker x = vec_to_tril_matrix(x, -1) 3935*da0073e9SAndroid Build Coastguard Worker diag = (1 - (x * x).sum(-1)).sqrt().diag_embed() 3936*da0073e9SAndroid Build Coastguard Worker x = x + diag 3937*da0073e9SAndroid Build Coastguard Worker return tril_matrix_to_vec(x @ x.T, -1) 3938*da0073e9SAndroid Build Coastguard Worker 3939*da0073e9SAndroid Build Coastguard Worker for dim in range(2, 5): 3940*da0073e9SAndroid Build Coastguard Worker log_probs = [] 3941*da0073e9SAndroid Build Coastguard Worker lkj = LKJCholesky(dim, concentration=1.0, validate_args=True) 3942*da0073e9SAndroid Build Coastguard Worker for i in range(2): 3943*da0073e9SAndroid Build Coastguard Worker sample = lkj.sample() 3944*da0073e9SAndroid Build Coastguard Worker sample_tril = tril_matrix_to_vec(sample, diag=-1) 3945*da0073e9SAndroid Build Coastguard Worker log_prob = lkj.log_prob(sample) 3946*da0073e9SAndroid Build Coastguard Worker log_abs_det_jacobian = torch.slogdet( 3947*da0073e9SAndroid Build Coastguard Worker jacobian(tril_cholesky_to_tril_corr, sample_tril) 3948*da0073e9SAndroid Build Coastguard Worker ).logabsdet 3949*da0073e9SAndroid Build Coastguard Worker log_probs.append(log_prob - log_abs_det_jacobian) 3950*da0073e9SAndroid Build Coastguard Worker # for concentration=1., the density is uniform over the space of all 3951*da0073e9SAndroid Build Coastguard Worker # correlation matrices. 3952*da0073e9SAndroid Build Coastguard Worker if dim == 2: 3953*da0073e9SAndroid Build Coastguard Worker # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.) 3954*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 3955*da0073e9SAndroid Build Coastguard Worker all( 3956*da0073e9SAndroid Build Coastguard Worker torch.allclose(x, torch.tensor(0.5).log(), atol=1e-10) 3957*da0073e9SAndroid Build Coastguard Worker for x in log_probs 3958*da0073e9SAndroid Build Coastguard Worker ) 3959*da0073e9SAndroid Build Coastguard Worker ) 3960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_probs[0], log_probs[1]) 3961*da0073e9SAndroid Build Coastguard Worker invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0) 3962*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample)) 3963*da0073e9SAndroid Build Coastguard Worker 3964*da0073e9SAndroid Build Coastguard Worker def test_independent_shape(self): 3965*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 3966*da0073e9SAndroid Build Coastguard Worker for param in params: 3967*da0073e9SAndroid Build Coastguard Worker base_dist = Dist(**param) 3968*da0073e9SAndroid Build Coastguard Worker x = base_dist.sample() 3969*da0073e9SAndroid Build Coastguard Worker base_log_prob_shape = base_dist.log_prob(x).shape 3970*da0073e9SAndroid Build Coastguard Worker for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1): 3971*da0073e9SAndroid Build Coastguard Worker indep_dist = Independent(base_dist, reinterpreted_batch_ndims) 3972*da0073e9SAndroid Build Coastguard Worker indep_log_prob_shape = base_log_prob_shape[ 3973*da0073e9SAndroid Build Coastguard Worker : len(base_log_prob_shape) - reinterpreted_batch_ndims 3974*da0073e9SAndroid Build Coastguard Worker ] 3975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape) 3976*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3977*da0073e9SAndroid Build Coastguard Worker indep_dist.sample().shape, base_dist.sample().shape 3978*da0073e9SAndroid Build Coastguard Worker ) 3979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) 3980*da0073e9SAndroid Build Coastguard Worker if indep_dist.has_rsample: 3981*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3982*da0073e9SAndroid Build Coastguard Worker indep_dist.sample().shape, base_dist.sample().shape 3983*da0073e9SAndroid Build Coastguard Worker ) 3984*da0073e9SAndroid Build Coastguard Worker try: 3985*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3986*da0073e9SAndroid Build Coastguard Worker indep_dist.enumerate_support().shape, 3987*da0073e9SAndroid Build Coastguard Worker base_dist.enumerate_support().shape, 3988*da0073e9SAndroid Build Coastguard Worker ) 3989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) 3990*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 3991*da0073e9SAndroid Build Coastguard Worker pass 3992*da0073e9SAndroid Build Coastguard Worker try: 3993*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3994*da0073e9SAndroid Build Coastguard Worker indep_dist.variance.shape, base_dist.variance.shape 3995*da0073e9SAndroid Build Coastguard Worker ) 3996*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 3997*da0073e9SAndroid Build Coastguard Worker pass 3998*da0073e9SAndroid Build Coastguard Worker try: 3999*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4000*da0073e9SAndroid Build Coastguard Worker indep_dist.entropy().shape, indep_log_prob_shape 4001*da0073e9SAndroid Build Coastguard Worker ) 4002*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 4003*da0073e9SAndroid Build Coastguard Worker pass 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker def test_independent_expand(self): 4006*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 4007*da0073e9SAndroid Build Coastguard Worker for param in params: 4008*da0073e9SAndroid Build Coastguard Worker base_dist = Dist(**param) 4009*da0073e9SAndroid Build Coastguard Worker for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1): 4010*da0073e9SAndroid Build Coastguard Worker for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]: 4011*da0073e9SAndroid Build Coastguard Worker indep_dist = Independent(base_dist, reinterpreted_batch_ndims) 4012*da0073e9SAndroid Build Coastguard Worker expanded_shape = s + indep_dist.batch_shape 4013*da0073e9SAndroid Build Coastguard Worker expanded = indep_dist.expand(expanded_shape) 4014*da0073e9SAndroid Build Coastguard Worker expanded_sample = expanded.sample() 4015*da0073e9SAndroid Build Coastguard Worker expected_shape = expanded_shape + indep_dist.event_shape 4016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded_sample.shape, expected_shape) 4017*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4018*da0073e9SAndroid Build Coastguard Worker expanded.log_prob(expanded_sample), 4019*da0073e9SAndroid Build Coastguard Worker indep_dist.log_prob(expanded_sample), 4020*da0073e9SAndroid Build Coastguard Worker ) 4021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.event_shape, indep_dist.event_shape) 4022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.batch_shape, expanded_shape) 4023*da0073e9SAndroid Build Coastguard Worker 4024*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4025*da0073e9SAndroid Build Coastguard Worker def test_cdf_icdf_inverse(self): 4026*da0073e9SAndroid Build Coastguard Worker # Tests the invertibility property on the distributions 4027*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 4028*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 4029*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 4030*da0073e9SAndroid Build Coastguard Worker samples = dist.sample(sample_shape=(20,)) 4031*da0073e9SAndroid Build Coastguard Worker try: 4032*da0073e9SAndroid Build Coastguard Worker cdf = dist.cdf(samples) 4033*da0073e9SAndroid Build Coastguard Worker actual = dist.icdf(cdf) 4034*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 4035*da0073e9SAndroid Build Coastguard Worker continue 4036*da0073e9SAndroid Build Coastguard Worker rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples)) 4037*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4038*da0073e9SAndroid Build Coastguard Worker rel_error.max(), 4039*da0073e9SAndroid Build Coastguard Worker 1e-4, 4040*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 4041*da0073e9SAndroid Build Coastguard Worker [ 4042*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)}, icdf(cdf(x)) != x", 4043*da0073e9SAndroid Build Coastguard Worker f"x = {samples}", 4044*da0073e9SAndroid Build Coastguard Worker f"cdf(x) = {cdf}", 4045*da0073e9SAndroid Build Coastguard Worker f"icdf(cdf(x)) = {actual}", 4046*da0073e9SAndroid Build Coastguard Worker ] 4047*da0073e9SAndroid Build Coastguard Worker ), 4048*da0073e9SAndroid Build Coastguard Worker ) 4049*da0073e9SAndroid Build Coastguard Worker 4050*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4051*da0073e9SAndroid Build Coastguard Worker def test_gamma_log_prob_at_boundary(self): 4052*da0073e9SAndroid Build Coastguard Worker for concentration, log_prob in [(0.5, inf), (1, 0), (2, -inf)]: 4053*da0073e9SAndroid Build Coastguard Worker dist = Gamma(concentration, 1) 4054*da0073e9SAndroid Build Coastguard Worker scipy_dist = scipy.stats.gamma(concentration) 4055*da0073e9SAndroid Build Coastguard Worker self.assertAlmostEqual(dist.log_prob(0), log_prob) 4056*da0073e9SAndroid Build Coastguard Worker self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0)) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4059*da0073e9SAndroid Build Coastguard Worker def test_cdf_log_prob(self): 4060*da0073e9SAndroid Build Coastguard Worker # Tests if the differentiation of the CDF gives the PDF at a given value 4061*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 4062*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 4063*da0073e9SAndroid Build Coastguard Worker # We do not need grads wrt params here, e.g. shape of gamma distribution. 4064*da0073e9SAndroid Build Coastguard Worker param = { 4065*da0073e9SAndroid Build Coastguard Worker key: value.detach() if isinstance(value, torch.Tensor) else value 4066*da0073e9SAndroid Build Coastguard Worker for key, value in param.items() 4067*da0073e9SAndroid Build Coastguard Worker } 4068*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 4069*da0073e9SAndroid Build Coastguard Worker samples = dist.sample() 4070*da0073e9SAndroid Build Coastguard Worker if not dist.support.is_discrete: 4071*da0073e9SAndroid Build Coastguard Worker samples.requires_grad_() 4072*da0073e9SAndroid Build Coastguard Worker try: 4073*da0073e9SAndroid Build Coastguard Worker cdfs = dist.cdf(samples) 4074*da0073e9SAndroid Build Coastguard Worker pdfs = dist.log_prob(samples).exp() 4075*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 4076*da0073e9SAndroid Build Coastguard Worker continue 4077*da0073e9SAndroid Build Coastguard Worker cdfs_derivative = grad(cdfs.sum(), [samples])[ 4078*da0073e9SAndroid Build Coastguard Worker 0 4079*da0073e9SAndroid Build Coastguard Worker ] # this should not be wrapped in torch.abs() 4080*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4081*da0073e9SAndroid Build Coastguard Worker cdfs_derivative, 4082*da0073e9SAndroid Build Coastguard Worker pdfs, 4083*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 4084*da0073e9SAndroid Build Coastguard Worker [ 4085*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)}, d(cdf)/dx != pdf(x)", 4086*da0073e9SAndroid Build Coastguard Worker f"x = {samples}", 4087*da0073e9SAndroid Build Coastguard Worker f"cdf = {cdfs}", 4088*da0073e9SAndroid Build Coastguard Worker f"pdf = {pdfs}", 4089*da0073e9SAndroid Build Coastguard Worker f"grad(cdf) = {cdfs_derivative}", 4090*da0073e9SAndroid Build Coastguard Worker ] 4091*da0073e9SAndroid Build Coastguard Worker ), 4092*da0073e9SAndroid Build Coastguard Worker ) 4093*da0073e9SAndroid Build Coastguard Worker 4094*da0073e9SAndroid Build Coastguard Worker def test_valid_parameter_broadcasting(self): 4095*da0073e9SAndroid Build Coastguard Worker # Test correct broadcasting of parameter sizes for distributions that have multiple 4096*da0073e9SAndroid Build Coastguard Worker # parameters. 4097*da0073e9SAndroid Build Coastguard Worker # example type (distribution instance, expected sample shape) 4098*da0073e9SAndroid Build Coastguard Worker valid_examples = [ 4099*da0073e9SAndroid Build Coastguard Worker (Normal(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4100*da0073e9SAndroid Build Coastguard Worker (Normal(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4101*da0073e9SAndroid Build Coastguard Worker (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4102*da0073e9SAndroid Build Coastguard Worker ( 4103*da0073e9SAndroid Build Coastguard Worker Normal( 4104*da0073e9SAndroid Build Coastguard Worker loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4105*da0073e9SAndroid Build Coastguard Worker ), 4106*da0073e9SAndroid Build Coastguard Worker (2, 2), 4107*da0073e9SAndroid Build Coastguard Worker ), 4108*da0073e9SAndroid Build Coastguard Worker (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)), 4109*da0073e9SAndroid Build Coastguard Worker (Normal(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4110*da0073e9SAndroid Build Coastguard Worker (FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=1), (2,)), 4111*da0073e9SAndroid Build Coastguard Worker (FisherSnedecor(df1=1, df2=torch.tensor([1.0, 1.0])), (2,)), 4112*da0073e9SAndroid Build Coastguard Worker ( 4113*da0073e9SAndroid Build Coastguard Worker FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([1.0])), 4114*da0073e9SAndroid Build Coastguard Worker (2,), 4115*da0073e9SAndroid Build Coastguard Worker ), 4116*da0073e9SAndroid Build Coastguard Worker ( 4117*da0073e9SAndroid Build Coastguard Worker FisherSnedecor( 4118*da0073e9SAndroid Build Coastguard Worker df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0], [1.0]]) 4119*da0073e9SAndroid Build Coastguard Worker ), 4120*da0073e9SAndroid Build Coastguard Worker (2, 2), 4121*da0073e9SAndroid Build Coastguard Worker ), 4122*da0073e9SAndroid Build Coastguard Worker ( 4123*da0073e9SAndroid Build Coastguard Worker FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0]])), 4124*da0073e9SAndroid Build Coastguard Worker (1, 2), 4125*da0073e9SAndroid Build Coastguard Worker ), 4126*da0073e9SAndroid Build Coastguard Worker ( 4127*da0073e9SAndroid Build Coastguard Worker FisherSnedecor(df1=torch.tensor([1.0]), df2=torch.tensor([[1.0]])), 4128*da0073e9SAndroid Build Coastguard Worker (1, 1), 4129*da0073e9SAndroid Build Coastguard Worker ), 4130*da0073e9SAndroid Build Coastguard Worker (Gamma(concentration=torch.tensor([1.0, 1.0]), rate=1), (2,)), 4131*da0073e9SAndroid Build Coastguard Worker (Gamma(concentration=1, rate=torch.tensor([1.0, 1.0])), (2,)), 4132*da0073e9SAndroid Build Coastguard Worker ( 4133*da0073e9SAndroid Build Coastguard Worker Gamma( 4134*da0073e9SAndroid Build Coastguard Worker concentration=torch.tensor([1.0, 1.0]), 4135*da0073e9SAndroid Build Coastguard Worker rate=torch.tensor([[1.0], [1.0], [1.0]]), 4136*da0073e9SAndroid Build Coastguard Worker ), 4137*da0073e9SAndroid Build Coastguard Worker (3, 2), 4138*da0073e9SAndroid Build Coastguard Worker ), 4139*da0073e9SAndroid Build Coastguard Worker ( 4140*da0073e9SAndroid Build Coastguard Worker Gamma( 4141*da0073e9SAndroid Build Coastguard Worker concentration=torch.tensor([1.0, 1.0]), 4142*da0073e9SAndroid Build Coastguard Worker rate=torch.tensor([[1.0], [1.0]]), 4143*da0073e9SAndroid Build Coastguard Worker ), 4144*da0073e9SAndroid Build Coastguard Worker (2, 2), 4145*da0073e9SAndroid Build Coastguard Worker ), 4146*da0073e9SAndroid Build Coastguard Worker ( 4147*da0073e9SAndroid Build Coastguard Worker Gamma( 4148*da0073e9SAndroid Build Coastguard Worker concentration=torch.tensor([1.0, 1.0]), rate=torch.tensor([[1.0]]) 4149*da0073e9SAndroid Build Coastguard Worker ), 4150*da0073e9SAndroid Build Coastguard Worker (1, 2), 4151*da0073e9SAndroid Build Coastguard Worker ), 4152*da0073e9SAndroid Build Coastguard Worker ( 4153*da0073e9SAndroid Build Coastguard Worker Gamma(concentration=torch.tensor([1.0]), rate=torch.tensor([[1.0]])), 4154*da0073e9SAndroid Build Coastguard Worker (1, 1), 4155*da0073e9SAndroid Build Coastguard Worker ), 4156*da0073e9SAndroid Build Coastguard Worker (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4157*da0073e9SAndroid Build Coastguard Worker (Gumbel(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4158*da0073e9SAndroid Build Coastguard Worker (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4159*da0073e9SAndroid Build Coastguard Worker ( 4160*da0073e9SAndroid Build Coastguard Worker Gumbel( 4161*da0073e9SAndroid Build Coastguard Worker loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4162*da0073e9SAndroid Build Coastguard Worker ), 4163*da0073e9SAndroid Build Coastguard Worker (2, 2), 4164*da0073e9SAndroid Build Coastguard Worker ), 4165*da0073e9SAndroid Build Coastguard Worker (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)), 4166*da0073e9SAndroid Build Coastguard Worker (Gumbel(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4167*da0073e9SAndroid Build Coastguard Worker ( 4168*da0073e9SAndroid Build Coastguard Worker Kumaraswamy( 4169*da0073e9SAndroid Build Coastguard Worker concentration1=torch.tensor([1.0, 1.0]), concentration0=1.0 4170*da0073e9SAndroid Build Coastguard Worker ), 4171*da0073e9SAndroid Build Coastguard Worker (2,), 4172*da0073e9SAndroid Build Coastguard Worker ), 4173*da0073e9SAndroid Build Coastguard Worker ( 4174*da0073e9SAndroid Build Coastguard Worker Kumaraswamy(concentration1=1, concentration0=torch.tensor([1.0, 1.0])), 4175*da0073e9SAndroid Build Coastguard Worker (2,), 4176*da0073e9SAndroid Build Coastguard Worker ), 4177*da0073e9SAndroid Build Coastguard Worker ( 4178*da0073e9SAndroid Build Coastguard Worker Kumaraswamy( 4179*da0073e9SAndroid Build Coastguard Worker concentration1=torch.tensor([1.0, 1.0]), 4180*da0073e9SAndroid Build Coastguard Worker concentration0=torch.tensor([1.0]), 4181*da0073e9SAndroid Build Coastguard Worker ), 4182*da0073e9SAndroid Build Coastguard Worker (2,), 4183*da0073e9SAndroid Build Coastguard Worker ), 4184*da0073e9SAndroid Build Coastguard Worker ( 4185*da0073e9SAndroid Build Coastguard Worker Kumaraswamy( 4186*da0073e9SAndroid Build Coastguard Worker concentration1=torch.tensor([1.0, 1.0]), 4187*da0073e9SAndroid Build Coastguard Worker concentration0=torch.tensor([[1.0], [1.0]]), 4188*da0073e9SAndroid Build Coastguard Worker ), 4189*da0073e9SAndroid Build Coastguard Worker (2, 2), 4190*da0073e9SAndroid Build Coastguard Worker ), 4191*da0073e9SAndroid Build Coastguard Worker ( 4192*da0073e9SAndroid Build Coastguard Worker Kumaraswamy( 4193*da0073e9SAndroid Build Coastguard Worker concentration1=torch.tensor([1.0, 1.0]), 4194*da0073e9SAndroid Build Coastguard Worker concentration0=torch.tensor([[1.0]]), 4195*da0073e9SAndroid Build Coastguard Worker ), 4196*da0073e9SAndroid Build Coastguard Worker (1, 2), 4197*da0073e9SAndroid Build Coastguard Worker ), 4198*da0073e9SAndroid Build Coastguard Worker ( 4199*da0073e9SAndroid Build Coastguard Worker Kumaraswamy( 4200*da0073e9SAndroid Build Coastguard Worker concentration1=torch.tensor([1.0]), 4201*da0073e9SAndroid Build Coastguard Worker concentration0=torch.tensor([[1.0]]), 4202*da0073e9SAndroid Build Coastguard Worker ), 4203*da0073e9SAndroid Build Coastguard Worker (1, 1), 4204*da0073e9SAndroid Build Coastguard Worker ), 4205*da0073e9SAndroid Build Coastguard Worker (Laplace(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4206*da0073e9SAndroid Build Coastguard Worker (Laplace(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4207*da0073e9SAndroid Build Coastguard Worker (Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4208*da0073e9SAndroid Build Coastguard Worker ( 4209*da0073e9SAndroid Build Coastguard Worker Laplace( 4210*da0073e9SAndroid Build Coastguard Worker loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4211*da0073e9SAndroid Build Coastguard Worker ), 4212*da0073e9SAndroid Build Coastguard Worker (2, 2), 4213*da0073e9SAndroid Build Coastguard Worker ), 4214*da0073e9SAndroid Build Coastguard Worker ( 4215*da0073e9SAndroid Build Coastguard Worker Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), 4216*da0073e9SAndroid Build Coastguard Worker (1, 2), 4217*da0073e9SAndroid Build Coastguard Worker ), 4218*da0073e9SAndroid Build Coastguard Worker (Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4219*da0073e9SAndroid Build Coastguard Worker (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=1), (2,)), 4220*da0073e9SAndroid Build Coastguard Worker (Pareto(scale=1, alpha=torch.tensor([1.0, 1.0])), (2,)), 4221*da0073e9SAndroid Build Coastguard Worker (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([1.0])), (2,)), 4222*da0073e9SAndroid Build Coastguard Worker ( 4223*da0073e9SAndroid Build Coastguard Worker Pareto( 4224*da0073e9SAndroid Build Coastguard Worker scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0], [1.0]]) 4225*da0073e9SAndroid Build Coastguard Worker ), 4226*da0073e9SAndroid Build Coastguard Worker (2, 2), 4227*da0073e9SAndroid Build Coastguard Worker ), 4228*da0073e9SAndroid Build Coastguard Worker ( 4229*da0073e9SAndroid Build Coastguard Worker Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0]])), 4230*da0073e9SAndroid Build Coastguard Worker (1, 2), 4231*da0073e9SAndroid Build Coastguard Worker ), 4232*da0073e9SAndroid Build Coastguard Worker (Pareto(scale=torch.tensor([1.0]), alpha=torch.tensor([[1.0]])), (1, 1)), 4233*da0073e9SAndroid Build Coastguard Worker (StudentT(df=torch.tensor([1.0, 1.0]), loc=1), (2,)), 4234*da0073e9SAndroid Build Coastguard Worker (StudentT(df=1, scale=torch.tensor([1.0, 1.0])), (2,)), 4235*da0073e9SAndroid Build Coastguard Worker (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([1.0])), (2,)), 4236*da0073e9SAndroid Build Coastguard Worker ( 4237*da0073e9SAndroid Build Coastguard Worker StudentT( 4238*da0073e9SAndroid Build Coastguard Worker df=torch.tensor([1.0, 1.0]), scale=torch.tensor([[1.0], [1.0]]) 4239*da0073e9SAndroid Build Coastguard Worker ), 4240*da0073e9SAndroid Build Coastguard Worker (2, 2), 4241*da0073e9SAndroid Build Coastguard Worker ), 4242*da0073e9SAndroid Build Coastguard Worker (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([[1.0]])), (1, 2)), 4243*da0073e9SAndroid Build Coastguard Worker (StudentT(df=torch.tensor([1.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4244*da0073e9SAndroid Build Coastguard Worker (StudentT(df=1.0, loc=torch.zeros(5, 1), scale=torch.ones(3)), (5, 3)), 4245*da0073e9SAndroid Build Coastguard Worker ] 4246*da0073e9SAndroid Build Coastguard Worker 4247*da0073e9SAndroid Build Coastguard Worker for dist, expected_size in valid_examples: 4248*da0073e9SAndroid Build Coastguard Worker actual_size = dist.sample().size() 4249*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4250*da0073e9SAndroid Build Coastguard Worker actual_size, 4251*da0073e9SAndroid Build Coastguard Worker expected_size, 4252*da0073e9SAndroid Build Coastguard Worker msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}", 4253*da0073e9SAndroid Build Coastguard Worker ) 4254*da0073e9SAndroid Build Coastguard Worker 4255*da0073e9SAndroid Build Coastguard Worker sample_shape = torch.Size((2,)) 4256*da0073e9SAndroid Build Coastguard Worker expected_size = sample_shape + expected_size 4257*da0073e9SAndroid Build Coastguard Worker actual_size = dist.sample(sample_shape).size() 4258*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4259*da0073e9SAndroid Build Coastguard Worker actual_size, 4260*da0073e9SAndroid Build Coastguard Worker expected_size, 4261*da0073e9SAndroid Build Coastguard Worker msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}", 4262*da0073e9SAndroid Build Coastguard Worker ) 4263*da0073e9SAndroid Build Coastguard Worker 4264*da0073e9SAndroid Build Coastguard Worker def test_invalid_parameter_broadcasting(self): 4265*da0073e9SAndroid Build Coastguard Worker # invalid broadcasting cases; should throw error 4266*da0073e9SAndroid Build Coastguard Worker # example type (distribution class, distribution params) 4267*da0073e9SAndroid Build Coastguard Worker invalid_examples = [ 4268*da0073e9SAndroid Build Coastguard Worker ( 4269*da0073e9SAndroid Build Coastguard Worker Normal, 4270*da0073e9SAndroid Build Coastguard Worker {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])}, 4271*da0073e9SAndroid Build Coastguard Worker ), 4272*da0073e9SAndroid Build Coastguard Worker ( 4273*da0073e9SAndroid Build Coastguard Worker Normal, 4274*da0073e9SAndroid Build Coastguard Worker { 4275*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]), 4276*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1, 1]), 4277*da0073e9SAndroid Build Coastguard Worker }, 4278*da0073e9SAndroid Build Coastguard Worker ), 4279*da0073e9SAndroid Build Coastguard Worker ( 4280*da0073e9SAndroid Build Coastguard Worker FisherSnedecor, 4281*da0073e9SAndroid Build Coastguard Worker { 4282*da0073e9SAndroid Build Coastguard Worker "df1": torch.tensor([1, 1]), 4283*da0073e9SAndroid Build Coastguard Worker "df2": torch.tensor([1, 1, 1]), 4284*da0073e9SAndroid Build Coastguard Worker }, 4285*da0073e9SAndroid Build Coastguard Worker ), 4286*da0073e9SAndroid Build Coastguard Worker ( 4287*da0073e9SAndroid Build Coastguard Worker Gumbel, 4288*da0073e9SAndroid Build Coastguard Worker {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])}, 4289*da0073e9SAndroid Build Coastguard Worker ), 4290*da0073e9SAndroid Build Coastguard Worker ( 4291*da0073e9SAndroid Build Coastguard Worker Gumbel, 4292*da0073e9SAndroid Build Coastguard Worker { 4293*da0073e9SAndroid Build Coastguard Worker "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]), 4294*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1, 1]), 4295*da0073e9SAndroid Build Coastguard Worker }, 4296*da0073e9SAndroid Build Coastguard Worker ), 4297*da0073e9SAndroid Build Coastguard Worker ( 4298*da0073e9SAndroid Build Coastguard Worker Gamma, 4299*da0073e9SAndroid Build Coastguard Worker { 4300*da0073e9SAndroid Build Coastguard Worker "concentration": torch.tensor([0, 0]), 4301*da0073e9SAndroid Build Coastguard Worker "rate": torch.tensor([1, 1, 1]), 4302*da0073e9SAndroid Build Coastguard Worker }, 4303*da0073e9SAndroid Build Coastguard Worker ), 4304*da0073e9SAndroid Build Coastguard Worker ( 4305*da0073e9SAndroid Build Coastguard Worker Kumaraswamy, 4306*da0073e9SAndroid Build Coastguard Worker { 4307*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.tensor([[1, 1]]), 4308*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.tensor([1, 1, 1, 1]), 4309*da0073e9SAndroid Build Coastguard Worker }, 4310*da0073e9SAndroid Build Coastguard Worker ), 4311*da0073e9SAndroid Build Coastguard Worker ( 4312*da0073e9SAndroid Build Coastguard Worker Kumaraswamy, 4313*da0073e9SAndroid Build Coastguard Worker { 4314*da0073e9SAndroid Build Coastguard Worker "concentration1": torch.tensor([[[1, 1, 1], [1, 1, 1]]]), 4315*da0073e9SAndroid Build Coastguard Worker "concentration0": torch.tensor([1, 1]), 4316*da0073e9SAndroid Build Coastguard Worker }, 4317*da0073e9SAndroid Build Coastguard Worker ), 4318*da0073e9SAndroid Build Coastguard Worker (Laplace, {"loc": torch.tensor([0, 0]), "scale": torch.tensor([1, 1, 1])}), 4319*da0073e9SAndroid Build Coastguard Worker (Pareto, {"scale": torch.tensor([1, 1]), "alpha": torch.tensor([1, 1, 1])}), 4320*da0073e9SAndroid Build Coastguard Worker ( 4321*da0073e9SAndroid Build Coastguard Worker StudentT, 4322*da0073e9SAndroid Build Coastguard Worker { 4323*da0073e9SAndroid Build Coastguard Worker "df": torch.tensor([1.0, 1.0]), 4324*da0073e9SAndroid Build Coastguard Worker "scale": torch.tensor([1.0, 1.0, 1.0]), 4325*da0073e9SAndroid Build Coastguard Worker }, 4326*da0073e9SAndroid Build Coastguard Worker ), 4327*da0073e9SAndroid Build Coastguard Worker ( 4328*da0073e9SAndroid Build Coastguard Worker StudentT, 4329*da0073e9SAndroid Build Coastguard Worker {"df": torch.tensor([1.0, 1.0]), "loc": torch.tensor([1.0, 1.0, 1.0])}, 4330*da0073e9SAndroid Build Coastguard Worker ), 4331*da0073e9SAndroid Build Coastguard Worker ] 4332*da0073e9SAndroid Build Coastguard Worker 4333*da0073e9SAndroid Build Coastguard Worker for dist, kwargs in invalid_examples: 4334*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, dist, **kwargs) 4335*da0073e9SAndroid Build Coastguard Worker 4336*da0073e9SAndroid Build Coastguard Worker def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite): 4337*da0073e9SAndroid Build Coastguard Worker # We cannot easily check the mode for discrete distributions, but we can look left and right 4338*da0073e9SAndroid Build Coastguard Worker # to ensure the log probability is smaller than at the mode. 4339*da0073e9SAndroid Build Coastguard Worker for step in [-1, 1]: 4340*da0073e9SAndroid Build Coastguard Worker log_prob_mode = dist.log_prob(sanitized_mode) 4341*da0073e9SAndroid Build Coastguard Worker if isinstance(dist, OneHotCategorical): 4342*da0073e9SAndroid Build Coastguard Worker idx = (dist._categorical.mode + 1) % dist.probs.shape[-1] 4343*da0073e9SAndroid Build Coastguard Worker other = torch.nn.functional.one_hot( 4344*da0073e9SAndroid Build Coastguard Worker idx, num_classes=dist.probs.shape[-1] 4345*da0073e9SAndroid Build Coastguard Worker ).to(dist.mode) 4346*da0073e9SAndroid Build Coastguard Worker else: 4347*da0073e9SAndroid Build Coastguard Worker other = dist.mode + step 4348*da0073e9SAndroid Build Coastguard Worker mask = batch_isfinite & dist.support.check(other) 4349*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mask.any() or dist.mode.unique().numel() == 1) 4350*da0073e9SAndroid Build Coastguard Worker # Add a dimension to the right if the event shape is not a scalar, e.g. OneHotCategorical. 4351*da0073e9SAndroid Build Coastguard Worker other = torch.where( 4352*da0073e9SAndroid Build Coastguard Worker mask[..., None] if mask.ndim < other.ndim else mask, 4353*da0073e9SAndroid Build Coastguard Worker other, 4354*da0073e9SAndroid Build Coastguard Worker dist.sample(), 4355*da0073e9SAndroid Build Coastguard Worker ) 4356*da0073e9SAndroid Build Coastguard Worker log_prob_other = dist.log_prob(other) 4357*da0073e9SAndroid Build Coastguard Worker delta = log_prob_mode - log_prob_other 4358*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4359*da0073e9SAndroid Build Coastguard Worker (-1e-12 < delta[mask].detach()).all() 4360*da0073e9SAndroid Build Coastguard Worker ) # Allow up to 1e-12 rounding error. 4361*da0073e9SAndroid Build Coastguard Worker 4362*da0073e9SAndroid Build Coastguard Worker def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite): 4363*da0073e9SAndroid Build Coastguard Worker # We perturb the mode in the unconstrained space and expect the log probability to decrease. 4364*da0073e9SAndroid Build Coastguard Worker num_points = 10 4365*da0073e9SAndroid Build Coastguard Worker transform = transform_to(dist.support) 4366*da0073e9SAndroid Build Coastguard Worker unconstrained_mode = transform.inv(sanitized_mode) 4367*da0073e9SAndroid Build Coastguard Worker perturbation = 1e-5 * ( 4368*da0073e9SAndroid Build Coastguard Worker torch.rand((num_points,) + unconstrained_mode.shape) - 0.5 4369*da0073e9SAndroid Build Coastguard Worker ) 4370*da0073e9SAndroid Build Coastguard Worker perturbed_mode = transform(perturbation + unconstrained_mode) 4371*da0073e9SAndroid Build Coastguard Worker log_prob_mode = dist.log_prob(sanitized_mode) 4372*da0073e9SAndroid Build Coastguard Worker log_prob_other = dist.log_prob(perturbed_mode) 4373*da0073e9SAndroid Build Coastguard Worker delta = log_prob_mode - log_prob_other 4374*da0073e9SAndroid Build Coastguard Worker 4375*da0073e9SAndroid Build Coastguard Worker # We pass the test with a small tolerance to allow for rounding and manually set the 4376*da0073e9SAndroid Build Coastguard Worker # difference to zero if both log probs are infinite with the same sign. 4377*da0073e9SAndroid Build Coastguard Worker both_infinite_with_same_sign = (log_prob_mode == log_prob_other) & ( 4378*da0073e9SAndroid Build Coastguard Worker log_prob_mode.abs() == inf 4379*da0073e9SAndroid Build Coastguard Worker ) 4380*da0073e9SAndroid Build Coastguard Worker delta[both_infinite_with_same_sign] = 0.0 4381*da0073e9SAndroid Build Coastguard Worker ordering = (delta > -1e-12).all(axis=0) 4382*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ordering[batch_isfinite].all()) 4383*da0073e9SAndroid Build Coastguard Worker 4384*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4385*da0073e9SAndroid Build Coastguard Worker def test_mode(self): 4386*da0073e9SAndroid Build Coastguard Worker discrete_distributions = ( 4387*da0073e9SAndroid Build Coastguard Worker Bernoulli, 4388*da0073e9SAndroid Build Coastguard Worker Binomial, 4389*da0073e9SAndroid Build Coastguard Worker Categorical, 4390*da0073e9SAndroid Build Coastguard Worker Geometric, 4391*da0073e9SAndroid Build Coastguard Worker NegativeBinomial, 4392*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 4393*da0073e9SAndroid Build Coastguard Worker Poisson, 4394*da0073e9SAndroid Build Coastguard Worker ) 4395*da0073e9SAndroid Build Coastguard Worker no_mode_available = ( 4396*da0073e9SAndroid Build Coastguard Worker ContinuousBernoulli, 4397*da0073e9SAndroid Build Coastguard Worker LKJCholesky, 4398*da0073e9SAndroid Build Coastguard Worker LogisticNormal, 4399*da0073e9SAndroid Build Coastguard Worker MixtureSameFamily, 4400*da0073e9SAndroid Build Coastguard Worker Multinomial, 4401*da0073e9SAndroid Build Coastguard Worker RelaxedBernoulli, 4402*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical, 4403*da0073e9SAndroid Build Coastguard Worker ) 4404*da0073e9SAndroid Build Coastguard Worker 4405*da0073e9SAndroid Build Coastguard Worker for dist_cls, params in _get_examples(): 4406*da0073e9SAndroid Build Coastguard Worker for param in params: 4407*da0073e9SAndroid Build Coastguard Worker dist = dist_cls(**param) 4408*da0073e9SAndroid Build Coastguard Worker if ( 4409*da0073e9SAndroid Build Coastguard Worker isinstance(dist, no_mode_available) 4410*da0073e9SAndroid Build Coastguard Worker or type(dist) is TransformedDistribution 4411*da0073e9SAndroid Build Coastguard Worker ): 4412*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 4413*da0073e9SAndroid Build Coastguard Worker dist.mode 4414*da0073e9SAndroid Build Coastguard Worker continue 4415*da0073e9SAndroid Build Coastguard Worker 4416*da0073e9SAndroid Build Coastguard Worker # Check that either all or no elements in the event shape are nan: the mode cannot be 4417*da0073e9SAndroid Build Coastguard Worker # defined for part of an event. 4418*da0073e9SAndroid Build Coastguard Worker isfinite = dist.mode.isfinite().reshape( 4419*da0073e9SAndroid Build Coastguard Worker dist.batch_shape + (dist.event_shape.numel(),) 4420*da0073e9SAndroid Build Coastguard Worker ) 4421*da0073e9SAndroid Build Coastguard Worker batch_isfinite = isfinite.all(axis=-1) 4422*da0073e9SAndroid Build Coastguard Worker self.assertTrue((batch_isfinite | ~isfinite.any(axis=-1)).all()) 4423*da0073e9SAndroid Build Coastguard Worker 4424*da0073e9SAndroid Build Coastguard Worker # We sanitize undefined modes by sampling from the distribution. 4425*da0073e9SAndroid Build Coastguard Worker sanitized_mode = torch.where( 4426*da0073e9SAndroid Build Coastguard Worker ~dist.mode.isnan(), dist.mode, dist.sample() 4427*da0073e9SAndroid Build Coastguard Worker ) 4428*da0073e9SAndroid Build Coastguard Worker if isinstance(dist, discrete_distributions): 4429*da0073e9SAndroid Build Coastguard Worker self._test_discrete_distribution_mode( 4430*da0073e9SAndroid Build Coastguard Worker dist, sanitized_mode, batch_isfinite 4431*da0073e9SAndroid Build Coastguard Worker ) 4432*da0073e9SAndroid Build Coastguard Worker else: 4433*da0073e9SAndroid Build Coastguard Worker self._test_continuous_distribution_mode( 4434*da0073e9SAndroid Build Coastguard Worker dist, sanitized_mode, batch_isfinite 4435*da0073e9SAndroid Build Coastguard Worker ) 4436*da0073e9SAndroid Build Coastguard Worker 4437*da0073e9SAndroid Build Coastguard Worker self.assertFalse(dist.log_prob(sanitized_mode).isnan().any()) 4438*da0073e9SAndroid Build Coastguard Worker 4439*da0073e9SAndroid Build Coastguard Worker 4440*da0073e9SAndroid Build Coastguard Worker# These tests are only needed for a few distributions that implement custom 4441*da0073e9SAndroid Build Coastguard Worker# reparameterized gradients. Most .rsample() implementations simply rely on 4442*da0073e9SAndroid Build Coastguard Worker# the reparameterization trick and do not need to be tested for accuracy. 4443*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 4444*da0073e9SAndroid Build Coastguard Workerclass TestRsample(DistributionsTestCase): 4445*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4446*da0073e9SAndroid Build Coastguard Worker def test_gamma(self): 4447*da0073e9SAndroid Build Coastguard Worker num_samples = 100 4448*da0073e9SAndroid Build Coastguard Worker for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]: 4449*da0073e9SAndroid Build Coastguard Worker alphas = torch.tensor( 4450*da0073e9SAndroid Build Coastguard Worker [alpha] * num_samples, dtype=torch.float, requires_grad=True 4451*da0073e9SAndroid Build Coastguard Worker ) 4452*da0073e9SAndroid Build Coastguard Worker betas = alphas.new_ones(num_samples) 4453*da0073e9SAndroid Build Coastguard Worker x = Gamma(alphas, betas).rsample() 4454*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4455*da0073e9SAndroid Build Coastguard Worker x, ind = x.sort() 4456*da0073e9SAndroid Build Coastguard Worker x = x.detach().numpy() 4457*da0073e9SAndroid Build Coastguard Worker actual_grad = alphas.grad[ind].numpy() 4458*da0073e9SAndroid Build Coastguard Worker # Compare with expected gradient dx/dalpha along constant cdf(x,alpha). 4459*da0073e9SAndroid Build Coastguard Worker cdf = scipy.stats.gamma.cdf 4460*da0073e9SAndroid Build Coastguard Worker pdf = scipy.stats.gamma.pdf 4461*da0073e9SAndroid Build Coastguard Worker eps = 0.01 * alpha / (1.0 + alpha**0.5) 4462*da0073e9SAndroid Build Coastguard Worker cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps) 4463*da0073e9SAndroid Build Coastguard Worker cdf_x = pdf(x, alpha) 4464*da0073e9SAndroid Build Coastguard Worker expected_grad = -cdf_alpha / cdf_x 4465*da0073e9SAndroid Build Coastguard Worker rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4466*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4467*da0073e9SAndroid Build Coastguard Worker np.max(rel_error), 4468*da0073e9SAndroid Build Coastguard Worker 0.0005, 4469*da0073e9SAndroid Build Coastguard Worker "\n".join( 4470*da0073e9SAndroid Build Coastguard Worker [ 4471*da0073e9SAndroid Build Coastguard Worker f"Bad gradient dx/alpha for x ~ Gamma({alpha}, 1)", 4472*da0073e9SAndroid Build Coastguard Worker f"x {x}", 4473*da0073e9SAndroid Build Coastguard Worker f"expected {expected_grad}", 4474*da0073e9SAndroid Build Coastguard Worker f"actual {actual_grad}", 4475*da0073e9SAndroid Build Coastguard Worker f"rel error {rel_error}", 4476*da0073e9SAndroid Build Coastguard Worker f"max error {rel_error.max()}", 4477*da0073e9SAndroid Build Coastguard Worker f"at alpha={alpha}, x={x[rel_error.argmax()]}", 4478*da0073e9SAndroid Build Coastguard Worker ] 4479*da0073e9SAndroid Build Coastguard Worker ), 4480*da0073e9SAndroid Build Coastguard Worker ) 4481*da0073e9SAndroid Build Coastguard Worker 4482*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4483*da0073e9SAndroid Build Coastguard Worker def test_chi2(self): 4484*da0073e9SAndroid Build Coastguard Worker num_samples = 100 4485*da0073e9SAndroid Build Coastguard Worker for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]: 4486*da0073e9SAndroid Build Coastguard Worker dfs = torch.tensor( 4487*da0073e9SAndroid Build Coastguard Worker [df] * num_samples, dtype=torch.float, requires_grad=True 4488*da0073e9SAndroid Build Coastguard Worker ) 4489*da0073e9SAndroid Build Coastguard Worker x = Chi2(dfs).rsample() 4490*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4491*da0073e9SAndroid Build Coastguard Worker x, ind = x.sort() 4492*da0073e9SAndroid Build Coastguard Worker x = x.detach().numpy() 4493*da0073e9SAndroid Build Coastguard Worker actual_grad = dfs.grad[ind].numpy() 4494*da0073e9SAndroid Build Coastguard Worker # Compare with expected gradient dx/ddf along constant cdf(x,df). 4495*da0073e9SAndroid Build Coastguard Worker cdf = scipy.stats.chi2.cdf 4496*da0073e9SAndroid Build Coastguard Worker pdf = scipy.stats.chi2.pdf 4497*da0073e9SAndroid Build Coastguard Worker eps = 0.01 * df / (1.0 + df**0.5) 4498*da0073e9SAndroid Build Coastguard Worker cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps) 4499*da0073e9SAndroid Build Coastguard Worker cdf_x = pdf(x, df) 4500*da0073e9SAndroid Build Coastguard Worker expected_grad = -cdf_df / cdf_x 4501*da0073e9SAndroid Build Coastguard Worker rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4502*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4503*da0073e9SAndroid Build Coastguard Worker np.max(rel_error), 4504*da0073e9SAndroid Build Coastguard Worker 0.001, 4505*da0073e9SAndroid Build Coastguard Worker "\n".join( 4506*da0073e9SAndroid Build Coastguard Worker [ 4507*da0073e9SAndroid Build Coastguard Worker f"Bad gradient dx/ddf for x ~ Chi2({df})", 4508*da0073e9SAndroid Build Coastguard Worker f"x {x}", 4509*da0073e9SAndroid Build Coastguard Worker f"expected {expected_grad}", 4510*da0073e9SAndroid Build Coastguard Worker f"actual {actual_grad}", 4511*da0073e9SAndroid Build Coastguard Worker f"rel error {rel_error}", 4512*da0073e9SAndroid Build Coastguard Worker f"max error {rel_error.max()}", 4513*da0073e9SAndroid Build Coastguard Worker ] 4514*da0073e9SAndroid Build Coastguard Worker ), 4515*da0073e9SAndroid Build Coastguard Worker ) 4516*da0073e9SAndroid Build Coastguard Worker 4517*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4518*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_on_diagonal(self): 4519*da0073e9SAndroid Build Coastguard Worker num_samples = 20 4520*da0073e9SAndroid Build Coastguard Worker grid = [1e-1, 1e0, 1e1] 4521*da0073e9SAndroid Build Coastguard Worker for a0, a1, a2 in product(grid, grid, grid): 4522*da0073e9SAndroid Build Coastguard Worker alphas = torch.tensor( 4523*da0073e9SAndroid Build Coastguard Worker [[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True 4524*da0073e9SAndroid Build Coastguard Worker ) 4525*da0073e9SAndroid Build Coastguard Worker x = Dirichlet(alphas).rsample()[:, 0] 4526*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4527*da0073e9SAndroid Build Coastguard Worker x, ind = x.sort() 4528*da0073e9SAndroid Build Coastguard Worker x = x.detach().numpy() 4529*da0073e9SAndroid Build Coastguard Worker actual_grad = alphas.grad[ind].numpy()[:, 0] 4530*da0073e9SAndroid Build Coastguard Worker # Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha). 4531*da0073e9SAndroid Build Coastguard Worker # This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]). 4532*da0073e9SAndroid Build Coastguard Worker cdf = scipy.stats.beta.cdf 4533*da0073e9SAndroid Build Coastguard Worker pdf = scipy.stats.beta.pdf 4534*da0073e9SAndroid Build Coastguard Worker alpha, beta = a0, a1 + a2 4535*da0073e9SAndroid Build Coastguard Worker eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) 4536*da0073e9SAndroid Build Coastguard Worker cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / ( 4537*da0073e9SAndroid Build Coastguard Worker 2 * eps 4538*da0073e9SAndroid Build Coastguard Worker ) 4539*da0073e9SAndroid Build Coastguard Worker cdf_x = pdf(x, alpha, beta) 4540*da0073e9SAndroid Build Coastguard Worker expected_grad = -cdf_alpha / cdf_x 4541*da0073e9SAndroid Build Coastguard Worker rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4542*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4543*da0073e9SAndroid Build Coastguard Worker np.max(rel_error), 4544*da0073e9SAndroid Build Coastguard Worker 0.001, 4545*da0073e9SAndroid Build Coastguard Worker "\n".join( 4546*da0073e9SAndroid Build Coastguard Worker [ 4547*da0073e9SAndroid Build Coastguard Worker f"Bad gradient dx[0]/dalpha[0] for Dirichlet([{a0}, {a1}, {a2}])", 4548*da0073e9SAndroid Build Coastguard Worker f"x {x}", 4549*da0073e9SAndroid Build Coastguard Worker f"expected {expected_grad}", 4550*da0073e9SAndroid Build Coastguard Worker f"actual {actual_grad}", 4551*da0073e9SAndroid Build Coastguard Worker f"rel error {rel_error}", 4552*da0073e9SAndroid Build Coastguard Worker f"max error {rel_error.max()}", 4553*da0073e9SAndroid Build Coastguard Worker f"at x={x[rel_error.argmax()]}", 4554*da0073e9SAndroid Build Coastguard Worker ] 4555*da0073e9SAndroid Build Coastguard Worker ), 4556*da0073e9SAndroid Build Coastguard Worker ) 4557*da0073e9SAndroid Build Coastguard Worker 4558*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4559*da0073e9SAndroid Build Coastguard Worker def test_beta_wrt_alpha(self): 4560*da0073e9SAndroid Build Coastguard Worker num_samples = 20 4561*da0073e9SAndroid Build Coastguard Worker grid = [1e-2, 1e-1, 1e0, 1e1, 1e2] 4562*da0073e9SAndroid Build Coastguard Worker for con1, con0 in product(grid, grid): 4563*da0073e9SAndroid Build Coastguard Worker con1s = torch.tensor( 4564*da0073e9SAndroid Build Coastguard Worker [con1] * num_samples, dtype=torch.float, requires_grad=True 4565*da0073e9SAndroid Build Coastguard Worker ) 4566*da0073e9SAndroid Build Coastguard Worker con0s = con1s.new_tensor([con0] * num_samples) 4567*da0073e9SAndroid Build Coastguard Worker x = Beta(con1s, con0s).rsample() 4568*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4569*da0073e9SAndroid Build Coastguard Worker x, ind = x.sort() 4570*da0073e9SAndroid Build Coastguard Worker x = x.detach().numpy() 4571*da0073e9SAndroid Build Coastguard Worker actual_grad = con1s.grad[ind].numpy() 4572*da0073e9SAndroid Build Coastguard Worker # Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0). 4573*da0073e9SAndroid Build Coastguard Worker cdf = scipy.stats.beta.cdf 4574*da0073e9SAndroid Build Coastguard Worker pdf = scipy.stats.beta.pdf 4575*da0073e9SAndroid Build Coastguard Worker eps = 0.01 * con1 / (1.0 + np.sqrt(con1)) 4576*da0073e9SAndroid Build Coastguard Worker cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / ( 4577*da0073e9SAndroid Build Coastguard Worker 2 * eps 4578*da0073e9SAndroid Build Coastguard Worker ) 4579*da0073e9SAndroid Build Coastguard Worker cdf_x = pdf(x, con1, con0) 4580*da0073e9SAndroid Build Coastguard Worker expected_grad = -cdf_alpha / cdf_x 4581*da0073e9SAndroid Build Coastguard Worker rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4582*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4583*da0073e9SAndroid Build Coastguard Worker np.max(rel_error), 4584*da0073e9SAndroid Build Coastguard Worker 0.005, 4585*da0073e9SAndroid Build Coastguard Worker "\n".join( 4586*da0073e9SAndroid Build Coastguard Worker [ 4587*da0073e9SAndroid Build Coastguard Worker f"Bad gradient dx/dcon1 for x ~ Beta({con1}, {con0})", 4588*da0073e9SAndroid Build Coastguard Worker f"x {x}", 4589*da0073e9SAndroid Build Coastguard Worker f"expected {expected_grad}", 4590*da0073e9SAndroid Build Coastguard Worker f"actual {actual_grad}", 4591*da0073e9SAndroid Build Coastguard Worker f"rel error {rel_error}", 4592*da0073e9SAndroid Build Coastguard Worker f"max error {rel_error.max()}", 4593*da0073e9SAndroid Build Coastguard Worker f"at x = {x[rel_error.argmax()]}", 4594*da0073e9SAndroid Build Coastguard Worker ] 4595*da0073e9SAndroid Build Coastguard Worker ), 4596*da0073e9SAndroid Build Coastguard Worker ) 4597*da0073e9SAndroid Build Coastguard Worker 4598*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4599*da0073e9SAndroid Build Coastguard Worker def test_beta_wrt_beta(self): 4600*da0073e9SAndroid Build Coastguard Worker num_samples = 20 4601*da0073e9SAndroid Build Coastguard Worker grid = [1e-2, 1e-1, 1e0, 1e1, 1e2] 4602*da0073e9SAndroid Build Coastguard Worker for con1, con0 in product(grid, grid): 4603*da0073e9SAndroid Build Coastguard Worker con0s = torch.tensor( 4604*da0073e9SAndroid Build Coastguard Worker [con0] * num_samples, dtype=torch.float, requires_grad=True 4605*da0073e9SAndroid Build Coastguard Worker ) 4606*da0073e9SAndroid Build Coastguard Worker con1s = con0s.new_tensor([con1] * num_samples) 4607*da0073e9SAndroid Build Coastguard Worker x = Beta(con1s, con0s).rsample() 4608*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4609*da0073e9SAndroid Build Coastguard Worker x, ind = x.sort() 4610*da0073e9SAndroid Build Coastguard Worker x = x.detach().numpy() 4611*da0073e9SAndroid Build Coastguard Worker actual_grad = con0s.grad[ind].numpy() 4612*da0073e9SAndroid Build Coastguard Worker # Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0). 4613*da0073e9SAndroid Build Coastguard Worker cdf = scipy.stats.beta.cdf 4614*da0073e9SAndroid Build Coastguard Worker pdf = scipy.stats.beta.pdf 4615*da0073e9SAndroid Build Coastguard Worker eps = 0.01 * con0 / (1.0 + np.sqrt(con0)) 4616*da0073e9SAndroid Build Coastguard Worker cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps) 4617*da0073e9SAndroid Build Coastguard Worker cdf_x = pdf(x, con1, con0) 4618*da0073e9SAndroid Build Coastguard Worker expected_grad = -cdf_beta / cdf_x 4619*da0073e9SAndroid Build Coastguard Worker rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4620*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4621*da0073e9SAndroid Build Coastguard Worker np.max(rel_error), 4622*da0073e9SAndroid Build Coastguard Worker 0.005, 4623*da0073e9SAndroid Build Coastguard Worker "\n".join( 4624*da0073e9SAndroid Build Coastguard Worker [ 4625*da0073e9SAndroid Build Coastguard Worker f"Bad gradient dx/dcon0 for x ~ Beta({con1}, {con0})", 4626*da0073e9SAndroid Build Coastguard Worker f"x {x}", 4627*da0073e9SAndroid Build Coastguard Worker f"expected {expected_grad}", 4628*da0073e9SAndroid Build Coastguard Worker f"actual {actual_grad}", 4629*da0073e9SAndroid Build Coastguard Worker f"rel error {rel_error}", 4630*da0073e9SAndroid Build Coastguard Worker f"max error {rel_error.max()}", 4631*da0073e9SAndroid Build Coastguard Worker f"at x = {x[rel_error.argmax()]!r}", 4632*da0073e9SAndroid Build Coastguard Worker ] 4633*da0073e9SAndroid Build Coastguard Worker ), 4634*da0073e9SAndroid Build Coastguard Worker ) 4635*da0073e9SAndroid Build Coastguard Worker 4636*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_multivariate(self): 4637*da0073e9SAndroid Build Coastguard Worker alpha_crit = 0.25 * (5.0**0.5 - 1.0) 4638*da0073e9SAndroid Build Coastguard Worker num_samples = 100000 4639*da0073e9SAndroid Build Coastguard Worker for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]: 4640*da0073e9SAndroid Build Coastguard Worker alpha = alpha_crit + shift 4641*da0073e9SAndroid Build Coastguard Worker alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True) 4642*da0073e9SAndroid Build Coastguard Worker alpha_vec = torch.cat([alpha, alpha, alpha.new([1])]) 4643*da0073e9SAndroid Build Coastguard Worker z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample() 4644*da0073e9SAndroid Build Coastguard Worker mean_z3 = 1.0 / (2.0 * alpha + 1.0) 4645*da0073e9SAndroid Build Coastguard Worker loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean() 4646*da0073e9SAndroid Build Coastguard Worker actual_grad = grad(loss, [alpha])[0] 4647*da0073e9SAndroid Build Coastguard Worker # Compute expected gradient by hand. 4648*da0073e9SAndroid Build Coastguard Worker num = 1.0 - 2.0 * alpha - 4.0 * alpha**2 4649*da0073e9SAndroid Build Coastguard Worker den = (1.0 + alpha) ** 2 * (1.0 + 2.0 * alpha) ** 3 4650*da0073e9SAndroid Build Coastguard Worker expected_grad = num / den 4651*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4652*da0073e9SAndroid Build Coastguard Worker actual_grad, 4653*da0073e9SAndroid Build Coastguard Worker expected_grad, 4654*da0073e9SAndroid Build Coastguard Worker atol=0.002, 4655*da0073e9SAndroid Build Coastguard Worker rtol=0, 4656*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 4657*da0073e9SAndroid Build Coastguard Worker [ 4658*da0073e9SAndroid Build Coastguard Worker "alpha = alpha_c + %.2g" % shift, # noqa: UP031 4659*da0073e9SAndroid Build Coastguard Worker "expected_grad: %.5g" % expected_grad, # noqa: UP031 4660*da0073e9SAndroid Build Coastguard Worker "actual_grad: %.5g" % actual_grad, # noqa: UP031 4661*da0073e9SAndroid Build Coastguard Worker "error = %.2g" # noqa: UP031 4662*da0073e9SAndroid Build Coastguard Worker % torch.abs(expected_grad - actual_grad).max(), # noqa: UP031 4663*da0073e9SAndroid Build Coastguard Worker ] 4664*da0073e9SAndroid Build Coastguard Worker ), 4665*da0073e9SAndroid Build Coastguard Worker ) 4666*da0073e9SAndroid Build Coastguard Worker 4667*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 4668*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_tangent_field(self): 4669*da0073e9SAndroid Build Coastguard Worker num_samples = 20 4670*da0073e9SAndroid Build Coastguard Worker alpha_grid = [0.5, 1.0, 2.0] 4671*da0073e9SAndroid Build Coastguard Worker 4672*da0073e9SAndroid Build Coastguard Worker # v = dx/dalpha[0] is the reparameterized gradient aka tangent field. 4673*da0073e9SAndroid Build Coastguard Worker def compute_v(x, alpha): 4674*da0073e9SAndroid Build Coastguard Worker return torch.stack( 4675*da0073e9SAndroid Build Coastguard Worker [ 4676*da0073e9SAndroid Build Coastguard Worker _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0] 4677*da0073e9SAndroid Build Coastguard Worker for i in range(3) 4678*da0073e9SAndroid Build Coastguard Worker ], 4679*da0073e9SAndroid Build Coastguard Worker dim=-1, 4680*da0073e9SAndroid Build Coastguard Worker ) 4681*da0073e9SAndroid Build Coastguard Worker 4682*da0073e9SAndroid Build Coastguard Worker for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid): 4683*da0073e9SAndroid Build Coastguard Worker alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand( 4684*da0073e9SAndroid Build Coastguard Worker num_samples, 3 4685*da0073e9SAndroid Build Coastguard Worker ) 4686*da0073e9SAndroid Build Coastguard Worker x = Dirichlet(alpha).rsample() 4687*da0073e9SAndroid Build Coastguard Worker dlogp_da = grad( 4688*da0073e9SAndroid Build Coastguard Worker [Dirichlet(alpha).log_prob(x.detach()).sum()], 4689*da0073e9SAndroid Build Coastguard Worker [alpha], 4690*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 4691*da0073e9SAndroid Build Coastguard Worker )[0][:, 0] 4692*da0073e9SAndroid Build Coastguard Worker dlogp_dx = grad( 4693*da0073e9SAndroid Build Coastguard Worker [Dirichlet(alpha.detach()).log_prob(x).sum()], [x], retain_graph=True 4694*da0073e9SAndroid Build Coastguard Worker )[0] 4695*da0073e9SAndroid Build Coastguard Worker v = torch.stack( 4696*da0073e9SAndroid Build Coastguard Worker [ 4697*da0073e9SAndroid Build Coastguard Worker grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0] 4698*da0073e9SAndroid Build Coastguard Worker for i in range(3) 4699*da0073e9SAndroid Build Coastguard Worker ], 4700*da0073e9SAndroid Build Coastguard Worker dim=-1, 4701*da0073e9SAndroid Build Coastguard Worker ) 4702*da0073e9SAndroid Build Coastguard Worker # Compute ramaining properties by finite difference. 4703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compute_v(x, alpha), v, msg="Bug in compute_v() helper") 4704*da0073e9SAndroid Build Coastguard Worker # dx is an arbitrary orthonormal basis tangent to the simplex. 4705*da0073e9SAndroid Build Coastguard Worker dx = torch.tensor([[2.0, -1.0, -1.0], [0.0, 1.0, -1.0]]) 4706*da0073e9SAndroid Build Coastguard Worker dx /= dx.norm(2, -1, True) 4707*da0073e9SAndroid Build Coastguard Worker eps = 1e-2 * x.min(-1, True)[0] # avoid boundary 4708*da0073e9SAndroid Build Coastguard Worker dv0 = ( 4709*da0073e9SAndroid Build Coastguard Worker compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha) 4710*da0073e9SAndroid Build Coastguard Worker ) / (2 * eps) 4711*da0073e9SAndroid Build Coastguard Worker dv1 = ( 4712*da0073e9SAndroid Build Coastguard Worker compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha) 4713*da0073e9SAndroid Build Coastguard Worker ) / (2 * eps) 4714*da0073e9SAndroid Build Coastguard Worker div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1) 4715*da0073e9SAndroid Build Coastguard Worker # This is a modification of the standard continuity equation, using the product rule to allow 4716*da0073e9SAndroid Build Coastguard Worker # expression in terms of log_prob rather than the less numerically stable log_prob.exp(). 4717*da0073e9SAndroid Build Coastguard Worker error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v 4718*da0073e9SAndroid Build Coastguard Worker self.assertLess( 4719*da0073e9SAndroid Build Coastguard Worker torch.abs(error).max(), 4720*da0073e9SAndroid Build Coastguard Worker 0.005, 4721*da0073e9SAndroid Build Coastguard Worker "\n".join( 4722*da0073e9SAndroid Build Coastguard Worker [ 4723*da0073e9SAndroid Build Coastguard Worker f"Dirichlet([{a1}, {a2}, {a3}]) gradient violates continuity equation:", 4724*da0073e9SAndroid Build Coastguard Worker f"error = {error}", 4725*da0073e9SAndroid Build Coastguard Worker ] 4726*da0073e9SAndroid Build Coastguard Worker ), 4727*da0073e9SAndroid Build Coastguard Worker ) 4728*da0073e9SAndroid Build Coastguard Worker 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Workerclass TestDistributionShapes(DistributionsTestCase): 4731*da0073e9SAndroid Build Coastguard Worker def setUp(self): 4732*da0073e9SAndroid Build Coastguard Worker super().setUp() 4733*da0073e9SAndroid Build Coastguard Worker self.scalar_sample = 1 4734*da0073e9SAndroid Build Coastguard Worker self.tensor_sample_1 = torch.ones(3, 2) 4735*da0073e9SAndroid Build Coastguard Worker self.tensor_sample_2 = torch.ones(3, 2, 3) 4736*da0073e9SAndroid Build Coastguard Worker 4737*da0073e9SAndroid Build Coastguard Worker def test_entropy_shape(self): 4738*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 4739*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 4740*da0073e9SAndroid Build Coastguard Worker dist = Dist(validate_args=False, **param) 4741*da0073e9SAndroid Build Coastguard Worker try: 4742*da0073e9SAndroid Build Coastguard Worker actual_shape = dist.entropy().size() 4743*da0073e9SAndroid Build Coastguard Worker expected_shape = ( 4744*da0073e9SAndroid Build Coastguard Worker dist.batch_shape if dist.batch_shape else torch.Size() 4745*da0073e9SAndroid Build Coastguard Worker ) 4746*da0073e9SAndroid Build Coastguard Worker message = f"{Dist.__name__} example {i + 1}/{len(params)}, shape mismatch. expected {expected_shape}, actual {actual_shape}" # noqa: B950 4747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_shape, expected_shape, msg=message) 4748*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 4749*da0073e9SAndroid Build Coastguard Worker continue 4750*da0073e9SAndroid Build Coastguard Worker 4751*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_shape_scalar_params(self): 4752*da0073e9SAndroid Build Coastguard Worker bernoulli = Bernoulli(0.3) 4753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli._batch_shape, torch.Size()) 4754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli._event_shape, torch.Size()) 4755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli.sample().size(), torch.Size()) 4756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2))) 4757*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample) 4758*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4759*da0073e9SAndroid Build Coastguard Worker bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4760*da0073e9SAndroid Build Coastguard Worker ) 4761*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4762*da0073e9SAndroid Build Coastguard Worker bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4763*da0073e9SAndroid Build Coastguard Worker ) 4764*da0073e9SAndroid Build Coastguard Worker 4765*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_shape_tensor_params(self): 4766*da0073e9SAndroid Build Coastguard Worker bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2))) 4768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli._event_shape, torch.Size(())) 4769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2))) 4770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4771*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4772*da0073e9SAndroid Build Coastguard Worker bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4773*da0073e9SAndroid Build Coastguard Worker ) 4774*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2) 4775*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4776*da0073e9SAndroid Build Coastguard Worker bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4777*da0073e9SAndroid Build Coastguard Worker ) 4778*da0073e9SAndroid Build Coastguard Worker 4779*da0073e9SAndroid Build Coastguard Worker def test_geometric_shape_scalar_params(self): 4780*da0073e9SAndroid Build Coastguard Worker geometric = Geometric(0.3) 4781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric._batch_shape, torch.Size()) 4782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric._event_shape, torch.Size()) 4783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric.sample().size(), torch.Size()) 4784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2))) 4785*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample) 4786*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4787*da0073e9SAndroid Build Coastguard Worker geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4788*da0073e9SAndroid Build Coastguard Worker ) 4789*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4790*da0073e9SAndroid Build Coastguard Worker geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4791*da0073e9SAndroid Build Coastguard Worker ) 4792*da0073e9SAndroid Build Coastguard Worker 4793*da0073e9SAndroid Build Coastguard Worker def test_geometric_shape_tensor_params(self): 4794*da0073e9SAndroid Build Coastguard Worker geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric._batch_shape, torch.Size((3, 2))) 4796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric._event_shape, torch.Size(())) 4797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric.sample().size(), torch.Size((3, 2))) 4798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4799*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4800*da0073e9SAndroid Build Coastguard Worker geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4801*da0073e9SAndroid Build Coastguard Worker ) 4802*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2) 4803*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4804*da0073e9SAndroid Build Coastguard Worker geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4805*da0073e9SAndroid Build Coastguard Worker ) 4806*da0073e9SAndroid Build Coastguard Worker 4807*da0073e9SAndroid Build Coastguard Worker def test_beta_shape_scalar_params(self): 4808*da0073e9SAndroid Build Coastguard Worker dist = Beta(0.1, 0.1) 4809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size()) 4810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size()) 4811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size()) 4812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2))) 4813*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.scalar_sample) 4814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4815*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4816*da0073e9SAndroid Build Coastguard Worker dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4817*da0073e9SAndroid Build Coastguard Worker ) 4818*da0073e9SAndroid Build Coastguard Worker 4819*da0073e9SAndroid Build Coastguard Worker def test_beta_shape_tensor_params(self): 4820*da0073e9SAndroid Build Coastguard Worker dist = Beta( 4821*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), 4822*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), 4823*da0073e9SAndroid Build Coastguard Worker ) 4824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((3, 2))) 4825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size(())) 4826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4829*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4830*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4831*da0073e9SAndroid Build Coastguard Worker dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4832*da0073e9SAndroid Build Coastguard Worker ) 4833*da0073e9SAndroid Build Coastguard Worker 4834*da0073e9SAndroid Build Coastguard Worker def test_binomial_shape(self): 4835*da0073e9SAndroid Build Coastguard Worker dist = Binomial(10, torch.tensor([0.6, 0.3])) 4836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((2,))) 4837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size(())) 4838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((2,))) 4839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2))) 4840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4841*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4842*da0073e9SAndroid Build Coastguard Worker 4843*da0073e9SAndroid Build Coastguard Worker def test_binomial_shape_vectorized_n(self): 4844*da0073e9SAndroid Build Coastguard Worker dist = Binomial( 4845*da0073e9SAndroid Build Coastguard Worker torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]) 4846*da0073e9SAndroid Build Coastguard Worker ) 4847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((2, 3))) 4848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size(())) 4849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((2, 3))) 4850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3))) 4851*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4852*da0073e9SAndroid Build Coastguard Worker dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4853*da0073e9SAndroid Build Coastguard Worker ) 4854*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4855*da0073e9SAndroid Build Coastguard Worker 4856*da0073e9SAndroid Build Coastguard Worker def test_multinomial_shape(self): 4857*da0073e9SAndroid Build Coastguard Worker dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((3,))) 4859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size((2,))) 4860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) 4863*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3))) 4865*da0073e9SAndroid Build Coastguard Worker 4866*da0073e9SAndroid Build Coastguard Worker def test_categorical_shape(self): 4867*da0073e9SAndroid Build Coastguard Worker # unbatched 4868*da0073e9SAndroid Build Coastguard Worker dist = Categorical(torch.tensor([0.6, 0.3, 0.1])) 4869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size(())) 4870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size(())) 4871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size()) 4872*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4873*da0073e9SAndroid Build Coastguard Worker dist.sample((3, 2)).size(), 4874*da0073e9SAndroid Build Coastguard Worker torch.Size( 4875*da0073e9SAndroid Build Coastguard Worker ( 4876*da0073e9SAndroid Build Coastguard Worker 3, 4877*da0073e9SAndroid Build Coastguard Worker 2, 4878*da0073e9SAndroid Build Coastguard Worker ) 4879*da0073e9SAndroid Build Coastguard Worker ), 4880*da0073e9SAndroid Build Coastguard Worker ) 4881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4882*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4883*da0073e9SAndroid Build Coastguard Worker dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4884*da0073e9SAndroid Build Coastguard Worker ) 4885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1))) 4886*da0073e9SAndroid Build Coastguard Worker # batched 4887*da0073e9SAndroid Build Coastguard Worker dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((3,))) 4889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size(())) 4890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3,))) 4891*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4892*da0073e9SAndroid Build Coastguard Worker dist.sample((3, 2)).size(), 4893*da0073e9SAndroid Build Coastguard Worker torch.Size( 4894*da0073e9SAndroid Build Coastguard Worker ( 4895*da0073e9SAndroid Build Coastguard Worker 3, 4896*da0073e9SAndroid Build Coastguard Worker 2, 4897*da0073e9SAndroid Build Coastguard Worker 3, 4898*da0073e9SAndroid Build Coastguard Worker ) 4899*da0073e9SAndroid Build Coastguard Worker ), 4900*da0073e9SAndroid Build Coastguard Worker ) 4901*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4902*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4903*da0073e9SAndroid Build Coastguard Worker dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4904*da0073e9SAndroid Build Coastguard Worker ) 4905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3))) 4906*da0073e9SAndroid Build Coastguard Worker 4907*da0073e9SAndroid Build Coastguard Worker def test_one_hot_categorical_shape(self): 4908*da0073e9SAndroid Build Coastguard Worker # unbatched 4909*da0073e9SAndroid Build Coastguard Worker dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1])) 4910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size(())) 4911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size((3,))) 4912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3,))) 4913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3))) 4914*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4915*da0073e9SAndroid Build Coastguard Worker sample = torch.tensor([0.0, 1.0, 0.0]).expand(3, 2, 3) 4916*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4917*da0073e9SAndroid Build Coastguard Worker dist.log_prob(sample).size(), 4918*da0073e9SAndroid Build Coastguard Worker torch.Size( 4919*da0073e9SAndroid Build Coastguard Worker ( 4920*da0073e9SAndroid Build Coastguard Worker 3, 4921*da0073e9SAndroid Build Coastguard Worker 2, 4922*da0073e9SAndroid Build Coastguard Worker ) 4923*da0073e9SAndroid Build Coastguard Worker ), 4924*da0073e9SAndroid Build Coastguard Worker ) 4925*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4926*da0073e9SAndroid Build Coastguard Worker dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)) 4927*da0073e9SAndroid Build Coastguard Worker ) 4928*da0073e9SAndroid Build Coastguard Worker sample = torch.eye(3) 4929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) 4930*da0073e9SAndroid Build Coastguard Worker # batched 4931*da0073e9SAndroid Build Coastguard Worker dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((3,))) 4933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size((2,))) 4934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4936*da0073e9SAndroid Build Coastguard Worker sample = torch.tensor([0.0, 1.0]) 4937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) 4938*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4939*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4940*da0073e9SAndroid Build Coastguard Worker dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)) 4941*da0073e9SAndroid Build Coastguard Worker ) 4942*da0073e9SAndroid Build Coastguard Worker sample = torch.tensor([0.0, 1.0]).expand(3, 1, 2) 4943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3))) 4944*da0073e9SAndroid Build Coastguard Worker 4945*da0073e9SAndroid Build Coastguard Worker def test_cauchy_shape_scalar_params(self): 4946*da0073e9SAndroid Build Coastguard Worker cauchy = Cauchy(0, 1) 4947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy._batch_shape, torch.Size()) 4948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy._event_shape, torch.Size()) 4949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy.sample().size(), torch.Size()) 4950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) 4951*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample) 4952*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4953*da0073e9SAndroid Build Coastguard Worker cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4954*da0073e9SAndroid Build Coastguard Worker ) 4955*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4956*da0073e9SAndroid Build Coastguard Worker cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4957*da0073e9SAndroid Build Coastguard Worker ) 4958*da0073e9SAndroid Build Coastguard Worker 4959*da0073e9SAndroid Build Coastguard Worker def test_cauchy_shape_tensor_params(self): 4960*da0073e9SAndroid Build Coastguard Worker cauchy = Cauchy(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 4961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy._batch_shape, torch.Size((2,))) 4962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy._event_shape, torch.Size(())) 4963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy.sample().size(), torch.Size((2,))) 4964*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4965*da0073e9SAndroid Build Coastguard Worker cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 4966*da0073e9SAndroid Build Coastguard Worker ) 4967*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4968*da0073e9SAndroid Build Coastguard Worker cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4969*da0073e9SAndroid Build Coastguard Worker ) 4970*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2) 4971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 4972*da0073e9SAndroid Build Coastguard Worker 4973*da0073e9SAndroid Build Coastguard Worker def test_halfcauchy_shape_scalar_params(self): 4974*da0073e9SAndroid Build Coastguard Worker halfcauchy = HalfCauchy(1) 4975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy._batch_shape, torch.Size()) 4976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy._event_shape, torch.Size()) 4977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy.sample().size(), torch.Size()) 4978*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4979*da0073e9SAndroid Build Coastguard Worker halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)) 4980*da0073e9SAndroid Build Coastguard Worker ) 4981*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample) 4982*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4983*da0073e9SAndroid Build Coastguard Worker halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4984*da0073e9SAndroid Build Coastguard Worker ) 4985*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4986*da0073e9SAndroid Build Coastguard Worker halfcauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4987*da0073e9SAndroid Build Coastguard Worker ) 4988*da0073e9SAndroid Build Coastguard Worker 4989*da0073e9SAndroid Build Coastguard Worker def test_halfcauchy_shape_tensor_params(self): 4990*da0073e9SAndroid Build Coastguard Worker halfcauchy = HalfCauchy(torch.tensor([1.0, 1.0])) 4991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy._batch_shape, torch.Size((2,))) 4992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy._event_shape, torch.Size(())) 4993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfcauchy.sample().size(), torch.Size((2,))) 4994*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4995*da0073e9SAndroid Build Coastguard Worker halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 4996*da0073e9SAndroid Build Coastguard Worker ) 4997*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4998*da0073e9SAndroid Build Coastguard Worker halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4999*da0073e9SAndroid Build Coastguard Worker ) 5000*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2) 5001*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5002*da0073e9SAndroid Build Coastguard Worker halfcauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)) 5003*da0073e9SAndroid Build Coastguard Worker ) 5004*da0073e9SAndroid Build Coastguard Worker 5005*da0073e9SAndroid Build Coastguard Worker def test_dirichlet_shape(self): 5006*da0073e9SAndroid Build Coastguard Worker dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]])) 5007*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size((3,))) 5008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size((2,))) 5009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 5010*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2))) 5011*da0073e9SAndroid Build Coastguard Worker simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum( 5012*da0073e9SAndroid Build Coastguard Worker -1, keepdim=True 5013*da0073e9SAndroid Build Coastguard Worker ) 5014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,))) 5015*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 5016*da0073e9SAndroid Build Coastguard Worker simplex_sample = torch.ones(3, 1, 2) 5017*da0073e9SAndroid Build Coastguard Worker simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1) 5018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3))) 5019*da0073e9SAndroid Build Coastguard Worker 5020*da0073e9SAndroid Build Coastguard Worker def test_mixture_same_family_shape(self): 5021*da0073e9SAndroid Build Coastguard Worker dist = MixtureSameFamily( 5022*da0073e9SAndroid Build Coastguard Worker Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5)) 5023*da0073e9SAndroid Build Coastguard Worker ) 5024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._batch_shape, torch.Size()) 5025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist._event_shape, torch.Size()) 5026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample().size(), torch.Size()) 5027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4))) 5028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5029*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5030*da0073e9SAndroid Build Coastguard Worker dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5031*da0073e9SAndroid Build Coastguard Worker ) 5032*da0073e9SAndroid Build Coastguard Worker 5033*da0073e9SAndroid Build Coastguard Worker def test_gamma_shape_scalar_params(self): 5034*da0073e9SAndroid Build Coastguard Worker gamma = Gamma(1, 1) 5035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma._batch_shape, torch.Size()) 5036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma._event_shape, torch.Size()) 5037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.sample().size(), torch.Size()) 5038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2))) 5039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.log_prob(self.scalar_sample).size(), torch.Size()) 5040*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5041*da0073e9SAndroid Build Coastguard Worker gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5042*da0073e9SAndroid Build Coastguard Worker ) 5043*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5044*da0073e9SAndroid Build Coastguard Worker gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5045*da0073e9SAndroid Build Coastguard Worker ) 5046*da0073e9SAndroid Build Coastguard Worker 5047*da0073e9SAndroid Build Coastguard Worker def test_gamma_shape_tensor_params(self): 5048*da0073e9SAndroid Build Coastguard Worker gamma = Gamma(torch.tensor([1.0, 1.0]), torch.tensor([1.0, 1.0])) 5049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma._batch_shape, torch.Size((2,))) 5050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma._event_shape, torch.Size(())) 5051*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.sample().size(), torch.Size((2,))) 5052*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5053*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5054*da0073e9SAndroid Build Coastguard Worker gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5055*da0073e9SAndroid Build Coastguard Worker ) 5056*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2) 5057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5058*da0073e9SAndroid Build Coastguard Worker 5059*da0073e9SAndroid Build Coastguard Worker def test_chi2_shape_scalar_params(self): 5060*da0073e9SAndroid Build Coastguard Worker chi2 = Chi2(1) 5061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2._batch_shape, torch.Size()) 5062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2._event_shape, torch.Size()) 5063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.sample().size(), torch.Size()) 5064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2))) 5065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.log_prob(self.scalar_sample).size(), torch.Size()) 5066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5067*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5068*da0073e9SAndroid Build Coastguard Worker chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5069*da0073e9SAndroid Build Coastguard Worker ) 5070*da0073e9SAndroid Build Coastguard Worker 5071*da0073e9SAndroid Build Coastguard Worker def test_chi2_shape_tensor_params(self): 5072*da0073e9SAndroid Build Coastguard Worker chi2 = Chi2(torch.tensor([1.0, 1.0])) 5073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2._batch_shape, torch.Size((2,))) 5074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2._event_shape, torch.Size(())) 5075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.sample().size(), torch.Size((2,))) 5076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5078*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2) 5079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5080*da0073e9SAndroid Build Coastguard Worker 5081*da0073e9SAndroid Build Coastguard Worker def test_studentT_shape_scalar_params(self): 5082*da0073e9SAndroid Build Coastguard Worker st = StudentT(1) 5083*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st._batch_shape, torch.Size()) 5084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st._event_shape, torch.Size()) 5085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.sample().size(), torch.Size()) 5086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2))) 5087*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, st.log_prob, self.scalar_sample) 5088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5089*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5090*da0073e9SAndroid Build Coastguard Worker st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5091*da0073e9SAndroid Build Coastguard Worker ) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker def test_studentT_shape_tensor_params(self): 5094*da0073e9SAndroid Build Coastguard Worker st = StudentT(torch.tensor([1.0, 1.0])) 5095*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st._batch_shape, torch.Size((2,))) 5096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st._event_shape, torch.Size(())) 5097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.sample().size(), torch.Size((2,))) 5098*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5100*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2) 5101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5102*da0073e9SAndroid Build Coastguard Worker 5103*da0073e9SAndroid Build Coastguard Worker def test_pareto_shape_scalar_params(self): 5104*da0073e9SAndroid Build Coastguard Worker pareto = Pareto(1, 1) 5105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pareto._batch_shape, torch.Size()) 5106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pareto._event_shape, torch.Size()) 5107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pareto.sample().size(), torch.Size()) 5108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2))) 5109*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5110*da0073e9SAndroid Build Coastguard Worker pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2)) 5111*da0073e9SAndroid Build Coastguard Worker ) 5112*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5113*da0073e9SAndroid Build Coastguard Worker pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3)) 5114*da0073e9SAndroid Build Coastguard Worker ) 5115*da0073e9SAndroid Build Coastguard Worker 5116*da0073e9SAndroid Build Coastguard Worker def test_gumbel_shape_scalar_params(self): 5117*da0073e9SAndroid Build Coastguard Worker gumbel = Gumbel(1, 1) 5118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gumbel._batch_shape, torch.Size()) 5119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gumbel._event_shape, torch.Size()) 5120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gumbel.sample().size(), torch.Size()) 5121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2))) 5122*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5123*da0073e9SAndroid Build Coastguard Worker gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5124*da0073e9SAndroid Build Coastguard Worker ) 5125*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5126*da0073e9SAndroid Build Coastguard Worker gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5127*da0073e9SAndroid Build Coastguard Worker ) 5128*da0073e9SAndroid Build Coastguard Worker 5129*da0073e9SAndroid Build Coastguard Worker def test_kumaraswamy_shape_scalar_params(self): 5130*da0073e9SAndroid Build Coastguard Worker kumaraswamy = Kumaraswamy(1, 1) 5131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kumaraswamy._batch_shape, torch.Size()) 5132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kumaraswamy._event_shape, torch.Size()) 5133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kumaraswamy.sample().size(), torch.Size()) 5134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2))) 5135*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5136*da0073e9SAndroid Build Coastguard Worker kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5137*da0073e9SAndroid Build Coastguard Worker ) 5138*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5139*da0073e9SAndroid Build Coastguard Worker kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5140*da0073e9SAndroid Build Coastguard Worker ) 5141*da0073e9SAndroid Build Coastguard Worker 5142*da0073e9SAndroid Build Coastguard Worker def test_vonmises_shape_tensor_params(self): 5143*da0073e9SAndroid Build Coastguard Worker von_mises = VonMises(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises._batch_shape, torch.Size((2,))) 5145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises._event_shape, torch.Size(())) 5146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises.sample().size(), torch.Size((2,))) 5147*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5148*da0073e9SAndroid Build Coastguard Worker von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 5149*da0073e9SAndroid Build Coastguard Worker ) 5150*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5151*da0073e9SAndroid Build Coastguard Worker von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5152*da0073e9SAndroid Build Coastguard Worker ) 5153*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5154*da0073e9SAndroid Build Coastguard Worker von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)) 5155*da0073e9SAndroid Build Coastguard Worker ) 5156*da0073e9SAndroid Build Coastguard Worker 5157*da0073e9SAndroid Build Coastguard Worker def test_vonmises_shape_scalar_params(self): 5158*da0073e9SAndroid Build Coastguard Worker von_mises = VonMises(0.0, 1.0) 5159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises._batch_shape, torch.Size()) 5160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises._event_shape, torch.Size()) 5161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(von_mises.sample().size(), torch.Size()) 5162*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5163*da0073e9SAndroid Build Coastguard Worker von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)) 5164*da0073e9SAndroid Build Coastguard Worker ) 5165*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5166*da0073e9SAndroid Build Coastguard Worker von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5167*da0073e9SAndroid Build Coastguard Worker ) 5168*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5169*da0073e9SAndroid Build Coastguard Worker von_mises.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5170*da0073e9SAndroid Build Coastguard Worker ) 5171*da0073e9SAndroid Build Coastguard Worker 5172*da0073e9SAndroid Build Coastguard Worker def test_weibull_scale_scalar_params(self): 5173*da0073e9SAndroid Build Coastguard Worker weibull = Weibull(1, 1) 5174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weibull._batch_shape, torch.Size()) 5175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weibull._event_shape, torch.Size()) 5176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weibull.sample().size(), torch.Size()) 5177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2))) 5178*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5179*da0073e9SAndroid Build Coastguard Worker weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5180*da0073e9SAndroid Build Coastguard Worker ) 5181*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5182*da0073e9SAndroid Build Coastguard Worker weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5183*da0073e9SAndroid Build Coastguard Worker ) 5184*da0073e9SAndroid Build Coastguard Worker 5185*da0073e9SAndroid Build Coastguard Worker def test_wishart_shape_scalar_params(self): 5186*da0073e9SAndroid Build Coastguard Worker wishart = Wishart(torch.tensor(1), torch.tensor([[1.0]])) 5187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart._batch_shape, torch.Size()) 5188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart._event_shape, torch.Size((1, 1))) 5189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart.sample().size(), torch.Size((1, 1))) 5190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1))) 5191*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample) 5192*da0073e9SAndroid Build Coastguard Worker 5193*da0073e9SAndroid Build Coastguard Worker def test_wishart_shape_tensor_params(self): 5194*da0073e9SAndroid Build Coastguard Worker wishart = Wishart(torch.tensor([1.0, 1.0]), torch.tensor([[[1.0]], [[1.0]]])) 5195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart._batch_shape, torch.Size((2,))) 5196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart._event_shape, torch.Size((1, 1))) 5197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1))) 5198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1))) 5199*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2) 5200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,))) 5201*da0073e9SAndroid Build Coastguard Worker 5202*da0073e9SAndroid Build Coastguard Worker def test_normal_shape_scalar_params(self): 5203*da0073e9SAndroid Build Coastguard Worker normal = Normal(0, 1) 5204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal._batch_shape, torch.Size()) 5205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal._event_shape, torch.Size()) 5206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal.sample().size(), torch.Size()) 5207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2))) 5208*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, normal.log_prob, self.scalar_sample) 5209*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5210*da0073e9SAndroid Build Coastguard Worker normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5211*da0073e9SAndroid Build Coastguard Worker ) 5212*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5213*da0073e9SAndroid Build Coastguard Worker normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5214*da0073e9SAndroid Build Coastguard Worker ) 5215*da0073e9SAndroid Build Coastguard Worker 5216*da0073e9SAndroid Build Coastguard Worker def test_normal_shape_tensor_params(self): 5217*da0073e9SAndroid Build Coastguard Worker normal = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal._batch_shape, torch.Size((2,))) 5219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal._event_shape, torch.Size(())) 5220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal.sample().size(), torch.Size((2,))) 5221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5222*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5223*da0073e9SAndroid Build Coastguard Worker normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5224*da0073e9SAndroid Build Coastguard Worker ) 5225*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2) 5226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5227*da0073e9SAndroid Build Coastguard Worker 5228*da0073e9SAndroid Build Coastguard Worker def test_uniform_shape_scalar_params(self): 5229*da0073e9SAndroid Build Coastguard Worker uniform = Uniform(0, 1) 5230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform._batch_shape, torch.Size()) 5231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform._event_shape, torch.Size()) 5232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.sample().size(), torch.Size()) 5233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) 5234*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample) 5235*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5236*da0073e9SAndroid Build Coastguard Worker uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5237*da0073e9SAndroid Build Coastguard Worker ) 5238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5239*da0073e9SAndroid Build Coastguard Worker uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5240*da0073e9SAndroid Build Coastguard Worker ) 5241*da0073e9SAndroid Build Coastguard Worker 5242*da0073e9SAndroid Build Coastguard Worker def test_uniform_shape_tensor_params(self): 5243*da0073e9SAndroid Build Coastguard Worker uniform = Uniform(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform._batch_shape, torch.Size((2,))) 5245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform._event_shape, torch.Size(())) 5246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.sample().size(), torch.Size((2,))) 5247*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5248*da0073e9SAndroid Build Coastguard Worker uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 5249*da0073e9SAndroid Build Coastguard Worker ) 5250*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5251*da0073e9SAndroid Build Coastguard Worker uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5252*da0073e9SAndroid Build Coastguard Worker ) 5253*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2) 5254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5255*da0073e9SAndroid Build Coastguard Worker 5256*da0073e9SAndroid Build Coastguard Worker def test_exponential_shape_scalar_param(self): 5257*da0073e9SAndroid Build Coastguard Worker expon = Exponential(1.0) 5258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon._batch_shape, torch.Size()) 5259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon._event_shape, torch.Size()) 5260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon.sample().size(), torch.Size()) 5261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2))) 5262*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, expon.log_prob, self.scalar_sample) 5263*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5264*da0073e9SAndroid Build Coastguard Worker expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5265*da0073e9SAndroid Build Coastguard Worker ) 5266*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5267*da0073e9SAndroid Build Coastguard Worker expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5268*da0073e9SAndroid Build Coastguard Worker ) 5269*da0073e9SAndroid Build Coastguard Worker 5270*da0073e9SAndroid Build Coastguard Worker def test_exponential_shape_tensor_param(self): 5271*da0073e9SAndroid Build Coastguard Worker expon = Exponential(torch.tensor([1.0, 1.0])) 5272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon._batch_shape, torch.Size((2,))) 5273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon._event_shape, torch.Size(())) 5274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon.sample().size(), torch.Size((2,))) 5275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5276*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5277*da0073e9SAndroid Build Coastguard Worker expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5278*da0073e9SAndroid Build Coastguard Worker ) 5279*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2) 5280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2))) 5281*da0073e9SAndroid Build Coastguard Worker 5282*da0073e9SAndroid Build Coastguard Worker def test_laplace_shape_scalar_params(self): 5283*da0073e9SAndroid Build Coastguard Worker laplace = Laplace(0, 1) 5284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace._batch_shape, torch.Size()) 5285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace._event_shape, torch.Size()) 5286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace.sample().size(), torch.Size()) 5287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2))) 5288*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample) 5289*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5290*da0073e9SAndroid Build Coastguard Worker laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5291*da0073e9SAndroid Build Coastguard Worker ) 5292*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5293*da0073e9SAndroid Build Coastguard Worker laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5294*da0073e9SAndroid Build Coastguard Worker ) 5295*da0073e9SAndroid Build Coastguard Worker 5296*da0073e9SAndroid Build Coastguard Worker def test_laplace_shape_tensor_params(self): 5297*da0073e9SAndroid Build Coastguard Worker laplace = Laplace(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace._batch_shape, torch.Size((2,))) 5299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace._event_shape, torch.Size(())) 5300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace.sample().size(), torch.Size((2,))) 5301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5302*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5303*da0073e9SAndroid Build Coastguard Worker laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5304*da0073e9SAndroid Build Coastguard Worker ) 5305*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2) 5306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5307*da0073e9SAndroid Build Coastguard Worker 5308*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_shape_scalar_params(self): 5309*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli = ContinuousBernoulli(0.3) 5310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli._batch_shape, torch.Size()) 5311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli._event_shape, torch.Size()) 5312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli.sample().size(), torch.Size()) 5313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2))) 5314*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample) 5315*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5316*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli.log_prob(self.tensor_sample_1).size(), 5317*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 2)), 5318*da0073e9SAndroid Build Coastguard Worker ) 5319*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5320*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli.log_prob(self.tensor_sample_2).size(), 5321*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 2, 3)), 5322*da0073e9SAndroid Build Coastguard Worker ) 5323*da0073e9SAndroid Build Coastguard Worker 5324*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_shape_tensor_params(self): 5325*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli = ContinuousBernoulli( 5326*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]) 5327*da0073e9SAndroid Build Coastguard Worker ) 5328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2))) 5329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli._event_shape, torch.Size(())) 5330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2))) 5331*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5332*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)) 5333*da0073e9SAndroid Build Coastguard Worker ) 5334*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5335*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli.log_prob(self.tensor_sample_1).size(), 5336*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 2)), 5337*da0073e9SAndroid Build Coastguard Worker ) 5338*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 5339*da0073e9SAndroid Build Coastguard Worker ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2 5340*da0073e9SAndroid Build Coastguard Worker ) 5341*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5342*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), 5343*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 3, 2)), 5344*da0073e9SAndroid Build Coastguard Worker ) 5345*da0073e9SAndroid Build Coastguard Worker 5346*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 5347*da0073e9SAndroid Build Coastguard Worker def test_mixture_same_family_mean_shape(self): 5348*da0073e9SAndroid Build Coastguard Worker mix_distribution = Categorical(torch.ones([3, 1, 3])) 5349*da0073e9SAndroid Build Coastguard Worker component_distribution = Normal(torch.zeros([3, 3, 3]), torch.ones([3, 3, 3])) 5350*da0073e9SAndroid Build Coastguard Worker gmm = MixtureSameFamily(mix_distribution, component_distribution) 5351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gmm.mean.shape), 2) 5352*da0073e9SAndroid Build Coastguard Worker 5353*da0073e9SAndroid Build Coastguard Worker 5354*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 5355*da0073e9SAndroid Build Coastguard Workerclass TestKL(DistributionsTestCase): 5356*da0073e9SAndroid Build Coastguard Worker def setUp(self): 5357*da0073e9SAndroid Build Coastguard Worker super().setUp() 5358*da0073e9SAndroid Build Coastguard Worker 5359*da0073e9SAndroid Build Coastguard Worker class Binomial30(Binomial): 5360*da0073e9SAndroid Build Coastguard Worker def __init__(self, probs): 5361*da0073e9SAndroid Build Coastguard Worker super().__init__(30, probs) 5362*da0073e9SAndroid Build Coastguard Worker 5363*da0073e9SAndroid Build Coastguard Worker # These are pairs of distributions with 4 x 4 parameters as specified. 5364*da0073e9SAndroid Build Coastguard Worker # The first of the pair e.g. bernoulli[0] varies column-wise and the second 5365*da0073e9SAndroid Build Coastguard Worker # e.g. bernoulli[1] varies row-wise; that way we test all param pairs. 5366*da0073e9SAndroid Build Coastguard Worker bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9]) 5367*da0073e9SAndroid Build Coastguard Worker binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9]) 5368*da0073e9SAndroid Build Coastguard Worker binomial_vectorized_count = ( 5369*da0073e9SAndroid Build Coastguard Worker Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), 5370*da0073e9SAndroid Build Coastguard Worker Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])), 5371*da0073e9SAndroid Build Coastguard Worker ) 5372*da0073e9SAndroid Build Coastguard Worker beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) 5373*da0073e9SAndroid Build Coastguard Worker categorical = pairwise( 5374*da0073e9SAndroid Build Coastguard Worker Categorical, 5375*da0073e9SAndroid Build Coastguard Worker [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], 5376*da0073e9SAndroid Build Coastguard Worker ) 5377*da0073e9SAndroid Build Coastguard Worker cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5378*da0073e9SAndroid Build Coastguard Worker chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0]) 5379*da0073e9SAndroid Build Coastguard Worker dirichlet = pairwise( 5380*da0073e9SAndroid Build Coastguard Worker Dirichlet, 5381*da0073e9SAndroid Build Coastguard Worker [[0.1, 0.2, 0.7], [0.5, 0.4, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]], 5382*da0073e9SAndroid Build Coastguard Worker ) 5383*da0073e9SAndroid Build Coastguard Worker exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0]) 5384*da0073e9SAndroid Build Coastguard Worker gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) 5385*da0073e9SAndroid Build Coastguard Worker gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) 5386*da0073e9SAndroid Build Coastguard Worker halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0]) 5387*da0073e9SAndroid Build Coastguard Worker inversegamma = pairwise( 5388*da0073e9SAndroid Build Coastguard Worker InverseGamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5] 5389*da0073e9SAndroid Build Coastguard Worker ) 5390*da0073e9SAndroid Build Coastguard Worker laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) 5391*da0073e9SAndroid Build Coastguard Worker lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5392*da0073e9SAndroid Build Coastguard Worker normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5393*da0073e9SAndroid Build Coastguard Worker independent = (Independent(normal[0], 1), Independent(normal[1], 1)) 5394*da0073e9SAndroid Build Coastguard Worker onehotcategorical = pairwise( 5395*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 5396*da0073e9SAndroid Build Coastguard Worker [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], 5397*da0073e9SAndroid Build Coastguard Worker ) 5398*da0073e9SAndroid Build Coastguard Worker pareto = ( 5399*da0073e9SAndroid Build Coastguard Worker Pareto( 5400*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4), 5401*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), 5402*da0073e9SAndroid Build Coastguard Worker ), 5403*da0073e9SAndroid Build Coastguard Worker Pareto( 5404*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4), 5405*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), 5406*da0073e9SAndroid Build Coastguard Worker ), 5407*da0073e9SAndroid Build Coastguard Worker ) 5408*da0073e9SAndroid Build Coastguard Worker poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0]) 5409*da0073e9SAndroid Build Coastguard Worker uniform_within_unit = pairwise( 5410*da0073e9SAndroid Build Coastguard Worker Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8] 5411*da0073e9SAndroid Build Coastguard Worker ) 5412*da0073e9SAndroid Build Coastguard Worker uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7]) 5413*da0073e9SAndroid Build Coastguard Worker uniform_real = pairwise(Uniform, [-2.0, -1, 0, 2], [-1.0, 1, 1, 4]) 5414*da0073e9SAndroid Build Coastguard Worker uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5]) 5415*da0073e9SAndroid Build Coastguard Worker continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9]) 5416*da0073e9SAndroid Build Coastguard Worker 5417*da0073e9SAndroid Build Coastguard Worker # These tests should pass with precision = 0.01, but that makes tests very expensive. 5418*da0073e9SAndroid Build Coastguard Worker # Instead, we test with precision = 0.1 and only test with higher precision locally 5419*da0073e9SAndroid Build Coastguard Worker # when adding a new KL implementation. 5420*da0073e9SAndroid Build Coastguard Worker # The following pairs are not tested due to very high variance of the monte carlo 5421*da0073e9SAndroid Build Coastguard Worker # estimator; their implementations have been reviewed with extra care: 5422*da0073e9SAndroid Build Coastguard Worker # - (pareto, normal) 5423*da0073e9SAndroid Build Coastguard Worker self.precision = 0.1 # Set this to 0.01 when testing a new KL implementation. 5424*da0073e9SAndroid Build Coastguard Worker self.max_samples = int(1e07) # Increase this when testing at smaller precision. 5425*da0073e9SAndroid Build Coastguard Worker self.samples_per_batch = int(1e04) 5426*da0073e9SAndroid Build Coastguard Worker self.finite_examples = [ 5427*da0073e9SAndroid Build Coastguard Worker (bernoulli, bernoulli), 5428*da0073e9SAndroid Build Coastguard Worker (bernoulli, poisson), 5429*da0073e9SAndroid Build Coastguard Worker (beta, beta), 5430*da0073e9SAndroid Build Coastguard Worker (beta, chi2), 5431*da0073e9SAndroid Build Coastguard Worker (beta, exponential), 5432*da0073e9SAndroid Build Coastguard Worker (beta, gamma), 5433*da0073e9SAndroid Build Coastguard Worker (beta, normal), 5434*da0073e9SAndroid Build Coastguard Worker (binomial30, binomial30), 5435*da0073e9SAndroid Build Coastguard Worker (binomial_vectorized_count, binomial_vectorized_count), 5436*da0073e9SAndroid Build Coastguard Worker (categorical, categorical), 5437*da0073e9SAndroid Build Coastguard Worker (cauchy, cauchy), 5438*da0073e9SAndroid Build Coastguard Worker (chi2, chi2), 5439*da0073e9SAndroid Build Coastguard Worker (chi2, exponential), 5440*da0073e9SAndroid Build Coastguard Worker (chi2, gamma), 5441*da0073e9SAndroid Build Coastguard Worker (chi2, normal), 5442*da0073e9SAndroid Build Coastguard Worker (dirichlet, dirichlet), 5443*da0073e9SAndroid Build Coastguard Worker (exponential, chi2), 5444*da0073e9SAndroid Build Coastguard Worker (exponential, exponential), 5445*da0073e9SAndroid Build Coastguard Worker (exponential, gamma), 5446*da0073e9SAndroid Build Coastguard Worker (exponential, gumbel), 5447*da0073e9SAndroid Build Coastguard Worker (exponential, normal), 5448*da0073e9SAndroid Build Coastguard Worker (gamma, chi2), 5449*da0073e9SAndroid Build Coastguard Worker (gamma, exponential), 5450*da0073e9SAndroid Build Coastguard Worker (gamma, gamma), 5451*da0073e9SAndroid Build Coastguard Worker (gamma, gumbel), 5452*da0073e9SAndroid Build Coastguard Worker (gamma, normal), 5453*da0073e9SAndroid Build Coastguard Worker (gumbel, gumbel), 5454*da0073e9SAndroid Build Coastguard Worker (gumbel, normal), 5455*da0073e9SAndroid Build Coastguard Worker (halfnormal, halfnormal), 5456*da0073e9SAndroid Build Coastguard Worker (independent, independent), 5457*da0073e9SAndroid Build Coastguard Worker (inversegamma, inversegamma), 5458*da0073e9SAndroid Build Coastguard Worker (laplace, laplace), 5459*da0073e9SAndroid Build Coastguard Worker (lognormal, lognormal), 5460*da0073e9SAndroid Build Coastguard Worker (laplace, normal), 5461*da0073e9SAndroid Build Coastguard Worker (normal, gumbel), 5462*da0073e9SAndroid Build Coastguard Worker (normal, laplace), 5463*da0073e9SAndroid Build Coastguard Worker (normal, normal), 5464*da0073e9SAndroid Build Coastguard Worker (onehotcategorical, onehotcategorical), 5465*da0073e9SAndroid Build Coastguard Worker (pareto, chi2), 5466*da0073e9SAndroid Build Coastguard Worker (pareto, pareto), 5467*da0073e9SAndroid Build Coastguard Worker (pareto, exponential), 5468*da0073e9SAndroid Build Coastguard Worker (pareto, gamma), 5469*da0073e9SAndroid Build Coastguard Worker (poisson, poisson), 5470*da0073e9SAndroid Build Coastguard Worker (uniform_within_unit, beta), 5471*da0073e9SAndroid Build Coastguard Worker (uniform_positive, chi2), 5472*da0073e9SAndroid Build Coastguard Worker (uniform_positive, exponential), 5473*da0073e9SAndroid Build Coastguard Worker (uniform_positive, gamma), 5474*da0073e9SAndroid Build Coastguard Worker (uniform_real, gumbel), 5475*da0073e9SAndroid Build Coastguard Worker (uniform_real, normal), 5476*da0073e9SAndroid Build Coastguard Worker (uniform_pareto, pareto), 5477*da0073e9SAndroid Build Coastguard Worker (continuous_bernoulli, continuous_bernoulli), 5478*da0073e9SAndroid Build Coastguard Worker (continuous_bernoulli, exponential), 5479*da0073e9SAndroid Build Coastguard Worker (continuous_bernoulli, normal), 5480*da0073e9SAndroid Build Coastguard Worker (beta, continuous_bernoulli), 5481*da0073e9SAndroid Build Coastguard Worker ] 5482*da0073e9SAndroid Build Coastguard Worker 5483*da0073e9SAndroid Build Coastguard Worker self.infinite_examples = [ 5484*da0073e9SAndroid Build Coastguard Worker (Bernoulli(0), Bernoulli(1)), 5485*da0073e9SAndroid Build Coastguard Worker (Bernoulli(1), Bernoulli(0)), 5486*da0073e9SAndroid Build Coastguard Worker ( 5487*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([0.9, 0.1])), 5488*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([1.0, 0.0])), 5489*da0073e9SAndroid Build Coastguard Worker ), 5490*da0073e9SAndroid Build Coastguard Worker ( 5491*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([[0.9, 0.1], [0.9, 0.1]])), 5492*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([1.0, 0.0])), 5493*da0073e9SAndroid Build Coastguard Worker ), 5494*da0073e9SAndroid Build Coastguard Worker (Beta(1, 2), Uniform(0.25, 1)), 5495*da0073e9SAndroid Build Coastguard Worker (Beta(1, 2), Uniform(0, 0.75)), 5496*da0073e9SAndroid Build Coastguard Worker (Beta(1, 2), Uniform(0.25, 0.75)), 5497*da0073e9SAndroid Build Coastguard Worker (Beta(1, 2), Pareto(1, 2)), 5498*da0073e9SAndroid Build Coastguard Worker (Binomial(31, 0.7), Binomial(30, 0.3)), 5499*da0073e9SAndroid Build Coastguard Worker ( 5500*da0073e9SAndroid Build Coastguard Worker Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), 5501*da0073e9SAndroid Build Coastguard Worker Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8])), 5502*da0073e9SAndroid Build Coastguard Worker ), 5503*da0073e9SAndroid Build Coastguard Worker (Chi2(1), Beta(2, 3)), 5504*da0073e9SAndroid Build Coastguard Worker (Chi2(1), Pareto(2, 3)), 5505*da0073e9SAndroid Build Coastguard Worker (Chi2(1), Uniform(-2, 3)), 5506*da0073e9SAndroid Build Coastguard Worker (Exponential(1), Beta(2, 3)), 5507*da0073e9SAndroid Build Coastguard Worker (Exponential(1), Pareto(2, 3)), 5508*da0073e9SAndroid Build Coastguard Worker (Exponential(1), Uniform(-2, 3)), 5509*da0073e9SAndroid Build Coastguard Worker (Gamma(1, 2), Beta(3, 4)), 5510*da0073e9SAndroid Build Coastguard Worker (Gamma(1, 2), Pareto(3, 4)), 5511*da0073e9SAndroid Build Coastguard Worker (Gamma(1, 2), Uniform(-3, 4)), 5512*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Beta(3, 4)), 5513*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Chi2(3)), 5514*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Exponential(3)), 5515*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Gamma(3, 4)), 5516*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Pareto(3, 4)), 5517*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), Uniform(-3, 4)), 5518*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Beta(3, 4)), 5519*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Chi2(3)), 5520*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Exponential(3)), 5521*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Gamma(3, 4)), 5522*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Pareto(3, 4)), 5523*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), Uniform(-3, 4)), 5524*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Beta(3, 4)), 5525*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Chi2(3)), 5526*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Exponential(3)), 5527*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Gamma(3, 4)), 5528*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Pareto(3, 4)), 5529*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), Uniform(-3, 4)), 5530*da0073e9SAndroid Build Coastguard Worker (Pareto(2, 1), Chi2(3)), 5531*da0073e9SAndroid Build Coastguard Worker (Pareto(2, 1), Exponential(3)), 5532*da0073e9SAndroid Build Coastguard Worker (Pareto(2, 1), Gamma(3, 4)), 5533*da0073e9SAndroid Build Coastguard Worker (Pareto(1, 2), Normal(-3, 4)), 5534*da0073e9SAndroid Build Coastguard Worker (Pareto(1, 2), Pareto(3, 4)), 5535*da0073e9SAndroid Build Coastguard Worker (Poisson(2), Bernoulli(0.5)), 5536*da0073e9SAndroid Build Coastguard Worker (Poisson(2.3), Binomial(10, 0.2)), 5537*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 1), Beta(2, 2)), 5538*da0073e9SAndroid Build Coastguard Worker (Uniform(0, 2), Beta(3, 4)), 5539*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), Beta(3, 4)), 5540*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), Chi2(3)), 5541*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), Exponential(3)), 5542*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), Gamma(3, 4)), 5543*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), Pareto(3, 4)), 5544*da0073e9SAndroid Build Coastguard Worker (ContinuousBernoulli(0.25), Uniform(0.25, 1)), 5545*da0073e9SAndroid Build Coastguard Worker (ContinuousBernoulli(0.25), Uniform(0, 0.75)), 5546*da0073e9SAndroid Build Coastguard Worker (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)), 5547*da0073e9SAndroid Build Coastguard Worker (ContinuousBernoulli(0.25), Pareto(1, 2)), 5548*da0073e9SAndroid Build Coastguard Worker (Exponential(1), ContinuousBernoulli(0.75)), 5549*da0073e9SAndroid Build Coastguard Worker (Gamma(1, 2), ContinuousBernoulli(0.75)), 5550*da0073e9SAndroid Build Coastguard Worker (Gumbel(-1, 2), ContinuousBernoulli(0.75)), 5551*da0073e9SAndroid Build Coastguard Worker (Laplace(-1, 2), ContinuousBernoulli(0.75)), 5552*da0073e9SAndroid Build Coastguard Worker (Normal(-1, 2), ContinuousBernoulli(0.75)), 5553*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 1), ContinuousBernoulli(0.75)), 5554*da0073e9SAndroid Build Coastguard Worker (Uniform(0, 2), ContinuousBernoulli(0.75)), 5555*da0073e9SAndroid Build Coastguard Worker (Uniform(-1, 2), ContinuousBernoulli(0.75)), 5556*da0073e9SAndroid Build Coastguard Worker ] 5557*da0073e9SAndroid Build Coastguard Worker 5558*da0073e9SAndroid Build Coastguard Worker def test_kl_monte_carlo(self): 5559*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 5560*da0073e9SAndroid Build Coastguard Worker for (p, _), (_, q) in self.finite_examples: 5561*da0073e9SAndroid Build Coastguard Worker actual = kl_divergence(p, q) 5562*da0073e9SAndroid Build Coastguard Worker numerator = 0 5563*da0073e9SAndroid Build Coastguard Worker denominator = 0 5564*da0073e9SAndroid Build Coastguard Worker while denominator < self.max_samples: 5565*da0073e9SAndroid Build Coastguard Worker x = p.sample(sample_shape=(self.samples_per_batch,)) 5566*da0073e9SAndroid Build Coastguard Worker numerator += (p.log_prob(x) - q.log_prob(x)).sum(0) 5567*da0073e9SAndroid Build Coastguard Worker denominator += x.size(0) 5568*da0073e9SAndroid Build Coastguard Worker expected = numerator / denominator 5569*da0073e9SAndroid Build Coastguard Worker error = torch.abs(expected - actual) / (1 + expected) 5570*da0073e9SAndroid Build Coastguard Worker if error[error == error].max() < self.precision: 5571*da0073e9SAndroid Build Coastguard Worker break 5572*da0073e9SAndroid Build Coastguard Worker self.assertLess( 5573*da0073e9SAndroid Build Coastguard Worker error[error == error].max(), 5574*da0073e9SAndroid Build Coastguard Worker self.precision, 5575*da0073e9SAndroid Build Coastguard Worker "\n".join( 5576*da0073e9SAndroid Build Coastguard Worker [ 5577*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL({type(p).__name__}, {type(q).__name__}).", 5578*da0073e9SAndroid Build Coastguard Worker f"Expected ({denominator} Monte Carlo samples): {expected}", 5579*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic): {actual}", 5580*da0073e9SAndroid Build Coastguard Worker ] 5581*da0073e9SAndroid Build Coastguard Worker ), 5582*da0073e9SAndroid Build Coastguard Worker ) 5583*da0073e9SAndroid Build Coastguard Worker 5584*da0073e9SAndroid Build Coastguard Worker # Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of 5585*da0073e9SAndroid Build Coastguard Worker # positive (semi) definite matrices. n is set to 5, but can be increased during testing. 5586*da0073e9SAndroid Build Coastguard Worker def test_kl_multivariate_normal(self): 5587*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 5588*da0073e9SAndroid Build Coastguard Worker n = 5 # Number of tests for multivariate_normal 5589*da0073e9SAndroid Build Coastguard Worker for i in range(0, n): 5590*da0073e9SAndroid Build Coastguard Worker loc = [torch.randn(4) for _ in range(0, 2)] 5591*da0073e9SAndroid Build Coastguard Worker scale_tril = [ 5592*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.lower_cholesky)(torch.randn(4, 4)) 5593*da0073e9SAndroid Build Coastguard Worker for _ in range(0, 2) 5594*da0073e9SAndroid Build Coastguard Worker ] 5595*da0073e9SAndroid Build Coastguard Worker p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0]) 5596*da0073e9SAndroid Build Coastguard Worker q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1]) 5597*da0073e9SAndroid Build Coastguard Worker actual = kl_divergence(p, q) 5598*da0073e9SAndroid Build Coastguard Worker numerator = 0 5599*da0073e9SAndroid Build Coastguard Worker denominator = 0 5600*da0073e9SAndroid Build Coastguard Worker while denominator < self.max_samples: 5601*da0073e9SAndroid Build Coastguard Worker x = p.sample(sample_shape=(self.samples_per_batch,)) 5602*da0073e9SAndroid Build Coastguard Worker numerator += (p.log_prob(x) - q.log_prob(x)).sum(0) 5603*da0073e9SAndroid Build Coastguard Worker denominator += x.size(0) 5604*da0073e9SAndroid Build Coastguard Worker expected = numerator / denominator 5605*da0073e9SAndroid Build Coastguard Worker error = torch.abs(expected - actual) / (1 + expected) 5606*da0073e9SAndroid Build Coastguard Worker if error[error == error].max() < self.precision: 5607*da0073e9SAndroid Build Coastguard Worker break 5608*da0073e9SAndroid Build Coastguard Worker self.assertLess( 5609*da0073e9SAndroid Build Coastguard Worker error[error == error].max(), 5610*da0073e9SAndroid Build Coastguard Worker self.precision, 5611*da0073e9SAndroid Build Coastguard Worker "\n".join( 5612*da0073e9SAndroid Build Coastguard Worker [ 5613*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL(MultivariateNormal, MultivariateNormal) instance {i + 1}/{n}", 5614*da0073e9SAndroid Build Coastguard Worker f"Expected ({denominator} Monte Carlo sample): {expected}", 5615*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic): {actual}", 5616*da0073e9SAndroid Build Coastguard Worker ] 5617*da0073e9SAndroid Build Coastguard Worker ), 5618*da0073e9SAndroid Build Coastguard Worker ) 5619*da0073e9SAndroid Build Coastguard Worker 5620*da0073e9SAndroid Build Coastguard Worker def test_kl_multivariate_normal_batched(self): 5621*da0073e9SAndroid Build Coastguard Worker b = 7 # Number of batches 5622*da0073e9SAndroid Build Coastguard Worker loc = [torch.randn(b, 3) for _ in range(0, 2)] 5623*da0073e9SAndroid Build Coastguard Worker scale_tril = [ 5624*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)) 5625*da0073e9SAndroid Build Coastguard Worker for _ in range(0, 2) 5626*da0073e9SAndroid Build Coastguard Worker ] 5627*da0073e9SAndroid Build Coastguard Worker expected_kl = torch.stack( 5628*da0073e9SAndroid Build Coastguard Worker [ 5629*da0073e9SAndroid Build Coastguard Worker kl_divergence( 5630*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]), 5631*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i]), 5632*da0073e9SAndroid Build Coastguard Worker ) 5633*da0073e9SAndroid Build Coastguard Worker for i in range(0, b) 5634*da0073e9SAndroid Build Coastguard Worker ] 5635*da0073e9SAndroid Build Coastguard Worker ) 5636*da0073e9SAndroid Build Coastguard Worker actual_kl = kl_divergence( 5637*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[0], scale_tril=scale_tril[0]), 5638*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[1], scale_tril=scale_tril[1]), 5639*da0073e9SAndroid Build Coastguard Worker ) 5640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_kl, actual_kl) 5641*da0073e9SAndroid Build Coastguard Worker 5642*da0073e9SAndroid Build Coastguard Worker def test_kl_multivariate_normal_batched_broadcasted(self): 5643*da0073e9SAndroid Build Coastguard Worker b = 7 # Number of batches 5644*da0073e9SAndroid Build Coastguard Worker loc = [torch.randn(b, 3) for _ in range(0, 2)] 5645*da0073e9SAndroid Build Coastguard Worker scale_tril = [ 5646*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)), 5647*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.lower_cholesky)(torch.randn(3, 3)), 5648*da0073e9SAndroid Build Coastguard Worker ] 5649*da0073e9SAndroid Build Coastguard Worker expected_kl = torch.stack( 5650*da0073e9SAndroid Build Coastguard Worker [ 5651*da0073e9SAndroid Build Coastguard Worker kl_divergence( 5652*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]), 5653*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[1][i], scale_tril=scale_tril[1]), 5654*da0073e9SAndroid Build Coastguard Worker ) 5655*da0073e9SAndroid Build Coastguard Worker for i in range(0, b) 5656*da0073e9SAndroid Build Coastguard Worker ] 5657*da0073e9SAndroid Build Coastguard Worker ) 5658*da0073e9SAndroid Build Coastguard Worker actual_kl = kl_divergence( 5659*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[0], scale_tril=scale_tril[0]), 5660*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(loc[1], scale_tril=scale_tril[1]), 5661*da0073e9SAndroid Build Coastguard Worker ) 5662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_kl, actual_kl) 5663*da0073e9SAndroid Build Coastguard Worker 5664*da0073e9SAndroid Build Coastguard Worker def test_kl_lowrank_multivariate_normal(self): 5665*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 5666*da0073e9SAndroid Build Coastguard Worker n = 5 # Number of tests for lowrank_multivariate_normal 5667*da0073e9SAndroid Build Coastguard Worker for i in range(0, n): 5668*da0073e9SAndroid Build Coastguard Worker loc = [torch.randn(4) for _ in range(0, 2)] 5669*da0073e9SAndroid Build Coastguard Worker cov_factor = [torch.randn(4, 3) for _ in range(0, 2)] 5670*da0073e9SAndroid Build Coastguard Worker cov_diag = [ 5671*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2) 5672*da0073e9SAndroid Build Coastguard Worker ] 5673*da0073e9SAndroid Build Coastguard Worker covariance_matrix = [ 5674*da0073e9SAndroid Build Coastguard Worker cov_factor[i].matmul(cov_factor[i].t()) + cov_diag[i].diag() 5675*da0073e9SAndroid Build Coastguard Worker for i in range(0, 2) 5676*da0073e9SAndroid Build Coastguard Worker ] 5677*da0073e9SAndroid Build Coastguard Worker p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]) 5678*da0073e9SAndroid Build Coastguard Worker q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]) 5679*da0073e9SAndroid Build Coastguard Worker p_full = MultivariateNormal(loc[0], covariance_matrix[0]) 5680*da0073e9SAndroid Build Coastguard Worker q_full = MultivariateNormal(loc[1], covariance_matrix[1]) 5681*da0073e9SAndroid Build Coastguard Worker expected = kl_divergence(p_full, q_full) 5682*da0073e9SAndroid Build Coastguard Worker 5683*da0073e9SAndroid Build Coastguard Worker actual_lowrank_lowrank = kl_divergence(p, q) 5684*da0073e9SAndroid Build Coastguard Worker actual_lowrank_full = kl_divergence(p, q_full) 5685*da0073e9SAndroid Build Coastguard Worker actual_full_lowrank = kl_divergence(p_full, q) 5686*da0073e9SAndroid Build Coastguard Worker 5687*da0073e9SAndroid Build Coastguard Worker error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max() 5688*da0073e9SAndroid Build Coastguard Worker self.assertLess( 5689*da0073e9SAndroid Build Coastguard Worker error_lowrank_lowrank, 5690*da0073e9SAndroid Build Coastguard Worker self.precision, 5691*da0073e9SAndroid Build Coastguard Worker "\n".join( 5692*da0073e9SAndroid Build Coastguard Worker [ 5693*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}", 5694*da0073e9SAndroid Build Coastguard Worker f"Expected (from KL MultivariateNormal): {expected}", 5695*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic): {actual_lowrank_lowrank}", 5696*da0073e9SAndroid Build Coastguard Worker ] 5697*da0073e9SAndroid Build Coastguard Worker ), 5698*da0073e9SAndroid Build Coastguard Worker ) 5699*da0073e9SAndroid Build Coastguard Worker 5700*da0073e9SAndroid Build Coastguard Worker error_lowrank_full = torch.abs(actual_lowrank_full - expected).max() 5701*da0073e9SAndroid Build Coastguard Worker self.assertLess( 5702*da0073e9SAndroid Build Coastguard Worker error_lowrank_full, 5703*da0073e9SAndroid Build Coastguard Worker self.precision, 5704*da0073e9SAndroid Build Coastguard Worker "\n".join( 5705*da0073e9SAndroid Build Coastguard Worker [ 5706*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {i + 1}/{n}", 5707*da0073e9SAndroid Build Coastguard Worker f"Expected (from KL MultivariateNormal): {expected}", 5708*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic): {actual_lowrank_full}", 5709*da0073e9SAndroid Build Coastguard Worker ] 5710*da0073e9SAndroid Build Coastguard Worker ), 5711*da0073e9SAndroid Build Coastguard Worker ) 5712*da0073e9SAndroid Build Coastguard Worker 5713*da0073e9SAndroid Build Coastguard Worker error_full_lowrank = torch.abs(actual_full_lowrank - expected).max() 5714*da0073e9SAndroid Build Coastguard Worker self.assertLess( 5715*da0073e9SAndroid Build Coastguard Worker error_full_lowrank, 5716*da0073e9SAndroid Build Coastguard Worker self.precision, 5717*da0073e9SAndroid Build Coastguard Worker "\n".join( 5718*da0073e9SAndroid Build Coastguard Worker [ 5719*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}", 5720*da0073e9SAndroid Build Coastguard Worker f"Expected (from KL MultivariateNormal): {expected}", 5721*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic): {actual_full_lowrank}", 5722*da0073e9SAndroid Build Coastguard Worker ] 5723*da0073e9SAndroid Build Coastguard Worker ), 5724*da0073e9SAndroid Build Coastguard Worker ) 5725*da0073e9SAndroid Build Coastguard Worker 5726*da0073e9SAndroid Build Coastguard Worker def test_kl_lowrank_multivariate_normal_batched(self): 5727*da0073e9SAndroid Build Coastguard Worker b = 7 # Number of batches 5728*da0073e9SAndroid Build Coastguard Worker loc = [torch.randn(b, 3) for _ in range(0, 2)] 5729*da0073e9SAndroid Build Coastguard Worker cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)] 5730*da0073e9SAndroid Build Coastguard Worker cov_diag = [ 5731*da0073e9SAndroid Build Coastguard Worker transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2) 5732*da0073e9SAndroid Build Coastguard Worker ] 5733*da0073e9SAndroid Build Coastguard Worker expected_kl = torch.stack( 5734*da0073e9SAndroid Build Coastguard Worker [ 5735*da0073e9SAndroid Build Coastguard Worker kl_divergence( 5736*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal( 5737*da0073e9SAndroid Build Coastguard Worker loc[0][i], cov_factor[0][i], cov_diag[0][i] 5738*da0073e9SAndroid Build Coastguard Worker ), 5739*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal( 5740*da0073e9SAndroid Build Coastguard Worker loc[1][i], cov_factor[1][i], cov_diag[1][i] 5741*da0073e9SAndroid Build Coastguard Worker ), 5742*da0073e9SAndroid Build Coastguard Worker ) 5743*da0073e9SAndroid Build Coastguard Worker for i in range(0, b) 5744*da0073e9SAndroid Build Coastguard Worker ] 5745*da0073e9SAndroid Build Coastguard Worker ) 5746*da0073e9SAndroid Build Coastguard Worker actual_kl = kl_divergence( 5747*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]), 5748*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]), 5749*da0073e9SAndroid Build Coastguard Worker ) 5750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_kl, actual_kl) 5751*da0073e9SAndroid Build Coastguard Worker 5752*da0073e9SAndroid Build Coastguard Worker def test_kl_exponential_family(self): 5753*da0073e9SAndroid Build Coastguard Worker for (p, _), (_, q) in self.finite_examples: 5754*da0073e9SAndroid Build Coastguard Worker if type(p) == type(q) and issubclass(type(p), ExponentialFamily): 5755*da0073e9SAndroid Build Coastguard Worker actual = kl_divergence(p, q) 5756*da0073e9SAndroid Build Coastguard Worker expected = _kl_expfamily_expfamily(p, q) 5757*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5758*da0073e9SAndroid Build Coastguard Worker actual, 5759*da0073e9SAndroid Build Coastguard Worker expected, 5760*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 5761*da0073e9SAndroid Build Coastguard Worker [ 5762*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL({type(p).__name__}, {type(q).__name__}).", 5763*da0073e9SAndroid Build Coastguard Worker f"Expected (using Bregman Divergence) {expected}", 5764*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic) {actual}", 5765*da0073e9SAndroid Build Coastguard Worker f"max error = {torch.abs(actual - expected).max()}", 5766*da0073e9SAndroid Build Coastguard Worker ] 5767*da0073e9SAndroid Build Coastguard Worker ), 5768*da0073e9SAndroid Build Coastguard Worker ) 5769*da0073e9SAndroid Build Coastguard Worker 5770*da0073e9SAndroid Build Coastguard Worker def test_kl_infinite(self): 5771*da0073e9SAndroid Build Coastguard Worker for p, q in self.infinite_examples: 5772*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5773*da0073e9SAndroid Build Coastguard Worker (kl_divergence(p, q) == inf).all(), 5774*da0073e9SAndroid Build Coastguard Worker f"Incorrect KL({type(p).__name__}, {type(q).__name__})", 5775*da0073e9SAndroid Build Coastguard Worker ) 5776*da0073e9SAndroid Build Coastguard Worker 5777*da0073e9SAndroid Build Coastguard Worker def test_kl_edgecases(self): 5778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0) 5779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0) 5780*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5781*da0073e9SAndroid Build Coastguard Worker kl_divergence( 5782*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([0.0, 1.0])), 5783*da0073e9SAndroid Build Coastguard Worker Categorical(torch.tensor([0.0, 1.0])), 5784*da0073e9SAndroid Build Coastguard Worker ), 5785*da0073e9SAndroid Build Coastguard Worker 0, 5786*da0073e9SAndroid Build Coastguard Worker ) 5787*da0073e9SAndroid Build Coastguard Worker 5788*da0073e9SAndroid Build Coastguard Worker def test_kl_shape(self): 5789*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 5790*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 5791*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 5792*da0073e9SAndroid Build Coastguard Worker try: 5793*da0073e9SAndroid Build Coastguard Worker kl = kl_divergence(dist, dist) 5794*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 5795*da0073e9SAndroid Build Coastguard Worker continue 5796*da0073e9SAndroid Build Coastguard Worker expected_shape = dist.batch_shape if dist.batch_shape else torch.Size() 5797*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5798*da0073e9SAndroid Build Coastguard Worker kl.shape, 5799*da0073e9SAndroid Build Coastguard Worker expected_shape, 5800*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 5801*da0073e9SAndroid Build Coastguard Worker [ 5802*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)}", 5803*da0073e9SAndroid Build Coastguard Worker f"Expected {expected_shape}", 5804*da0073e9SAndroid Build Coastguard Worker f"Actual {kl.shape}", 5805*da0073e9SAndroid Build Coastguard Worker ] 5806*da0073e9SAndroid Build Coastguard Worker ), 5807*da0073e9SAndroid Build Coastguard Worker ) 5808*da0073e9SAndroid Build Coastguard Worker 5809*da0073e9SAndroid Build Coastguard Worker def test_kl_transformed(self): 5810*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/34859 5811*da0073e9SAndroid Build Coastguard Worker scale = torch.ones(2, 3) 5812*da0073e9SAndroid Build Coastguard Worker loc = torch.zeros(2, 3) 5813*da0073e9SAndroid Build Coastguard Worker normal = Normal(loc=loc, scale=scale) 5814*da0073e9SAndroid Build Coastguard Worker diag_normal = Independent(normal, reinterpreted_batch_ndims=1) 5815*da0073e9SAndroid Build Coastguard Worker trans_dist = TransformedDistribution( 5816*da0073e9SAndroid Build Coastguard Worker diag_normal, AffineTransform(loc=0.0, scale=2.0) 5817*da0073e9SAndroid Build Coastguard Worker ) 5818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,)) 5819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,)) 5820*da0073e9SAndroid Build Coastguard Worker 5821*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 5822*da0073e9SAndroid Build Coastguard Worker def test_entropy_monte_carlo(self): 5823*da0073e9SAndroid Build Coastguard Worker set_rng_seed(0) # see Note [Randomized statistical tests] 5824*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 5825*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 5826*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 5827*da0073e9SAndroid Build Coastguard Worker try: 5828*da0073e9SAndroid Build Coastguard Worker actual = dist.entropy() 5829*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 5830*da0073e9SAndroid Build Coastguard Worker continue 5831*da0073e9SAndroid Build Coastguard Worker x = dist.sample(sample_shape=(60000,)) 5832*da0073e9SAndroid Build Coastguard Worker expected = -dist.log_prob(x).mean(0) 5833*da0073e9SAndroid Build Coastguard Worker ignore = (expected == inf) | (expected == -inf) 5834*da0073e9SAndroid Build Coastguard Worker expected[ignore] = actual[ignore] 5835*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5836*da0073e9SAndroid Build Coastguard Worker actual, 5837*da0073e9SAndroid Build Coastguard Worker expected, 5838*da0073e9SAndroid Build Coastguard Worker atol=0.2, 5839*da0073e9SAndroid Build Coastguard Worker rtol=0, 5840*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 5841*da0073e9SAndroid Build Coastguard Worker [ 5842*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().", 5843*da0073e9SAndroid Build Coastguard Worker f"Expected (monte carlo) {expected}", 5844*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic) {actual}", 5845*da0073e9SAndroid Build Coastguard Worker f"max error = {torch.abs(actual - expected).max()}", 5846*da0073e9SAndroid Build Coastguard Worker ] 5847*da0073e9SAndroid Build Coastguard Worker ), 5848*da0073e9SAndroid Build Coastguard Worker ) 5849*da0073e9SAndroid Build Coastguard Worker 5850*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 5851*da0073e9SAndroid Build Coastguard Worker def test_entropy_exponential_family(self): 5852*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 5853*da0073e9SAndroid Build Coastguard Worker if not issubclass(Dist, ExponentialFamily): 5854*da0073e9SAndroid Build Coastguard Worker continue 5855*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 5856*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 5857*da0073e9SAndroid Build Coastguard Worker try: 5858*da0073e9SAndroid Build Coastguard Worker actual = dist.entropy() 5859*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 5860*da0073e9SAndroid Build Coastguard Worker continue 5861*da0073e9SAndroid Build Coastguard Worker try: 5862*da0073e9SAndroid Build Coastguard Worker expected = ExponentialFamily.entropy(dist) 5863*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 5864*da0073e9SAndroid Build Coastguard Worker continue 5865*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5866*da0073e9SAndroid Build Coastguard Worker actual, 5867*da0073e9SAndroid Build Coastguard Worker expected, 5868*da0073e9SAndroid Build Coastguard Worker msg="\n".join( 5869*da0073e9SAndroid Build Coastguard Worker [ 5870*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().", 5871*da0073e9SAndroid Build Coastguard Worker f"Expected (Bregman Divergence) {expected}", 5872*da0073e9SAndroid Build Coastguard Worker f"Actual (analytic) {actual}", 5873*da0073e9SAndroid Build Coastguard Worker f"max error = {torch.abs(actual - expected).max()}", 5874*da0073e9SAndroid Build Coastguard Worker ] 5875*da0073e9SAndroid Build Coastguard Worker ), 5876*da0073e9SAndroid Build Coastguard Worker ) 5877*da0073e9SAndroid Build Coastguard Worker 5878*da0073e9SAndroid Build Coastguard Worker 5879*da0073e9SAndroid Build Coastguard Workerclass TestConstraints(DistributionsTestCase): 5880*da0073e9SAndroid Build Coastguard Worker def test_params_constraints(self): 5881*da0073e9SAndroid Build Coastguard Worker normalize_probs_dists = ( 5882*da0073e9SAndroid Build Coastguard Worker Categorical, 5883*da0073e9SAndroid Build Coastguard Worker Multinomial, 5884*da0073e9SAndroid Build Coastguard Worker OneHotCategorical, 5885*da0073e9SAndroid Build Coastguard Worker OneHotCategoricalStraightThrough, 5886*da0073e9SAndroid Build Coastguard Worker RelaxedOneHotCategorical, 5887*da0073e9SAndroid Build Coastguard Worker ) 5888*da0073e9SAndroid Build Coastguard Worker 5889*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 5890*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 5891*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 5892*da0073e9SAndroid Build Coastguard Worker for name, value in param.items(): 5893*da0073e9SAndroid Build Coastguard Worker if isinstance(value, numbers.Number): 5894*da0073e9SAndroid Build Coastguard Worker value = torch.tensor([value]) 5895*da0073e9SAndroid Build Coastguard Worker if Dist in normalize_probs_dists and name == "probs": 5896*da0073e9SAndroid Build Coastguard Worker # These distributions accept positive probs, but elsewhere we 5897*da0073e9SAndroid Build Coastguard Worker # use a stricter constraint to the simplex. 5898*da0073e9SAndroid Build Coastguard Worker value = value / value.sum(-1, True) 5899*da0073e9SAndroid Build Coastguard Worker try: 5900*da0073e9SAndroid Build Coastguard Worker constraint = dist.arg_constraints[name] 5901*da0073e9SAndroid Build Coastguard Worker except KeyError: 5902*da0073e9SAndroid Build Coastguard Worker continue # ignore optional parameters 5903*da0073e9SAndroid Build Coastguard Worker 5904*da0073e9SAndroid Build Coastguard Worker # Check param shape is compatible with distribution shape. 5905*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(value.dim(), constraint.event_dim) 5906*da0073e9SAndroid Build Coastguard Worker value_batch_shape = value.shape[ 5907*da0073e9SAndroid Build Coastguard Worker : value.dim() - constraint.event_dim 5908*da0073e9SAndroid Build Coastguard Worker ] 5909*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(dist.batch_shape, value_batch_shape) 5910*da0073e9SAndroid Build Coastguard Worker 5911*da0073e9SAndroid Build Coastguard Worker if is_dependent(constraint): 5912*da0073e9SAndroid Build Coastguard Worker continue 5913*da0073e9SAndroid Build Coastguard Worker 5914*da0073e9SAndroid Build Coastguard Worker message = f"{Dist.__name__} example {i + 1}/{len(params)} parameter {name} = {value}" 5915*da0073e9SAndroid Build Coastguard Worker self.assertTrue(constraint.check(value).all(), msg=message) 5916*da0073e9SAndroid Build Coastguard Worker 5917*da0073e9SAndroid Build Coastguard Worker def test_support_constraints(self): 5918*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 5919*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(Dist.support, Constraint) 5920*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 5921*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 5922*da0073e9SAndroid Build Coastguard Worker value = dist.sample() 5923*da0073e9SAndroid Build Coastguard Worker constraint = dist.support 5924*da0073e9SAndroid Build Coastguard Worker message = ( 5925*da0073e9SAndroid Build Coastguard Worker f"{Dist.__name__} example {i + 1}/{len(params)} sample = {value}" 5926*da0073e9SAndroid Build Coastguard Worker ) 5927*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5928*da0073e9SAndroid Build Coastguard Worker constraint.event_dim, len(dist.event_shape), msg=message 5929*da0073e9SAndroid Build Coastguard Worker ) 5930*da0073e9SAndroid Build Coastguard Worker ok = constraint.check(value) 5931*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ok.shape, dist.batch_shape, msg=message) 5932*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ok.all(), msg=message) 5933*da0073e9SAndroid Build Coastguard Worker 5934*da0073e9SAndroid Build Coastguard Worker 5935*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a TorchDynamo suitable test") 5936*da0073e9SAndroid Build Coastguard Workerclass TestNumericalStability(DistributionsTestCase): 5937*da0073e9SAndroid Build Coastguard Worker def _test_pdf_score( 5938*da0073e9SAndroid Build Coastguard Worker self, 5939*da0073e9SAndroid Build Coastguard Worker dist_class, 5940*da0073e9SAndroid Build Coastguard Worker x, 5941*da0073e9SAndroid Build Coastguard Worker expected_value, 5942*da0073e9SAndroid Build Coastguard Worker probs=None, 5943*da0073e9SAndroid Build Coastguard Worker logits=None, 5944*da0073e9SAndroid Build Coastguard Worker expected_gradient=None, 5945*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 5946*da0073e9SAndroid Build Coastguard Worker ): 5947*da0073e9SAndroid Build Coastguard Worker if probs is not None: 5948*da0073e9SAndroid Build Coastguard Worker p = probs.detach().requires_grad_() 5949*da0073e9SAndroid Build Coastguard Worker dist = dist_class(p) 5950*da0073e9SAndroid Build Coastguard Worker else: 5951*da0073e9SAndroid Build Coastguard Worker p = logits.detach().requires_grad_() 5952*da0073e9SAndroid Build Coastguard Worker dist = dist_class(logits=p) 5953*da0073e9SAndroid Build Coastguard Worker log_pdf = dist.log_prob(x) 5954*da0073e9SAndroid Build Coastguard Worker log_pdf.sum().backward() 5955*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5956*da0073e9SAndroid Build Coastguard Worker log_pdf, 5957*da0073e9SAndroid Build Coastguard Worker expected_value, 5958*da0073e9SAndroid Build Coastguard Worker atol=atol, 5959*da0073e9SAndroid Build Coastguard Worker rtol=0, 5960*da0073e9SAndroid Build Coastguard Worker msg=f"Incorrect value for tensor type: {type(x)}. Expected = {expected_value}, Actual = {log_pdf}", 5961*da0073e9SAndroid Build Coastguard Worker ) 5962*da0073e9SAndroid Build Coastguard Worker if expected_gradient is not None: 5963*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 5964*da0073e9SAndroid Build Coastguard Worker p.grad, 5965*da0073e9SAndroid Build Coastguard Worker expected_gradient, 5966*da0073e9SAndroid Build Coastguard Worker atol=atol, 5967*da0073e9SAndroid Build Coastguard Worker rtol=0, 5968*da0073e9SAndroid Build Coastguard Worker msg=f"Incorrect gradient for tensor type: {type(x)}. Expected = {expected_gradient}, Actual = {p.grad}", 5969*da0073e9SAndroid Build Coastguard Worker ) 5970*da0073e9SAndroid Build Coastguard Worker 5971*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_gradient(self): 5972*da0073e9SAndroid Build Coastguard Worker for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: 5973*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 5974*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 5975*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([0]), 5976*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 5977*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([0]), 5978*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0]), 5979*da0073e9SAndroid Build Coastguard Worker ) 5980*da0073e9SAndroid Build Coastguard Worker 5981*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 5982*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 5983*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([0]), 5984*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1]), 5985*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type( 5986*da0073e9SAndroid Build Coastguard Worker [torch.finfo(tensor_type([]).dtype).eps] 5987*da0073e9SAndroid Build Coastguard Worker ).log(), 5988*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0]), 5989*da0073e9SAndroid Build Coastguard Worker ) 5990*da0073e9SAndroid Build Coastguard Worker 5991*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 5992*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 5993*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([1e-4]), 5994*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1]), 5995*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([math.log(1e-4)]), 5996*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([10000]), 5997*da0073e9SAndroid Build Coastguard Worker ) 5998*da0073e9SAndroid Build Coastguard Worker 5999*da0073e9SAndroid Build Coastguard Worker # Lower precision due to: 6000*da0073e9SAndroid Build Coastguard Worker # >>> 1 / (1 - torch.FloatTensor([0.9999])) 6001*da0073e9SAndroid Build Coastguard Worker # 9998.3408 6002*da0073e9SAndroid Build Coastguard Worker # [torch.FloatTensor of size 1] 6003*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6004*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 6005*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([1 - 1e-4]), 6006*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 6007*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([math.log(1e-4)]), 6008*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([-10000]), 6009*da0073e9SAndroid Build Coastguard Worker atol=2, 6010*da0073e9SAndroid Build Coastguard Worker ) 6011*da0073e9SAndroid Build Coastguard Worker 6012*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6013*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 6014*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([math.log(9999)]), 6015*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 6016*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([math.log(1e-4)]), 6017*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([-1]), 6018*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 6019*da0073e9SAndroid Build Coastguard Worker ) 6020*da0073e9SAndroid Build Coastguard Worker 6021*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_with_logits_underflow(self): 6022*da0073e9SAndroid Build Coastguard Worker for tensor_type, lim in [ 6023*da0073e9SAndroid Build Coastguard Worker (torch.FloatTensor, -1e38), 6024*da0073e9SAndroid Build Coastguard Worker (torch.DoubleTensor, -1e308), 6025*da0073e9SAndroid Build Coastguard Worker ]: 6026*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6027*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 6028*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([lim]), 6029*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 6030*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([0]), 6031*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0]), 6032*da0073e9SAndroid Build Coastguard Worker ) 6033*da0073e9SAndroid Build Coastguard Worker 6034*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_with_logits_overflow(self): 6035*da0073e9SAndroid Build Coastguard Worker for tensor_type, lim in [ 6036*da0073e9SAndroid Build Coastguard Worker (torch.FloatTensor, 1e38), 6037*da0073e9SAndroid Build Coastguard Worker (torch.DoubleTensor, 1e308), 6038*da0073e9SAndroid Build Coastguard Worker ]: 6039*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6040*da0073e9SAndroid Build Coastguard Worker dist_class=Bernoulli, 6041*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([lim]), 6042*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1]), 6043*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([0]), 6044*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0]), 6045*da0073e9SAndroid Build Coastguard Worker ) 6046*da0073e9SAndroid Build Coastguard Worker 6047*da0073e9SAndroid Build Coastguard Worker def test_categorical_log_prob(self): 6048*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 6049*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0, 1], dtype=dtype, requires_grad=True) 6050*da0073e9SAndroid Build Coastguard Worker categorical = OneHotCategorical(p) 6051*da0073e9SAndroid Build Coastguard Worker log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype)) 6052*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf.item(), 0) 6053*da0073e9SAndroid Build Coastguard Worker 6054*da0073e9SAndroid Build Coastguard Worker def test_categorical_log_prob_with_logits(self): 6055*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 6056*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True) 6057*da0073e9SAndroid Build Coastguard Worker categorical = OneHotCategorical(logits=p) 6058*da0073e9SAndroid Build Coastguard Worker log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype)) 6059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf_prob_1.item(), 0) 6060*da0073e9SAndroid Build Coastguard Worker log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype)) 6061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf_prob_0.item(), -inf) 6062*da0073e9SAndroid Build Coastguard Worker 6063*da0073e9SAndroid Build Coastguard Worker def test_multinomial_log_prob(self): 6064*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 6065*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([0, 1], dtype=dtype, requires_grad=True) 6066*da0073e9SAndroid Build Coastguard Worker s = torch.tensor([0, 10], dtype=dtype) 6067*da0073e9SAndroid Build Coastguard Worker multinomial = Multinomial(10, p) 6068*da0073e9SAndroid Build Coastguard Worker log_pdf = multinomial.log_prob(s) 6069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf.item(), 0) 6070*da0073e9SAndroid Build Coastguard Worker 6071*da0073e9SAndroid Build Coastguard Worker def test_multinomial_log_prob_with_logits(self): 6072*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 6073*da0073e9SAndroid Build Coastguard Worker p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True) 6074*da0073e9SAndroid Build Coastguard Worker multinomial = Multinomial(10, logits=p) 6075*da0073e9SAndroid Build Coastguard Worker log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype)) 6076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf_prob_1.item(), 0) 6077*da0073e9SAndroid Build Coastguard Worker log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype)) 6078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_pdf_prob_0.item(), -inf) 6079*da0073e9SAndroid Build Coastguard Worker 6080*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_gradient(self): 6081*da0073e9SAndroid Build Coastguard Worker def expec_val(x, probs=None, logits=None): 6082*da0073e9SAndroid Build Coastguard Worker assert not (probs is None and logits is None) 6083*da0073e9SAndroid Build Coastguard Worker if logits is not None: 6084*da0073e9SAndroid Build Coastguard Worker probs = 1.0 / (1.0 + math.exp(-logits)) 6085*da0073e9SAndroid Build Coastguard Worker bern_log_lik = x * math.log(probs) + (1.0 - x) * math.log1p(-probs) 6086*da0073e9SAndroid Build Coastguard Worker if probs < 0.499 or probs > 0.501: # using default values of lims here 6087*da0073e9SAndroid Build Coastguard Worker log_norm_const = ( 6088*da0073e9SAndroid Build Coastguard Worker math.log(math.fabs(math.atanh(1.0 - 2.0 * probs))) 6089*da0073e9SAndroid Build Coastguard Worker - math.log(math.fabs(1.0 - 2.0 * probs)) 6090*da0073e9SAndroid Build Coastguard Worker + math.log(2.0) 6091*da0073e9SAndroid Build Coastguard Worker ) 6092*da0073e9SAndroid Build Coastguard Worker else: 6093*da0073e9SAndroid Build Coastguard Worker aux = math.pow(probs - 0.5, 2) 6094*da0073e9SAndroid Build Coastguard Worker log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux 6095*da0073e9SAndroid Build Coastguard Worker log_lik = bern_log_lik + log_norm_const 6096*da0073e9SAndroid Build Coastguard Worker return log_lik 6097*da0073e9SAndroid Build Coastguard Worker 6098*da0073e9SAndroid Build Coastguard Worker def expec_grad(x, probs=None, logits=None): 6099*da0073e9SAndroid Build Coastguard Worker assert not (probs is None and logits is None) 6100*da0073e9SAndroid Build Coastguard Worker if logits is not None: 6101*da0073e9SAndroid Build Coastguard Worker probs = 1.0 / (1.0 + math.exp(-logits)) 6102*da0073e9SAndroid Build Coastguard Worker grad_bern_log_lik = x / probs - (1.0 - x) / (1.0 - probs) 6103*da0073e9SAndroid Build Coastguard Worker if probs < 0.499 or probs > 0.501: # using default values of lims here 6104*da0073e9SAndroid Build Coastguard Worker grad_log_c = ( 6105*da0073e9SAndroid Build Coastguard Worker 2.0 * probs 6106*da0073e9SAndroid Build Coastguard Worker - 4.0 * (probs - 1.0) * probs * math.atanh(1.0 - 2.0 * probs) 6107*da0073e9SAndroid Build Coastguard Worker - 1.0 6108*da0073e9SAndroid Build Coastguard Worker ) 6109*da0073e9SAndroid Build Coastguard Worker grad_log_c /= ( 6110*da0073e9SAndroid Build Coastguard Worker 2.0 6111*da0073e9SAndroid Build Coastguard Worker * (probs - 1.0) 6112*da0073e9SAndroid Build Coastguard Worker * probs 6113*da0073e9SAndroid Build Coastguard Worker * (2.0 * probs - 1.0) 6114*da0073e9SAndroid Build Coastguard Worker * math.atanh(1.0 - 2.0 * probs) 6115*da0073e9SAndroid Build Coastguard Worker ) 6116*da0073e9SAndroid Build Coastguard Worker else: 6117*da0073e9SAndroid Build Coastguard Worker grad_log_c = 8.0 / 3.0 * (probs - 0.5) + 416.0 / 45.0 * math.pow( 6118*da0073e9SAndroid Build Coastguard Worker probs - 0.5, 3 6119*da0073e9SAndroid Build Coastguard Worker ) 6120*da0073e9SAndroid Build Coastguard Worker grad = grad_bern_log_lik + grad_log_c 6121*da0073e9SAndroid Build Coastguard Worker if logits is not None: 6122*da0073e9SAndroid Build Coastguard Worker grad *= 1.0 / (1.0 + math.exp(logits)) - 1.0 / math.pow( 6123*da0073e9SAndroid Build Coastguard Worker 1.0 + math.exp(logits), 2 6124*da0073e9SAndroid Build Coastguard Worker ) 6125*da0073e9SAndroid Build Coastguard Worker return grad 6126*da0073e9SAndroid Build Coastguard Worker 6127*da0073e9SAndroid Build Coastguard Worker for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: 6128*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6129*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6130*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([0.1]), 6131*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0.1]), 6132*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(0.1, probs=0.1)]), 6133*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]), 6134*da0073e9SAndroid Build Coastguard Worker ) 6135*da0073e9SAndroid Build Coastguard Worker 6136*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6137*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6138*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([0.1]), 6139*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1.0]), 6140*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(1.0, probs=0.1)]), 6141*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(1.0, probs=0.1)]), 6142*da0073e9SAndroid Build Coastguard Worker ) 6143*da0073e9SAndroid Build Coastguard Worker 6144*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6145*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6146*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([0.4999]), 6147*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0.9]), 6148*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(0.9, probs=0.4999)]), 6149*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]), 6150*da0073e9SAndroid Build Coastguard Worker ) 6151*da0073e9SAndroid Build Coastguard Worker 6152*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6153*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6154*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([1e-4]), 6155*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1]), 6156*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(1, probs=1e-4)]), 6157*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])), 6158*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 6159*da0073e9SAndroid Build Coastguard Worker ) 6160*da0073e9SAndroid Build Coastguard Worker 6161*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6162*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6163*da0073e9SAndroid Build Coastguard Worker probs=tensor_type([1 - 1e-4]), 6164*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0.1]), 6165*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]), 6166*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]), 6167*da0073e9SAndroid Build Coastguard Worker atol=2, 6168*da0073e9SAndroid Build Coastguard Worker ) 6169*da0073e9SAndroid Build Coastguard Worker 6170*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6171*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6172*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([math.log(9999)]), 6173*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 6174*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(0, logits=math.log(9999))]), 6175*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]), 6176*da0073e9SAndroid Build Coastguard Worker atol=1e-3, 6177*da0073e9SAndroid Build Coastguard Worker ) 6178*da0073e9SAndroid Build Coastguard Worker 6179*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6180*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6181*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([0.001]), 6182*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0.5]), 6183*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expec_val(0.5, logits=0.001)]), 6184*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]), 6185*da0073e9SAndroid Build Coastguard Worker ) 6186*da0073e9SAndroid Build Coastguard Worker 6187*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_with_logits_underflow(self): 6188*da0073e9SAndroid Build Coastguard Worker for tensor_type, lim, expected in [ 6189*da0073e9SAndroid Build Coastguard Worker (torch.FloatTensor, -1e38, 2.76898), 6190*da0073e9SAndroid Build Coastguard Worker (torch.DoubleTensor, -1e308, 3.58473), 6191*da0073e9SAndroid Build Coastguard Worker ]: 6192*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6193*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6194*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([lim]), 6195*da0073e9SAndroid Build Coastguard Worker x=tensor_type([0]), 6196*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expected]), 6197*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0.0]), 6198*da0073e9SAndroid Build Coastguard Worker ) 6199*da0073e9SAndroid Build Coastguard Worker 6200*da0073e9SAndroid Build Coastguard Worker def test_continuous_bernoulli_with_logits_overflow(self): 6201*da0073e9SAndroid Build Coastguard Worker for tensor_type, lim, expected in [ 6202*da0073e9SAndroid Build Coastguard Worker (torch.FloatTensor, 1e38, 2.76898), 6203*da0073e9SAndroid Build Coastguard Worker (torch.DoubleTensor, 1e308, 3.58473), 6204*da0073e9SAndroid Build Coastguard Worker ]: 6205*da0073e9SAndroid Build Coastguard Worker self._test_pdf_score( 6206*da0073e9SAndroid Build Coastguard Worker dist_class=ContinuousBernoulli, 6207*da0073e9SAndroid Build Coastguard Worker logits=tensor_type([lim]), 6208*da0073e9SAndroid Build Coastguard Worker x=tensor_type([1]), 6209*da0073e9SAndroid Build Coastguard Worker expected_value=tensor_type([expected]), 6210*da0073e9SAndroid Build Coastguard Worker expected_gradient=tensor_type([0.0]), 6211*da0073e9SAndroid Build Coastguard Worker ) 6212*da0073e9SAndroid Build Coastguard Worker 6213*da0073e9SAndroid Build Coastguard Worker 6214*da0073e9SAndroid Build Coastguard Worker# TODO: make this a pytest parameterized test 6215*da0073e9SAndroid Build Coastguard Workerclass TestLazyLogitsInitialization(DistributionsTestCase): 6216*da0073e9SAndroid Build Coastguard Worker def setUp(self): 6217*da0073e9SAndroid Build Coastguard Worker super().setUp() 6218*da0073e9SAndroid Build Coastguard Worker # ContinuousBernoulli is not tested because log_prob is not computed simply 6219*da0073e9SAndroid Build Coastguard Worker # from 'logits', but 'probs' is also needed 6220*da0073e9SAndroid Build Coastguard Worker self.examples = [ 6221*da0073e9SAndroid Build Coastguard Worker e 6222*da0073e9SAndroid Build Coastguard Worker for e in _get_examples() 6223*da0073e9SAndroid Build Coastguard Worker if e.Dist 6224*da0073e9SAndroid Build Coastguard Worker in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial) 6225*da0073e9SAndroid Build Coastguard Worker ] 6226*da0073e9SAndroid Build Coastguard Worker 6227*da0073e9SAndroid Build Coastguard Worker def test_lazy_logits_initialization(self): 6228*da0073e9SAndroid Build Coastguard Worker for Dist, params in self.examples: 6229*da0073e9SAndroid Build Coastguard Worker param = params[0].copy() 6230*da0073e9SAndroid Build Coastguard Worker if "probs" not in param: 6231*da0073e9SAndroid Build Coastguard Worker continue 6232*da0073e9SAndroid Build Coastguard Worker probs = param.pop("probs") 6233*da0073e9SAndroid Build Coastguard Worker param["logits"] = probs_to_logits(probs) 6234*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6235*da0073e9SAndroid Build Coastguard Worker # Create new instance to generate a valid sample 6236*da0073e9SAndroid Build Coastguard Worker dist.log_prob(Dist(**param).sample()) 6237*da0073e9SAndroid Build Coastguard Worker message = f"Failed for {Dist.__name__} example 0/{len(params)}" 6238*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("probs", dist.__dict__, msg=message) 6239*da0073e9SAndroid Build Coastguard Worker try: 6240*da0073e9SAndroid Build Coastguard Worker dist.enumerate_support() 6241*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6242*da0073e9SAndroid Build Coastguard Worker pass 6243*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("probs", dist.__dict__, msg=message) 6244*da0073e9SAndroid Build Coastguard Worker batch_shape, event_shape = dist.batch_shape, dist.event_shape 6245*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("probs", dist.__dict__, msg=message) 6246*da0073e9SAndroid Build Coastguard Worker 6247*da0073e9SAndroid Build Coastguard Worker def test_lazy_probs_initialization(self): 6248*da0073e9SAndroid Build Coastguard Worker for Dist, params in self.examples: 6249*da0073e9SAndroid Build Coastguard Worker param = params[0].copy() 6250*da0073e9SAndroid Build Coastguard Worker if "probs" not in param: 6251*da0073e9SAndroid Build Coastguard Worker continue 6252*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6253*da0073e9SAndroid Build Coastguard Worker dist.sample() 6254*da0073e9SAndroid Build Coastguard Worker message = f"Failed for {Dist.__name__} example 0/{len(params)}" 6255*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("logits", dist.__dict__, msg=message) 6256*da0073e9SAndroid Build Coastguard Worker try: 6257*da0073e9SAndroid Build Coastguard Worker dist.enumerate_support() 6258*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6259*da0073e9SAndroid Build Coastguard Worker pass 6260*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("logits", dist.__dict__, msg=message) 6261*da0073e9SAndroid Build Coastguard Worker batch_shape, event_shape = dist.batch_shape, dist.event_shape 6262*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("logits", dist.__dict__, msg=message) 6263*da0073e9SAndroid Build Coastguard Worker 6264*da0073e9SAndroid Build Coastguard Worker 6265*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_NUMPY, "NumPy not found") 6266*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("FIXME: Tries to trace through SciPy and fails") 6267*da0073e9SAndroid Build Coastguard Workerclass TestAgainstScipy(DistributionsTestCase): 6268*da0073e9SAndroid Build Coastguard Worker def setUp(self): 6269*da0073e9SAndroid Build Coastguard Worker super().setUp() 6270*da0073e9SAndroid Build Coastguard Worker positive_var = torch.randn(20, dtype=torch.double).exp() 6271*da0073e9SAndroid Build Coastguard Worker positive_var2 = torch.randn(20, dtype=torch.double).exp() 6272*da0073e9SAndroid Build Coastguard Worker random_var = torch.randn(20, dtype=torch.double) 6273*da0073e9SAndroid Build Coastguard Worker simplex_tensor = softmax(torch.randn(20, dtype=torch.double), dim=-1) 6274*da0073e9SAndroid Build Coastguard Worker cov_tensor = torch.randn(20, 20, dtype=torch.double) 6275*da0073e9SAndroid Build Coastguard Worker cov_tensor = cov_tensor @ cov_tensor.mT 6276*da0073e9SAndroid Build Coastguard Worker self.distribution_pairs = [ 6277*da0073e9SAndroid Build Coastguard Worker (Bernoulli(simplex_tensor), scipy.stats.bernoulli(simplex_tensor)), 6278*da0073e9SAndroid Build Coastguard Worker ( 6279*da0073e9SAndroid Build Coastguard Worker Beta(positive_var, positive_var2), 6280*da0073e9SAndroid Build Coastguard Worker scipy.stats.beta(positive_var, positive_var2), 6281*da0073e9SAndroid Build Coastguard Worker ), 6282*da0073e9SAndroid Build Coastguard Worker ( 6283*da0073e9SAndroid Build Coastguard Worker Binomial(10, simplex_tensor), 6284*da0073e9SAndroid Build Coastguard Worker scipy.stats.binom( 6285*da0073e9SAndroid Build Coastguard Worker 10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy() 6286*da0073e9SAndroid Build Coastguard Worker ), 6287*da0073e9SAndroid Build Coastguard Worker ), 6288*da0073e9SAndroid Build Coastguard Worker ( 6289*da0073e9SAndroid Build Coastguard Worker Cauchy(random_var, positive_var), 6290*da0073e9SAndroid Build Coastguard Worker scipy.stats.cauchy(loc=random_var, scale=positive_var), 6291*da0073e9SAndroid Build Coastguard Worker ), 6292*da0073e9SAndroid Build Coastguard Worker (Dirichlet(positive_var), scipy.stats.dirichlet(positive_var)), 6293*da0073e9SAndroid Build Coastguard Worker ( 6294*da0073e9SAndroid Build Coastguard Worker Exponential(positive_var), 6295*da0073e9SAndroid Build Coastguard Worker scipy.stats.expon(scale=positive_var.reciprocal()), 6296*da0073e9SAndroid Build Coastguard Worker ), 6297*da0073e9SAndroid Build Coastguard Worker ( 6298*da0073e9SAndroid Build Coastguard Worker FisherSnedecor( 6299*da0073e9SAndroid Build Coastguard Worker positive_var, 4 + positive_var2 6300*da0073e9SAndroid Build Coastguard Worker ), # var for df2<=4 is undefined 6301*da0073e9SAndroid Build Coastguard Worker scipy.stats.f(positive_var, 4 + positive_var2), 6302*da0073e9SAndroid Build Coastguard Worker ), 6303*da0073e9SAndroid Build Coastguard Worker ( 6304*da0073e9SAndroid Build Coastguard Worker Gamma(positive_var, positive_var2), 6305*da0073e9SAndroid Build Coastguard Worker scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()), 6306*da0073e9SAndroid Build Coastguard Worker ), 6307*da0073e9SAndroid Build Coastguard Worker (Geometric(simplex_tensor), scipy.stats.geom(simplex_tensor, loc=-1)), 6308*da0073e9SAndroid Build Coastguard Worker ( 6309*da0073e9SAndroid Build Coastguard Worker Gumbel(random_var, positive_var2), 6310*da0073e9SAndroid Build Coastguard Worker scipy.stats.gumbel_r(random_var, positive_var2), 6311*da0073e9SAndroid Build Coastguard Worker ), 6312*da0073e9SAndroid Build Coastguard Worker (HalfCauchy(positive_var), scipy.stats.halfcauchy(scale=positive_var)), 6313*da0073e9SAndroid Build Coastguard Worker (HalfNormal(positive_var2), scipy.stats.halfnorm(scale=positive_var2)), 6314*da0073e9SAndroid Build Coastguard Worker ( 6315*da0073e9SAndroid Build Coastguard Worker InverseGamma(positive_var, positive_var2), 6316*da0073e9SAndroid Build Coastguard Worker scipy.stats.invgamma(positive_var, scale=positive_var2), 6317*da0073e9SAndroid Build Coastguard Worker ), 6318*da0073e9SAndroid Build Coastguard Worker ( 6319*da0073e9SAndroid Build Coastguard Worker Laplace(random_var, positive_var2), 6320*da0073e9SAndroid Build Coastguard Worker scipy.stats.laplace(random_var, positive_var2), 6321*da0073e9SAndroid Build Coastguard Worker ), 6322*da0073e9SAndroid Build Coastguard Worker ( 6323*da0073e9SAndroid Build Coastguard Worker # Tests fail 1e-5 threshold if scale > 3 6324*da0073e9SAndroid Build Coastguard Worker LogNormal(random_var, positive_var.clamp(max=3)), 6325*da0073e9SAndroid Build Coastguard Worker scipy.stats.lognorm( 6326*da0073e9SAndroid Build Coastguard Worker s=positive_var.clamp(max=3), scale=random_var.exp() 6327*da0073e9SAndroid Build Coastguard Worker ), 6328*da0073e9SAndroid Build Coastguard Worker ), 6329*da0073e9SAndroid Build Coastguard Worker ( 6330*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal( 6331*da0073e9SAndroid Build Coastguard Worker random_var, torch.zeros(20, 1, dtype=torch.double), positive_var2 6332*da0073e9SAndroid Build Coastguard Worker ), 6333*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)), 6334*da0073e9SAndroid Build Coastguard Worker ), 6335*da0073e9SAndroid Build Coastguard Worker ( 6336*da0073e9SAndroid Build Coastguard Worker Multinomial(10, simplex_tensor), 6337*da0073e9SAndroid Build Coastguard Worker scipy.stats.multinomial(10, simplex_tensor), 6338*da0073e9SAndroid Build Coastguard Worker ), 6339*da0073e9SAndroid Build Coastguard Worker ( 6340*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(random_var, torch.diag(positive_var2)), 6341*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)), 6342*da0073e9SAndroid Build Coastguard Worker ), 6343*da0073e9SAndroid Build Coastguard Worker ( 6344*da0073e9SAndroid Build Coastguard Worker MultivariateNormal(random_var, cov_tensor), 6345*da0073e9SAndroid Build Coastguard Worker scipy.stats.multivariate_normal(random_var, cov_tensor), 6346*da0073e9SAndroid Build Coastguard Worker ), 6347*da0073e9SAndroid Build Coastguard Worker ( 6348*da0073e9SAndroid Build Coastguard Worker Normal(random_var, positive_var2), 6349*da0073e9SAndroid Build Coastguard Worker scipy.stats.norm(random_var, positive_var2), 6350*da0073e9SAndroid Build Coastguard Worker ), 6351*da0073e9SAndroid Build Coastguard Worker ( 6352*da0073e9SAndroid Build Coastguard Worker OneHotCategorical(simplex_tensor), 6353*da0073e9SAndroid Build Coastguard Worker scipy.stats.multinomial(1, simplex_tensor), 6354*da0073e9SAndroid Build Coastguard Worker ), 6355*da0073e9SAndroid Build Coastguard Worker ( 6356*da0073e9SAndroid Build Coastguard Worker Pareto(positive_var, 2 + positive_var2), 6357*da0073e9SAndroid Build Coastguard Worker scipy.stats.pareto(2 + positive_var2, scale=positive_var), 6358*da0073e9SAndroid Build Coastguard Worker ), 6359*da0073e9SAndroid Build Coastguard Worker (Poisson(positive_var), scipy.stats.poisson(positive_var)), 6360*da0073e9SAndroid Build Coastguard Worker ( 6361*da0073e9SAndroid Build Coastguard Worker StudentT(2 + positive_var, random_var, positive_var2), 6362*da0073e9SAndroid Build Coastguard Worker scipy.stats.t(2 + positive_var, random_var, positive_var2), 6363*da0073e9SAndroid Build Coastguard Worker ), 6364*da0073e9SAndroid Build Coastguard Worker ( 6365*da0073e9SAndroid Build Coastguard Worker Uniform(random_var, random_var + positive_var), 6366*da0073e9SAndroid Build Coastguard Worker scipy.stats.uniform(random_var, positive_var), 6367*da0073e9SAndroid Build Coastguard Worker ), 6368*da0073e9SAndroid Build Coastguard Worker ( 6369*da0073e9SAndroid Build Coastguard Worker VonMises(random_var, positive_var), 6370*da0073e9SAndroid Build Coastguard Worker scipy.stats.vonmises(positive_var, loc=random_var), 6371*da0073e9SAndroid Build Coastguard Worker ), 6372*da0073e9SAndroid Build Coastguard Worker ( 6373*da0073e9SAndroid Build Coastguard Worker Weibull( 6374*da0073e9SAndroid Build Coastguard Worker positive_var[0], positive_var2[0] 6375*da0073e9SAndroid Build Coastguard Worker ), # scipy var for Weibull only supports scalars 6376*da0073e9SAndroid Build Coastguard Worker scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0]), 6377*da0073e9SAndroid Build Coastguard Worker ), 6378*da0073e9SAndroid Build Coastguard Worker ( 6379*da0073e9SAndroid Build Coastguard Worker # scipy var for Wishart only supports scalars 6380*da0073e9SAndroid Build Coastguard Worker # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 6381*da0073e9SAndroid Build Coastguard Worker Wishart( 6382*da0073e9SAndroid Build Coastguard Worker ( 6383*da0073e9SAndroid Build Coastguard Worker 20 6384*da0073e9SAndroid Build Coastguard Worker if version.parse(scipy.__version__) < version.parse("1.7.0") 6385*da0073e9SAndroid Build Coastguard Worker else 19 6386*da0073e9SAndroid Build Coastguard Worker ) 6387*da0073e9SAndroid Build Coastguard Worker + positive_var[0], 6388*da0073e9SAndroid Build Coastguard Worker cov_tensor, 6389*da0073e9SAndroid Build Coastguard Worker ), 6390*da0073e9SAndroid Build Coastguard Worker scipy.stats.wishart( 6391*da0073e9SAndroid Build Coastguard Worker ( 6392*da0073e9SAndroid Build Coastguard Worker 20 6393*da0073e9SAndroid Build Coastguard Worker if version.parse(scipy.__version__) < version.parse("1.7.0") 6394*da0073e9SAndroid Build Coastguard Worker else 19 6395*da0073e9SAndroid Build Coastguard Worker ) 6396*da0073e9SAndroid Build Coastguard Worker + positive_var[0].item(), 6397*da0073e9SAndroid Build Coastguard Worker cov_tensor, 6398*da0073e9SAndroid Build Coastguard Worker ), 6399*da0073e9SAndroid Build Coastguard Worker ), 6400*da0073e9SAndroid Build Coastguard Worker ] 6401*da0073e9SAndroid Build Coastguard Worker 6402*da0073e9SAndroid Build Coastguard Worker def test_mean(self): 6403*da0073e9SAndroid Build Coastguard Worker for pytorch_dist, scipy_dist in self.distribution_pairs: 6404*da0073e9SAndroid Build Coastguard Worker if isinstance(pytorch_dist, (Cauchy, HalfCauchy)): 6405*da0073e9SAndroid Build Coastguard Worker # Cauchy, HalfCauchy distributions' mean is nan, skipping check 6406*da0073e9SAndroid Build Coastguard Worker continue 6407*da0073e9SAndroid Build Coastguard Worker elif isinstance( 6408*da0073e9SAndroid Build Coastguard Worker pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal) 6409*da0073e9SAndroid Build Coastguard Worker ): 6410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pytorch_dist.mean, scipy_dist.mean, msg=pytorch_dist) 6411*da0073e9SAndroid Build Coastguard Worker else: 6412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), msg=pytorch_dist) 6413*da0073e9SAndroid Build Coastguard Worker 6414*da0073e9SAndroid Build Coastguard Worker def test_variance_stddev(self): 6415*da0073e9SAndroid Build Coastguard Worker for pytorch_dist, scipy_dist in self.distribution_pairs: 6416*da0073e9SAndroid Build Coastguard Worker if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)): 6417*da0073e9SAndroid Build Coastguard Worker # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check 6418*da0073e9SAndroid Build Coastguard Worker # VonMises variance is circular and scipy doesn't produce a correct result 6419*da0073e9SAndroid Build Coastguard Worker continue 6420*da0073e9SAndroid Build Coastguard Worker elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): 6421*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6422*da0073e9SAndroid Build Coastguard Worker pytorch_dist.variance, np.diag(scipy_dist.cov()), msg=pytorch_dist 6423*da0073e9SAndroid Build Coastguard Worker ) 6424*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6425*da0073e9SAndroid Build Coastguard Worker pytorch_dist.stddev, 6426*da0073e9SAndroid Build Coastguard Worker np.diag(scipy_dist.cov()) ** 0.5, 6427*da0073e9SAndroid Build Coastguard Worker msg=pytorch_dist, 6428*da0073e9SAndroid Build Coastguard Worker ) 6429*da0073e9SAndroid Build Coastguard Worker elif isinstance( 6430*da0073e9SAndroid Build Coastguard Worker pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal) 6431*da0073e9SAndroid Build Coastguard Worker ): 6432*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6433*da0073e9SAndroid Build Coastguard Worker pytorch_dist.variance, np.diag(scipy_dist.cov), msg=pytorch_dist 6434*da0073e9SAndroid Build Coastguard Worker ) 6435*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6436*da0073e9SAndroid Build Coastguard Worker pytorch_dist.stddev, 6437*da0073e9SAndroid Build Coastguard Worker np.diag(scipy_dist.cov) ** 0.5, 6438*da0073e9SAndroid Build Coastguard Worker msg=pytorch_dist, 6439*da0073e9SAndroid Build Coastguard Worker ) 6440*da0073e9SAndroid Build Coastguard Worker else: 6441*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6442*da0073e9SAndroid Build Coastguard Worker pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist 6443*da0073e9SAndroid Build Coastguard Worker ) 6444*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6445*da0073e9SAndroid Build Coastguard Worker pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist 6446*da0073e9SAndroid Build Coastguard Worker ) 6447*da0073e9SAndroid Build Coastguard Worker 6448*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6449*da0073e9SAndroid Build Coastguard Worker def test_cdf(self): 6450*da0073e9SAndroid Build Coastguard Worker for pytorch_dist, scipy_dist in self.distribution_pairs: 6451*da0073e9SAndroid Build Coastguard Worker samples = pytorch_dist.sample((5,)) 6452*da0073e9SAndroid Build Coastguard Worker try: 6453*da0073e9SAndroid Build Coastguard Worker cdf = pytorch_dist.cdf(samples) 6454*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6455*da0073e9SAndroid Build Coastguard Worker continue 6456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cdf, scipy_dist.cdf(samples), msg=pytorch_dist) 6457*da0073e9SAndroid Build Coastguard Worker 6458*da0073e9SAndroid Build Coastguard Worker def test_icdf(self): 6459*da0073e9SAndroid Build Coastguard Worker for pytorch_dist, scipy_dist in self.distribution_pairs: 6460*da0073e9SAndroid Build Coastguard Worker samples = torch.rand((5,) + pytorch_dist.batch_shape, dtype=torch.double) 6461*da0073e9SAndroid Build Coastguard Worker try: 6462*da0073e9SAndroid Build Coastguard Worker icdf = pytorch_dist.icdf(samples) 6463*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6464*da0073e9SAndroid Build Coastguard Worker continue 6465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist) 6466*da0073e9SAndroid Build Coastguard Worker 6467*da0073e9SAndroid Build Coastguard Worker 6468*da0073e9SAndroid Build Coastguard Workerclass TestFunctors(DistributionsTestCase): 6469*da0073e9SAndroid Build Coastguard Worker def test_cat_transform(self): 6470*da0073e9SAndroid Build Coastguard Worker x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6471*da0073e9SAndroid Build Coastguard Worker x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100 6472*da0073e9SAndroid Build Coastguard Worker x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6473*da0073e9SAndroid Build Coastguard Worker t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform 6474*da0073e9SAndroid Build Coastguard Worker dim = 0 6475*da0073e9SAndroid Build Coastguard Worker x = torch.cat([x1, x2, x3], dim=dim) 6476*da0073e9SAndroid Build Coastguard Worker t = CatTransform([t1, t2, t3], dim=dim) 6477*da0073e9SAndroid Build Coastguard Worker actual_dom_check = t.domain.check(x) 6478*da0073e9SAndroid Build Coastguard Worker expected_dom_check = torch.cat( 6479*da0073e9SAndroid Build Coastguard Worker [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim 6480*da0073e9SAndroid Build Coastguard Worker ) 6481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_dom_check, actual_dom_check) 6482*da0073e9SAndroid Build Coastguard Worker actual = t(x) 6483*da0073e9SAndroid Build Coastguard Worker expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim) 6484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 6485*da0073e9SAndroid Build Coastguard Worker y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6486*da0073e9SAndroid Build Coastguard Worker y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6487*da0073e9SAndroid Build Coastguard Worker y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6488*da0073e9SAndroid Build Coastguard Worker y = torch.cat([y1, y2, y3], dim=dim) 6489*da0073e9SAndroid Build Coastguard Worker actual_cod_check = t.codomain.check(y) 6490*da0073e9SAndroid Build Coastguard Worker expected_cod_check = torch.cat( 6491*da0073e9SAndroid Build Coastguard Worker [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)], 6492*da0073e9SAndroid Build Coastguard Worker dim=dim, 6493*da0073e9SAndroid Build Coastguard Worker ) 6494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_cod_check, expected_cod_check) 6495*da0073e9SAndroid Build Coastguard Worker actual_inv = t.inv(y) 6496*da0073e9SAndroid Build Coastguard Worker expected_inv = torch.cat([t1.inv(y1), t2.inv(y2), t3.inv(y3)], dim=dim) 6497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_inv, actual_inv) 6498*da0073e9SAndroid Build Coastguard Worker actual_jac = t.log_abs_det_jacobian(x, y) 6499*da0073e9SAndroid Build Coastguard Worker expected_jac = torch.cat( 6500*da0073e9SAndroid Build Coastguard Worker [ 6501*da0073e9SAndroid Build Coastguard Worker t1.log_abs_det_jacobian(x1, y1), 6502*da0073e9SAndroid Build Coastguard Worker t2.log_abs_det_jacobian(x2, y2), 6503*da0073e9SAndroid Build Coastguard Worker t3.log_abs_det_jacobian(x3, y3), 6504*da0073e9SAndroid Build Coastguard Worker ], 6505*da0073e9SAndroid Build Coastguard Worker dim=dim, 6506*da0073e9SAndroid Build Coastguard Worker ) 6507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_jac, expected_jac) 6508*da0073e9SAndroid Build Coastguard Worker 6509*da0073e9SAndroid Build Coastguard Worker def test_cat_transform_non_uniform(self): 6510*da0073e9SAndroid Build Coastguard Worker x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6511*da0073e9SAndroid Build Coastguard Worker x2 = torch.cat( 6512*da0073e9SAndroid Build Coastguard Worker [ 6513*da0073e9SAndroid Build Coastguard Worker (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100, 6514*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6515*da0073e9SAndroid Build Coastguard Worker ] 6516*da0073e9SAndroid Build Coastguard Worker ) 6517*da0073e9SAndroid Build Coastguard Worker t1 = ExpTransform() 6518*da0073e9SAndroid Build Coastguard Worker t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0) 6519*da0073e9SAndroid Build Coastguard Worker dim = 0 6520*da0073e9SAndroid Build Coastguard Worker x = torch.cat([x1, x2], dim=dim) 6521*da0073e9SAndroid Build Coastguard Worker t = CatTransform([t1, t2], dim=dim, lengths=[1, 2]) 6522*da0073e9SAndroid Build Coastguard Worker actual_dom_check = t.domain.check(x) 6523*da0073e9SAndroid Build Coastguard Worker expected_dom_check = torch.cat( 6524*da0073e9SAndroid Build Coastguard Worker [t1.domain.check(x1), t2.domain.check(x2)], dim=dim 6525*da0073e9SAndroid Build Coastguard Worker ) 6526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_dom_check, actual_dom_check) 6527*da0073e9SAndroid Build Coastguard Worker actual = t(x) 6528*da0073e9SAndroid Build Coastguard Worker expected = torch.cat([t1(x1), t2(x2)], dim=dim) 6529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 6530*da0073e9SAndroid Build Coastguard Worker y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6531*da0073e9SAndroid Build Coastguard Worker y2 = torch.cat( 6532*da0073e9SAndroid Build Coastguard Worker [ 6533*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6534*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6535*da0073e9SAndroid Build Coastguard Worker ] 6536*da0073e9SAndroid Build Coastguard Worker ) 6537*da0073e9SAndroid Build Coastguard Worker y = torch.cat([y1, y2], dim=dim) 6538*da0073e9SAndroid Build Coastguard Worker actual_cod_check = t.codomain.check(y) 6539*da0073e9SAndroid Build Coastguard Worker expected_cod_check = torch.cat( 6540*da0073e9SAndroid Build Coastguard Worker [t1.codomain.check(y1), t2.codomain.check(y2)], dim=dim 6541*da0073e9SAndroid Build Coastguard Worker ) 6542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_cod_check, expected_cod_check) 6543*da0073e9SAndroid Build Coastguard Worker actual_inv = t.inv(y) 6544*da0073e9SAndroid Build Coastguard Worker expected_inv = torch.cat([t1.inv(y1), t2.inv(y2)], dim=dim) 6545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_inv, actual_inv) 6546*da0073e9SAndroid Build Coastguard Worker actual_jac = t.log_abs_det_jacobian(x, y) 6547*da0073e9SAndroid Build Coastguard Worker expected_jac = torch.cat( 6548*da0073e9SAndroid Build Coastguard Worker [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)], dim=dim 6549*da0073e9SAndroid Build Coastguard Worker ) 6550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_jac, expected_jac) 6551*da0073e9SAndroid Build Coastguard Worker 6552*da0073e9SAndroid Build Coastguard Worker def test_cat_event_dim(self): 6553*da0073e9SAndroid Build Coastguard Worker t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) 6554*da0073e9SAndroid Build Coastguard Worker t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) 6555*da0073e9SAndroid Build Coastguard Worker dim = 1 6556*da0073e9SAndroid Build Coastguard Worker bs = 16 6557*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(bs, 2) 6558*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(bs, 2) 6559*da0073e9SAndroid Build Coastguard Worker x = torch.cat([x1, x2], dim=1) 6560*da0073e9SAndroid Build Coastguard Worker t = CatTransform([t1, t2], dim=dim, lengths=[2, 2]) 6561*da0073e9SAndroid Build Coastguard Worker y1 = t1(x1) 6562*da0073e9SAndroid Build Coastguard Worker y2 = t2(x2) 6563*da0073e9SAndroid Build Coastguard Worker y = t(x) 6564*da0073e9SAndroid Build Coastguard Worker actual_jac = t.log_abs_det_jacobian(x, y) 6565*da0073e9SAndroid Build Coastguard Worker expected_jac = sum( 6566*da0073e9SAndroid Build Coastguard Worker [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)] 6567*da0073e9SAndroid Build Coastguard Worker ) 6568*da0073e9SAndroid Build Coastguard Worker 6569*da0073e9SAndroid Build Coastguard Worker def test_stack_transform(self): 6570*da0073e9SAndroid Build Coastguard Worker x1 = -1 * torch.arange(1, 101, dtype=torch.float) 6571*da0073e9SAndroid Build Coastguard Worker x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 6572*da0073e9SAndroid Build Coastguard Worker x3 = torch.arange(1, 101, dtype=torch.float) 6573*da0073e9SAndroid Build Coastguard Worker t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform 6574*da0073e9SAndroid Build Coastguard Worker dim = 0 6575*da0073e9SAndroid Build Coastguard Worker x = torch.stack([x1, x2, x3], dim=dim) 6576*da0073e9SAndroid Build Coastguard Worker t = StackTransform([t1, t2, t3], dim=dim) 6577*da0073e9SAndroid Build Coastguard Worker actual_dom_check = t.domain.check(x) 6578*da0073e9SAndroid Build Coastguard Worker expected_dom_check = torch.stack( 6579*da0073e9SAndroid Build Coastguard Worker [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim 6580*da0073e9SAndroid Build Coastguard Worker ) 6581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_dom_check, actual_dom_check) 6582*da0073e9SAndroid Build Coastguard Worker actual = t(x) 6583*da0073e9SAndroid Build Coastguard Worker expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim) 6584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 6585*da0073e9SAndroid Build Coastguard Worker y1 = torch.arange(1, 101, dtype=torch.float) 6586*da0073e9SAndroid Build Coastguard Worker y2 = torch.arange(1, 101, dtype=torch.float) 6587*da0073e9SAndroid Build Coastguard Worker y3 = torch.arange(1, 101, dtype=torch.float) 6588*da0073e9SAndroid Build Coastguard Worker y = torch.stack([y1, y2, y3], dim=dim) 6589*da0073e9SAndroid Build Coastguard Worker actual_cod_check = t.codomain.check(y) 6590*da0073e9SAndroid Build Coastguard Worker expected_cod_check = torch.stack( 6591*da0073e9SAndroid Build Coastguard Worker [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)], 6592*da0073e9SAndroid Build Coastguard Worker dim=dim, 6593*da0073e9SAndroid Build Coastguard Worker ) 6594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_cod_check, expected_cod_check) 6595*da0073e9SAndroid Build Coastguard Worker actual_inv = t.inv(x) 6596*da0073e9SAndroid Build Coastguard Worker expected_inv = torch.stack([t1.inv(x1), t2.inv(x2), t3.inv(x3)], dim=dim) 6597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_inv, actual_inv) 6598*da0073e9SAndroid Build Coastguard Worker actual_jac = t.log_abs_det_jacobian(x, y) 6599*da0073e9SAndroid Build Coastguard Worker expected_jac = torch.stack( 6600*da0073e9SAndroid Build Coastguard Worker [ 6601*da0073e9SAndroid Build Coastguard Worker t1.log_abs_det_jacobian(x1, y1), 6602*da0073e9SAndroid Build Coastguard Worker t2.log_abs_det_jacobian(x2, y2), 6603*da0073e9SAndroid Build Coastguard Worker t3.log_abs_det_jacobian(x3, y3), 6604*da0073e9SAndroid Build Coastguard Worker ], 6605*da0073e9SAndroid Build Coastguard Worker dim=dim, 6606*da0073e9SAndroid Build Coastguard Worker ) 6607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_jac, expected_jac) 6608*da0073e9SAndroid Build Coastguard Worker 6609*da0073e9SAndroid Build Coastguard Worker 6610*da0073e9SAndroid Build Coastguard Workerclass TestValidation(DistributionsTestCase): 6611*da0073e9SAndroid Build Coastguard Worker def test_valid(self): 6612*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 6613*da0073e9SAndroid Build Coastguard Worker for param in params: 6614*da0073e9SAndroid Build Coastguard Worker Dist(validate_args=True, **param) 6615*da0073e9SAndroid Build Coastguard Worker 6616*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6617*da0073e9SAndroid Build Coastguard Worker def test_invalid_log_probs_arg(self): 6618*da0073e9SAndroid Build Coastguard Worker # Check that validation errors are indeed disabled, 6619*da0073e9SAndroid Build Coastguard Worker # but they might raise another error 6620*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 6621*da0073e9SAndroid Build Coastguard Worker if Dist == TransformedDistribution: 6622*da0073e9SAndroid Build Coastguard Worker # TransformedDistribution has a distribution instance 6623*da0073e9SAndroid Build Coastguard Worker # as the argument, so we cannot do much about that 6624*da0073e9SAndroid Build Coastguard Worker continue 6625*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 6626*da0073e9SAndroid Build Coastguard Worker d_nonval = Dist(validate_args=False, **param) 6627*da0073e9SAndroid Build Coastguard Worker d_val = Dist(validate_args=True, **param) 6628*da0073e9SAndroid Build Coastguard Worker for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]): 6629*da0073e9SAndroid Build Coastguard Worker # samples with incorrect shape must throw ValueError only 6630*da0073e9SAndroid Build Coastguard Worker try: 6631*da0073e9SAndroid Build Coastguard Worker log_prob = d_val.log_prob(v) 6632*da0073e9SAndroid Build Coastguard Worker except ValueError: 6633*da0073e9SAndroid Build Coastguard Worker pass 6634*da0073e9SAndroid Build Coastguard Worker # get sample of correct shape 6635*da0073e9SAndroid Build Coastguard Worker val = torch.full(d_val.batch_shape + d_val.event_shape, v) 6636*da0073e9SAndroid Build Coastguard Worker # check samples with incorrect support 6637*da0073e9SAndroid Build Coastguard Worker try: 6638*da0073e9SAndroid Build Coastguard Worker log_prob = d_val.log_prob(val) 6639*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 6640*da0073e9SAndroid Build Coastguard Worker if e.args and "must be within the support" in e.args[0]: 6641*da0073e9SAndroid Build Coastguard Worker try: 6642*da0073e9SAndroid Build Coastguard Worker log_prob = d_nonval.log_prob(val) 6643*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 6644*da0073e9SAndroid Build Coastguard Worker pass 6645*da0073e9SAndroid Build Coastguard Worker 6646*da0073e9SAndroid Build Coastguard Worker # check correct samples are ok 6647*da0073e9SAndroid Build Coastguard Worker valid_value = d_val.sample() 6648*da0073e9SAndroid Build Coastguard Worker d_val.log_prob(valid_value) 6649*da0073e9SAndroid Build Coastguard Worker # check invalid values raise ValueError 6650*da0073e9SAndroid Build Coastguard Worker if valid_value.dtype == torch.long: 6651*da0073e9SAndroid Build Coastguard Worker valid_value = valid_value.float() 6652*da0073e9SAndroid Build Coastguard Worker invalid_value = torch.full_like(valid_value, math.nan) 6653*da0073e9SAndroid Build Coastguard Worker try: 6654*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6655*da0073e9SAndroid Build Coastguard Worker ValueError, 6656*da0073e9SAndroid Build Coastguard Worker "Expected value argument .* to be within the support .*", 6657*da0073e9SAndroid Build Coastguard Worker ): 6658*da0073e9SAndroid Build Coastguard Worker d_val.log_prob(invalid_value) 6659*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 6660*da0073e9SAndroid Build Coastguard Worker fail_string = "Support ValueError not raised for {} example {}/{}" 6661*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 6662*da0073e9SAndroid Build Coastguard Worker fail_string.format(Dist.__name__, i + 1, len(params)) 6663*da0073e9SAndroid Build Coastguard Worker ) from e 6664*da0073e9SAndroid Build Coastguard Worker 6665*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6666*da0073e9SAndroid Build Coastguard Worker def test_invalid(self): 6667*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_bad_examples(): 6668*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 6669*da0073e9SAndroid Build Coastguard Worker try: 6670*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 6671*da0073e9SAndroid Build Coastguard Worker Dist(validate_args=True, **param) 6672*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 6673*da0073e9SAndroid Build Coastguard Worker fail_string = "ValueError not raised for {} example {}/{}" 6674*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 6675*da0073e9SAndroid Build Coastguard Worker fail_string.format(Dist.__name__, i + 1, len(params)) 6676*da0073e9SAndroid Build Coastguard Worker ) from e 6677*da0073e9SAndroid Build Coastguard Worker 6678*da0073e9SAndroid Build Coastguard Worker def test_warning_unimplemented_constraints(self): 6679*da0073e9SAndroid Build Coastguard Worker class Delta(Distribution): 6680*da0073e9SAndroid Build Coastguard Worker def __init__(self, validate_args=True): 6681*da0073e9SAndroid Build Coastguard Worker super().__init__(validate_args=validate_args) 6682*da0073e9SAndroid Build Coastguard Worker 6683*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape=torch.Size()): 6684*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0.0).expand(sample_shape) 6685*da0073e9SAndroid Build Coastguard Worker 6686*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 6687*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 6688*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 6689*da0073e9SAndroid Build Coastguard Worker value[value != 0.0] = -float("inf") 6690*da0073e9SAndroid Build Coastguard Worker value[value == 0.0] = 0.0 6691*da0073e9SAndroid Build Coastguard Worker return value 6692*da0073e9SAndroid Build Coastguard Worker 6693*da0073e9SAndroid Build Coastguard Worker with self.assertWarns(UserWarning): 6694*da0073e9SAndroid Build Coastguard Worker d = Delta() 6695*da0073e9SAndroid Build Coastguard Worker sample = d.sample((2,)) 6696*da0073e9SAndroid Build Coastguard Worker with self.assertWarns(UserWarning): 6697*da0073e9SAndroid Build Coastguard Worker d.log_prob(sample) 6698*da0073e9SAndroid Build Coastguard Worker 6699*da0073e9SAndroid Build Coastguard Worker 6700*da0073e9SAndroid Build Coastguard Workerclass TestJit(DistributionsTestCase): 6701*da0073e9SAndroid Build Coastguard Worker def _examples(self): 6702*da0073e9SAndroid Build Coastguard Worker for Dist, params in _get_examples(): 6703*da0073e9SAndroid Build Coastguard Worker for param in params: 6704*da0073e9SAndroid Build Coastguard Worker keys = param.keys() 6705*da0073e9SAndroid Build Coastguard Worker values = tuple(param[key] for key in keys) 6706*da0073e9SAndroid Build Coastguard Worker if not all(isinstance(x, torch.Tensor) for x in values): 6707*da0073e9SAndroid Build Coastguard Worker continue 6708*da0073e9SAndroid Build Coastguard Worker sample = Dist(**param).sample() 6709*da0073e9SAndroid Build Coastguard Worker yield Dist, keys, values, sample 6710*da0073e9SAndroid Build Coastguard Worker 6711*da0073e9SAndroid Build Coastguard Worker def _perturb_tensor(self, value, constraint): 6712*da0073e9SAndroid Build Coastguard Worker if isinstance(constraint, constraints._IntegerGreaterThan): 6713*da0073e9SAndroid Build Coastguard Worker return value + 1 6714*da0073e9SAndroid Build Coastguard Worker if isinstance( 6715*da0073e9SAndroid Build Coastguard Worker constraint, 6716*da0073e9SAndroid Build Coastguard Worker (constraints._PositiveDefinite, constraints._PositiveSemidefinite), 6717*da0073e9SAndroid Build Coastguard Worker ): 6718*da0073e9SAndroid Build Coastguard Worker return value + torch.eye(value.shape[-1]) 6719*da0073e9SAndroid Build Coastguard Worker if value.dtype in [torch.float, torch.double]: 6720*da0073e9SAndroid Build Coastguard Worker transform = transform_to(constraint) 6721*da0073e9SAndroid Build Coastguard Worker delta = value.new(value.shape).normal_() 6722*da0073e9SAndroid Build Coastguard Worker return transform(transform.inv(value) + delta) 6723*da0073e9SAndroid Build Coastguard Worker if value.dtype == torch.long: 6724*da0073e9SAndroid Build Coastguard Worker result = value.clone() 6725*da0073e9SAndroid Build Coastguard Worker result[value == 0] = 1 6726*da0073e9SAndroid Build Coastguard Worker result[value == 1] = 0 6727*da0073e9SAndroid Build Coastguard Worker return result 6728*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 6729*da0073e9SAndroid Build Coastguard Worker 6730*da0073e9SAndroid Build Coastguard Worker def _perturb(self, Dist, keys, values, sample): 6731*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 6732*da0073e9SAndroid Build Coastguard Worker if Dist is Uniform: 6733*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6734*da0073e9SAndroid Build Coastguard Worker param["low"] = param["low"] - torch.rand(param["low"].shape) 6735*da0073e9SAndroid Build Coastguard Worker param["high"] = param["high"] + torch.rand(param["high"].shape) 6736*da0073e9SAndroid Build Coastguard Worker values = [param[key] for key in keys] 6737*da0073e9SAndroid Build Coastguard Worker else: 6738*da0073e9SAndroid Build Coastguard Worker values = [ 6739*da0073e9SAndroid Build Coastguard Worker self._perturb_tensor( 6740*da0073e9SAndroid Build Coastguard Worker value, Dist.arg_constraints.get(key, constraints.real) 6741*da0073e9SAndroid Build Coastguard Worker ) 6742*da0073e9SAndroid Build Coastguard Worker for key, value in zip(keys, values) 6743*da0073e9SAndroid Build Coastguard Worker ] 6744*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6745*da0073e9SAndroid Build Coastguard Worker sample = Dist(**param).sample() 6746*da0073e9SAndroid Build Coastguard Worker return values, sample 6747*da0073e9SAndroid Build Coastguard Worker 6748*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6749*da0073e9SAndroid Build Coastguard Worker def test_sample(self): 6750*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6751*da0073e9SAndroid Build Coastguard Worker 6752*da0073e9SAndroid Build Coastguard Worker def f(*values): 6753*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6754*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6755*da0073e9SAndroid Build Coastguard Worker return dist.sample() 6756*da0073e9SAndroid Build Coastguard Worker 6757*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values, check_trace=False) 6758*da0073e9SAndroid Build Coastguard Worker 6759*da0073e9SAndroid Build Coastguard Worker # FIXME Schema not found for node 6760*da0073e9SAndroid Build Coastguard Worker xfail = [ 6761*da0073e9SAndroid Build Coastguard Worker Cauchy, # aten::cauchy(Double(2,1), float, float, Generator) 6762*da0073e9SAndroid Build Coastguard Worker HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator) 6763*da0073e9SAndroid Build Coastguard Worker VonMises, # Variance is not Euclidean 6764*da0073e9SAndroid Build Coastguard Worker ] 6765*da0073e9SAndroid Build Coastguard Worker if Dist in xfail: 6766*da0073e9SAndroid Build Coastguard Worker continue 6767*da0073e9SAndroid Build Coastguard Worker 6768*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(): 6769*da0073e9SAndroid Build Coastguard Worker sample = f(*values) 6770*da0073e9SAndroid Build Coastguard Worker traced_sample = traced_f(*values) 6771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, traced_sample) 6772*da0073e9SAndroid Build Coastguard Worker 6773*da0073e9SAndroid Build Coastguard Worker # FIXME no nondeterministic nodes found in trace 6774*da0073e9SAndroid Build Coastguard Worker xfail = [Beta, Dirichlet] 6775*da0073e9SAndroid Build Coastguard Worker if Dist not in xfail: 6776*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 6777*da0073e9SAndroid Build Coastguard Worker any(n.isNondeterministic() for n in traced_f.graph.nodes()) 6778*da0073e9SAndroid Build Coastguard Worker ) 6779*da0073e9SAndroid Build Coastguard Worker 6780*da0073e9SAndroid Build Coastguard Worker def test_rsample(self): 6781*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6782*da0073e9SAndroid Build Coastguard Worker if not Dist.has_rsample: 6783*da0073e9SAndroid Build Coastguard Worker continue 6784*da0073e9SAndroid Build Coastguard Worker 6785*da0073e9SAndroid Build Coastguard Worker def f(*values): 6786*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6787*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6788*da0073e9SAndroid Build Coastguard Worker return dist.rsample() 6789*da0073e9SAndroid Build Coastguard Worker 6790*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values, check_trace=False) 6791*da0073e9SAndroid Build Coastguard Worker 6792*da0073e9SAndroid Build Coastguard Worker # FIXME Schema not found for node 6793*da0073e9SAndroid Build Coastguard Worker xfail = [ 6794*da0073e9SAndroid Build Coastguard Worker Cauchy, # aten::cauchy(Double(2,1), float, float, Generator) 6795*da0073e9SAndroid Build Coastguard Worker HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator) 6796*da0073e9SAndroid Build Coastguard Worker ] 6797*da0073e9SAndroid Build Coastguard Worker if Dist in xfail: 6798*da0073e9SAndroid Build Coastguard Worker continue 6799*da0073e9SAndroid Build Coastguard Worker 6800*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(): 6801*da0073e9SAndroid Build Coastguard Worker sample = f(*values) 6802*da0073e9SAndroid Build Coastguard Worker traced_sample = traced_f(*values) 6803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, traced_sample) 6804*da0073e9SAndroid Build Coastguard Worker 6805*da0073e9SAndroid Build Coastguard Worker # FIXME no nondeterministic nodes found in trace 6806*da0073e9SAndroid Build Coastguard Worker xfail = [Beta, Dirichlet] 6807*da0073e9SAndroid Build Coastguard Worker if Dist not in xfail: 6808*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 6809*da0073e9SAndroid Build Coastguard Worker any(n.isNondeterministic() for n in traced_f.graph.nodes()) 6810*da0073e9SAndroid Build Coastguard Worker ) 6811*da0073e9SAndroid Build Coastguard Worker 6812*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6813*da0073e9SAndroid Build Coastguard Worker def test_log_prob(self): 6814*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6815*da0073e9SAndroid Build Coastguard Worker # FIXME traced functions produce incorrect results 6816*da0073e9SAndroid Build Coastguard Worker xfail = [LowRankMultivariateNormal, MultivariateNormal] 6817*da0073e9SAndroid Build Coastguard Worker if Dist in xfail: 6818*da0073e9SAndroid Build Coastguard Worker continue 6819*da0073e9SAndroid Build Coastguard Worker 6820*da0073e9SAndroid Build Coastguard Worker def f(sample, *values): 6821*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6822*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6823*da0073e9SAndroid Build Coastguard Worker return dist.log_prob(sample) 6824*da0073e9SAndroid Build Coastguard Worker 6825*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, (sample,) + values) 6826*da0073e9SAndroid Build Coastguard Worker 6827*da0073e9SAndroid Build Coastguard Worker # check on different data 6828*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6829*da0073e9SAndroid Build Coastguard Worker expected = f(sample, *values) 6830*da0073e9SAndroid Build Coastguard Worker actual = traced_f(sample, *values) 6831*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6832*da0073e9SAndroid Build Coastguard Worker expected, 6833*da0073e9SAndroid Build Coastguard Worker actual, 6834*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6835*da0073e9SAndroid Build Coastguard Worker ) 6836*da0073e9SAndroid Build Coastguard Worker 6837*da0073e9SAndroid Build Coastguard Worker def test_enumerate_support(self): 6838*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6839*da0073e9SAndroid Build Coastguard Worker # FIXME traced functions produce incorrect results 6840*da0073e9SAndroid Build Coastguard Worker xfail = [Binomial] 6841*da0073e9SAndroid Build Coastguard Worker if Dist in xfail: 6842*da0073e9SAndroid Build Coastguard Worker continue 6843*da0073e9SAndroid Build Coastguard Worker 6844*da0073e9SAndroid Build Coastguard Worker def f(*values): 6845*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6846*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6847*da0073e9SAndroid Build Coastguard Worker return dist.enumerate_support() 6848*da0073e9SAndroid Build Coastguard Worker 6849*da0073e9SAndroid Build Coastguard Worker try: 6850*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values) 6851*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6852*da0073e9SAndroid Build Coastguard Worker continue 6853*da0073e9SAndroid Build Coastguard Worker 6854*da0073e9SAndroid Build Coastguard Worker # check on different data 6855*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6856*da0073e9SAndroid Build Coastguard Worker expected = f(*values) 6857*da0073e9SAndroid Build Coastguard Worker actual = traced_f(*values) 6858*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6859*da0073e9SAndroid Build Coastguard Worker expected, 6860*da0073e9SAndroid Build Coastguard Worker actual, 6861*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6862*da0073e9SAndroid Build Coastguard Worker ) 6863*da0073e9SAndroid Build Coastguard Worker 6864*da0073e9SAndroid Build Coastguard Worker def test_mean(self): 6865*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6866*da0073e9SAndroid Build Coastguard Worker 6867*da0073e9SAndroid Build Coastguard Worker def f(*values): 6868*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6869*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6870*da0073e9SAndroid Build Coastguard Worker return dist.mean 6871*da0073e9SAndroid Build Coastguard Worker 6872*da0073e9SAndroid Build Coastguard Worker try: 6873*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values) 6874*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6875*da0073e9SAndroid Build Coastguard Worker continue 6876*da0073e9SAndroid Build Coastguard Worker 6877*da0073e9SAndroid Build Coastguard Worker # check on different data 6878*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6879*da0073e9SAndroid Build Coastguard Worker expected = f(*values) 6880*da0073e9SAndroid Build Coastguard Worker actual = traced_f(*values) 6881*da0073e9SAndroid Build Coastguard Worker expected[expected == float("inf")] = 0.0 6882*da0073e9SAndroid Build Coastguard Worker actual[actual == float("inf")] = 0.0 6883*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6884*da0073e9SAndroid Build Coastguard Worker expected, 6885*da0073e9SAndroid Build Coastguard Worker actual, 6886*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6887*da0073e9SAndroid Build Coastguard Worker ) 6888*da0073e9SAndroid Build Coastguard Worker 6889*da0073e9SAndroid Build Coastguard Worker def test_variance(self): 6890*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6891*da0073e9SAndroid Build Coastguard Worker if Dist in [Cauchy, HalfCauchy]: 6892*da0073e9SAndroid Build Coastguard Worker continue # infinite variance 6893*da0073e9SAndroid Build Coastguard Worker 6894*da0073e9SAndroid Build Coastguard Worker def f(*values): 6895*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6896*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6897*da0073e9SAndroid Build Coastguard Worker return dist.variance 6898*da0073e9SAndroid Build Coastguard Worker 6899*da0073e9SAndroid Build Coastguard Worker try: 6900*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values) 6901*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6902*da0073e9SAndroid Build Coastguard Worker continue 6903*da0073e9SAndroid Build Coastguard Worker 6904*da0073e9SAndroid Build Coastguard Worker # check on different data 6905*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6906*da0073e9SAndroid Build Coastguard Worker expected = f(*values).clone() 6907*da0073e9SAndroid Build Coastguard Worker actual = traced_f(*values).clone() 6908*da0073e9SAndroid Build Coastguard Worker expected[expected == float("inf")] = 0.0 6909*da0073e9SAndroid Build Coastguard Worker actual[actual == float("inf")] = 0.0 6910*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6911*da0073e9SAndroid Build Coastguard Worker expected, 6912*da0073e9SAndroid Build Coastguard Worker actual, 6913*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6914*da0073e9SAndroid Build Coastguard Worker ) 6915*da0073e9SAndroid Build Coastguard Worker 6916*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6917*da0073e9SAndroid Build Coastguard Worker def test_entropy(self): 6918*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6919*da0073e9SAndroid Build Coastguard Worker # FIXME traced functions produce incorrect results 6920*da0073e9SAndroid Build Coastguard Worker xfail = [LowRankMultivariateNormal, MultivariateNormal] 6921*da0073e9SAndroid Build Coastguard Worker if Dist in xfail: 6922*da0073e9SAndroid Build Coastguard Worker continue 6923*da0073e9SAndroid Build Coastguard Worker 6924*da0073e9SAndroid Build Coastguard Worker def f(*values): 6925*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6926*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6927*da0073e9SAndroid Build Coastguard Worker return dist.entropy() 6928*da0073e9SAndroid Build Coastguard Worker 6929*da0073e9SAndroid Build Coastguard Worker try: 6930*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, values) 6931*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6932*da0073e9SAndroid Build Coastguard Worker continue 6933*da0073e9SAndroid Build Coastguard Worker 6934*da0073e9SAndroid Build Coastguard Worker # check on different data 6935*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6936*da0073e9SAndroid Build Coastguard Worker expected = f(*values) 6937*da0073e9SAndroid Build Coastguard Worker actual = traced_f(*values) 6938*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6939*da0073e9SAndroid Build Coastguard Worker expected, 6940*da0073e9SAndroid Build Coastguard Worker actual, 6941*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6942*da0073e9SAndroid Build Coastguard Worker ) 6943*da0073e9SAndroid Build Coastguard Worker 6944*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6945*da0073e9SAndroid Build Coastguard Worker def test_cdf(self): 6946*da0073e9SAndroid Build Coastguard Worker for Dist, keys, values, sample in self._examples(): 6947*da0073e9SAndroid Build Coastguard Worker 6948*da0073e9SAndroid Build Coastguard Worker def f(sample, *values): 6949*da0073e9SAndroid Build Coastguard Worker param = dict(zip(keys, values)) 6950*da0073e9SAndroid Build Coastguard Worker dist = Dist(**param) 6951*da0073e9SAndroid Build Coastguard Worker cdf = dist.cdf(sample) 6952*da0073e9SAndroid Build Coastguard Worker return dist.icdf(cdf) 6953*da0073e9SAndroid Build Coastguard Worker 6954*da0073e9SAndroid Build Coastguard Worker try: 6955*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, (sample,) + values) 6956*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 6957*da0073e9SAndroid Build Coastguard Worker continue 6958*da0073e9SAndroid Build Coastguard Worker 6959*da0073e9SAndroid Build Coastguard Worker # check on different data 6960*da0073e9SAndroid Build Coastguard Worker values, sample = self._perturb(Dist, keys, values, sample) 6961*da0073e9SAndroid Build Coastguard Worker expected = f(sample, *values) 6962*da0073e9SAndroid Build Coastguard Worker actual = traced_f(sample, *values) 6963*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6964*da0073e9SAndroid Build Coastguard Worker expected, 6965*da0073e9SAndroid Build Coastguard Worker actual, 6966*da0073e9SAndroid Build Coastguard Worker msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6967*da0073e9SAndroid Build Coastguard Worker ) 6968*da0073e9SAndroid Build Coastguard Worker 6969*da0073e9SAndroid Build Coastguard Worker 6970*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__" and torch._C.has_lapack: 6971*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 6972*da0073e9SAndroid Build Coastguard Worker run_tests() 6973