1# Owner(s): ["module: distributions"] 2 3""" 4Note [Randomized statistical tests] 5----------------------------------- 6 7This note describes how to maintain tests in this file as random sources 8change. This file contains two types of randomized tests: 9 101. The easier type of randomized test are tests that should always pass but are 11 initialized with random data. If these fail something is wrong, but it's 12 fine to use a fixed seed by inheriting from common.TestCase. 13 142. The trickier tests are statistical tests. These tests explicitly call 15 set_rng_seed(n) and are marked "see Note [Randomized statistical tests]". 16 These statistical tests have a known positive failure rate 17 (we set failure_rate=1e-3 by default). We need to balance strength of these 18 tests with annoyance of false alarms. One way that works is to specifically 19 set seeds in each of the randomized tests. When a random generator 20 occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some 21 of these statistical tests may (rarely) fail. If one fails in this case, 22 it's fine to increment the seed of the failing test (but you shouldn't need 23 to increment it more than once; otherwise something is probably actually 24 wrong). 25 263. `test_geometric_sample`, `test_binomial_sample` and `test_poisson_sample` 27 are validated against `scipy.stats.` which are not guaranteed to be identical 28 across different versions of scipy (namely, they yield invalid results in 1.7+) 29""" 30 31import math 32import numbers 33import unittest 34from collections import namedtuple 35from itertools import product 36from random import shuffle 37 38from packaging import version 39 40import torch 41import torch.autograd.forward_ad as fwAD 42from torch import inf, nan 43from torch.autograd import grad 44from torch.autograd.functional import jacobian 45from torch.distributions import ( 46 Bernoulli, 47 Beta, 48 Binomial, 49 Categorical, 50 Cauchy, 51 Chi2, 52 constraints, 53 ContinuousBernoulli, 54 Dirichlet, 55 Distribution, 56 Exponential, 57 ExponentialFamily, 58 FisherSnedecor, 59 Gamma, 60 Geometric, 61 Gumbel, 62 HalfCauchy, 63 HalfNormal, 64 Independent, 65 InverseGamma, 66 kl_divergence, 67 Kumaraswamy, 68 Laplace, 69 LKJCholesky, 70 LogisticNormal, 71 LogNormal, 72 LowRankMultivariateNormal, 73 MixtureSameFamily, 74 Multinomial, 75 MultivariateNormal, 76 NegativeBinomial, 77 Normal, 78 OneHotCategorical, 79 OneHotCategoricalStraightThrough, 80 Pareto, 81 Poisson, 82 RelaxedBernoulli, 83 RelaxedOneHotCategorical, 84 StudentT, 85 TransformedDistribution, 86 Uniform, 87 VonMises, 88 Weibull, 89 Wishart, 90) 91from torch.distributions.constraint_registry import transform_to 92from torch.distributions.constraints import Constraint, is_dependent 93from torch.distributions.dirichlet import _Dirichlet_backward 94from torch.distributions.kl import _kl_expfamily_expfamily 95from torch.distributions.transforms import ( 96 AffineTransform, 97 CatTransform, 98 ExpTransform, 99 identity_transform, 100 StackTransform, 101) 102from torch.distributions.utils import ( 103 lazy_property, 104 probs_to_logits, 105 tril_matrix_to_vec, 106 vec_to_tril_matrix, 107) 108from torch.nn.functional import softmax 109from torch.testing._internal.common_cuda import TEST_CUDA 110from torch.testing._internal.common_utils import ( 111 gradcheck, 112 load_tests, 113 run_tests, 114 set_default_dtype, 115 set_rng_seed, 116 skipIfTorchDynamo, 117 TestCase, 118) 119 120 121# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 122# sharding on sandcastle. This line silences flake warnings 123load_tests = load_tests 124 125TEST_NUMPY = True 126try: 127 import numpy as np 128 import scipy.special 129 import scipy.stats 130except ImportError: 131 TEST_NUMPY = False 132 133 134def pairwise(Dist, *params): 135 """ 136 Creates a pair of distributions `Dist` initialized to test each element of 137 param with each other. 138 """ 139 params1 = [torch.tensor([p] * len(p)) for p in params] 140 params2 = [p.transpose(0, 1) for p in params1] 141 return Dist(*params1), Dist(*params2) 142 143 144def is_all_nan(tensor): 145 """ 146 Checks if all entries of a tensor is nan. 147 """ 148 return (tensor != tensor).all() 149 150 151Example = namedtuple("Example", ["Dist", "params"]) 152 153 154# Register all distributions for generic tests. 155def _get_examples(): 156 return [ 157 Example( 158 Bernoulli, 159 [ 160 {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 161 {"probs": torch.tensor([0.3], requires_grad=True)}, 162 {"probs": 0.3}, 163 {"logits": torch.tensor([0.0], requires_grad=True)}, 164 ], 165 ), 166 Example( 167 Geometric, 168 [ 169 {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 170 {"probs": torch.tensor([0.3], requires_grad=True)}, 171 {"probs": 0.3}, 172 ], 173 ), 174 Example( 175 Beta, 176 [ 177 { 178 "concentration1": torch.randn(2, 3).exp().requires_grad_(), 179 "concentration0": torch.randn(2, 3).exp().requires_grad_(), 180 }, 181 { 182 "concentration1": torch.randn(4).exp().requires_grad_(), 183 "concentration0": torch.randn(4).exp().requires_grad_(), 184 }, 185 ], 186 ), 187 Example( 188 Categorical, 189 [ 190 { 191 "probs": torch.tensor( 192 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 193 ) 194 }, 195 {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 196 {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 197 ], 198 ), 199 Example( 200 Binomial, 201 [ 202 { 203 "probs": torch.tensor( 204 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 205 ), 206 "total_count": 10, 207 }, 208 { 209 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 210 "total_count": 10, 211 }, 212 { 213 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 214 "total_count": torch.tensor([10]), 215 }, 216 { 217 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 218 "total_count": torch.tensor([10, 8]), 219 }, 220 { 221 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 222 "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]), 223 }, 224 { 225 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 226 "total_count": torch.tensor(0.0), 227 }, 228 ], 229 ), 230 Example( 231 NegativeBinomial, 232 [ 233 { 234 "probs": torch.tensor( 235 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 236 ), 237 "total_count": 10, 238 }, 239 { 240 "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 241 "total_count": 10, 242 }, 243 { 244 "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 245 "total_count": torch.tensor([10]), 246 }, 247 { 248 "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 249 "total_count": torch.tensor([10, 8]), 250 }, 251 { 252 "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 253 "total_count": torch.tensor([[10.0, 8.0], [5.0, 3.0]]), 254 }, 255 { 256 "probs": torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 257 "total_count": torch.tensor(0.0), 258 }, 259 ], 260 ), 261 Example( 262 Multinomial, 263 [ 264 { 265 "probs": torch.tensor( 266 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 267 ), 268 "total_count": 10, 269 }, 270 { 271 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 272 "total_count": 10, 273 }, 274 ], 275 ), 276 Example( 277 Cauchy, 278 [ 279 {"loc": 0.0, "scale": 1.0}, 280 {"loc": torch.tensor([0.0]), "scale": 1.0}, 281 { 282 "loc": torch.tensor([[0.0], [0.0]]), 283 "scale": torch.tensor([[1.0], [1.0]]), 284 }, 285 ], 286 ), 287 Example( 288 Chi2, 289 [ 290 {"df": torch.randn(2, 3).exp().requires_grad_()}, 291 {"df": torch.randn(1).exp().requires_grad_()}, 292 ], 293 ), 294 Example( 295 StudentT, 296 [ 297 {"df": torch.randn(2, 3).exp().requires_grad_()}, 298 {"df": torch.randn(1).exp().requires_grad_()}, 299 ], 300 ), 301 Example( 302 Dirichlet, 303 [ 304 {"concentration": torch.randn(2, 3).exp().requires_grad_()}, 305 {"concentration": torch.randn(4).exp().requires_grad_()}, 306 ], 307 ), 308 Example( 309 Exponential, 310 [ 311 {"rate": torch.randn(5, 5).abs().requires_grad_()}, 312 {"rate": torch.randn(1).abs().requires_grad_()}, 313 ], 314 ), 315 Example( 316 FisherSnedecor, 317 [ 318 { 319 "df1": torch.randn(5, 5).abs().requires_grad_(), 320 "df2": torch.randn(5, 5).abs().requires_grad_(), 321 }, 322 { 323 "df1": torch.randn(1).abs().requires_grad_(), 324 "df2": torch.randn(1).abs().requires_grad_(), 325 }, 326 { 327 "df1": torch.tensor([1.0]), 328 "df2": 1.0, 329 }, 330 ], 331 ), 332 Example( 333 Gamma, 334 [ 335 { 336 "concentration": torch.randn(2, 3).exp().requires_grad_(), 337 "rate": torch.randn(2, 3).exp().requires_grad_(), 338 }, 339 { 340 "concentration": torch.randn(1).exp().requires_grad_(), 341 "rate": torch.randn(1).exp().requires_grad_(), 342 }, 343 ], 344 ), 345 Example( 346 Gumbel, 347 [ 348 { 349 "loc": torch.randn(5, 5, requires_grad=True), 350 "scale": torch.randn(5, 5).abs().requires_grad_(), 351 }, 352 { 353 "loc": torch.randn(1, requires_grad=True), 354 "scale": torch.randn(1).abs().requires_grad_(), 355 }, 356 ], 357 ), 358 Example(HalfCauchy, [{"scale": 1.0}, {"scale": torch.tensor([[1.0], [1.0]])}]), 359 Example( 360 HalfNormal, 361 [ 362 {"scale": torch.randn(5, 5).abs().requires_grad_()}, 363 {"scale": torch.randn(1).abs().requires_grad_()}, 364 {"scale": torch.tensor([1e-5, 1e-5], requires_grad=True)}, 365 ], 366 ), 367 Example( 368 Independent, 369 [ 370 { 371 "base_distribution": Normal( 372 torch.randn(2, 3, requires_grad=True), 373 torch.randn(2, 3).abs().requires_grad_(), 374 ), 375 "reinterpreted_batch_ndims": 0, 376 }, 377 { 378 "base_distribution": Normal( 379 torch.randn(2, 3, requires_grad=True), 380 torch.randn(2, 3).abs().requires_grad_(), 381 ), 382 "reinterpreted_batch_ndims": 1, 383 }, 384 { 385 "base_distribution": Normal( 386 torch.randn(2, 3, requires_grad=True), 387 torch.randn(2, 3).abs().requires_grad_(), 388 ), 389 "reinterpreted_batch_ndims": 2, 390 }, 391 { 392 "base_distribution": Normal( 393 torch.randn(2, 3, 5, requires_grad=True), 394 torch.randn(2, 3, 5).abs().requires_grad_(), 395 ), 396 "reinterpreted_batch_ndims": 2, 397 }, 398 { 399 "base_distribution": Normal( 400 torch.randn(2, 3, 5, requires_grad=True), 401 torch.randn(2, 3, 5).abs().requires_grad_(), 402 ), 403 "reinterpreted_batch_ndims": 3, 404 }, 405 ], 406 ), 407 Example( 408 Kumaraswamy, 409 [ 410 { 411 "concentration1": torch.empty(2, 3).uniform_(1, 2).requires_grad_(), 412 "concentration0": torch.empty(2, 3).uniform_(1, 2).requires_grad_(), 413 }, 414 { 415 "concentration1": torch.rand(4).uniform_(1, 2).requires_grad_(), 416 "concentration0": torch.rand(4).uniform_(1, 2).requires_grad_(), 417 }, 418 ], 419 ), 420 Example( 421 LKJCholesky, 422 [ 423 {"dim": 2, "concentration": 0.5}, 424 { 425 "dim": 3, 426 "concentration": torch.tensor([0.5, 1.0, 2.0]), 427 }, 428 {"dim": 100, "concentration": 4.0}, 429 ], 430 ), 431 Example( 432 Laplace, 433 [ 434 { 435 "loc": torch.randn(5, 5, requires_grad=True), 436 "scale": torch.randn(5, 5).abs().requires_grad_(), 437 }, 438 { 439 "loc": torch.randn(1, requires_grad=True), 440 "scale": torch.randn(1).abs().requires_grad_(), 441 }, 442 { 443 "loc": torch.tensor([1.0, 0.0], requires_grad=True), 444 "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 445 }, 446 ], 447 ), 448 Example( 449 LogNormal, 450 [ 451 { 452 "loc": torch.randn(5, 5, requires_grad=True), 453 "scale": torch.randn(5, 5).abs().requires_grad_(), 454 }, 455 { 456 "loc": torch.randn(1, requires_grad=True), 457 "scale": torch.randn(1).abs().requires_grad_(), 458 }, 459 { 460 "loc": torch.tensor([1.0, 0.0], requires_grad=True), 461 "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 462 }, 463 ], 464 ), 465 Example( 466 LogisticNormal, 467 [ 468 { 469 "loc": torch.randn(5, 5).requires_grad_(), 470 "scale": torch.randn(5, 5).abs().requires_grad_(), 471 }, 472 { 473 "loc": torch.randn(1).requires_grad_(), 474 "scale": torch.randn(1).abs().requires_grad_(), 475 }, 476 { 477 "loc": torch.tensor([1.0, 0.0], requires_grad=True), 478 "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 479 }, 480 ], 481 ), 482 Example( 483 LowRankMultivariateNormal, 484 [ 485 { 486 "loc": torch.randn(5, 2, requires_grad=True), 487 "cov_factor": torch.randn(5, 2, 1, requires_grad=True), 488 "cov_diag": torch.tensor([2.0, 0.25], requires_grad=True), 489 }, 490 { 491 "loc": torch.randn(4, 3, requires_grad=True), 492 "cov_factor": torch.randn(3, 2, requires_grad=True), 493 "cov_diag": torch.tensor([5.0, 1.5, 3.0], requires_grad=True), 494 }, 495 ], 496 ), 497 Example( 498 MultivariateNormal, 499 [ 500 { 501 "loc": torch.randn(5, 2, requires_grad=True), 502 "covariance_matrix": torch.tensor( 503 [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 504 ), 505 }, 506 { 507 "loc": torch.randn(2, 3, requires_grad=True), 508 "precision_matrix": torch.tensor( 509 [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]], 510 requires_grad=True, 511 ), 512 }, 513 { 514 "loc": torch.randn(5, 3, 2, requires_grad=True), 515 "scale_tril": torch.tensor( 516 [ 517 [[2.0, 0.0], [-0.5, 0.25]], 518 [[2.0, 0.0], [0.3, 0.25]], 519 [[5.0, 0.0], [-0.5, 1.5]], 520 ], 521 requires_grad=True, 522 ), 523 }, 524 { 525 "loc": torch.tensor([1.0, -1.0]), 526 "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 527 }, 528 ], 529 ), 530 Example( 531 Normal, 532 [ 533 { 534 "loc": torch.randn(5, 5, requires_grad=True), 535 "scale": torch.randn(5, 5).abs().requires_grad_(), 536 }, 537 { 538 "loc": torch.randn(1, requires_grad=True), 539 "scale": torch.randn(1).abs().requires_grad_(), 540 }, 541 { 542 "loc": torch.tensor([1.0, 0.0], requires_grad=True), 543 "scale": torch.tensor([1e-5, 1e-5], requires_grad=True), 544 }, 545 ], 546 ), 547 Example( 548 OneHotCategorical, 549 [ 550 { 551 "probs": torch.tensor( 552 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 553 ) 554 }, 555 {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 556 {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 557 ], 558 ), 559 Example( 560 OneHotCategoricalStraightThrough, 561 [ 562 { 563 "probs": torch.tensor( 564 [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 565 ) 566 }, 567 {"probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, 568 {"logits": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 569 ], 570 ), 571 Example( 572 Pareto, 573 [ 574 {"scale": 1.0, "alpha": 1.0}, 575 { 576 "scale": (torch.randn(5, 5).abs() + 0.1).requires_grad_(), 577 "alpha": (torch.randn(5, 5).abs() + 0.1).requires_grad_(), 578 }, 579 {"scale": torch.tensor([1.0]), "alpha": 1.0}, 580 ], 581 ), 582 Example( 583 Poisson, 584 [ 585 { 586 "rate": torch.randn(5, 5).abs().requires_grad_(), 587 }, 588 { 589 "rate": torch.randn(3).abs().requires_grad_(), 590 }, 591 { 592 "rate": 0.2, 593 }, 594 { 595 "rate": torch.tensor([0.0], requires_grad=True), 596 }, 597 { 598 "rate": 0.0, 599 }, 600 ], 601 ), 602 Example( 603 RelaxedBernoulli, 604 [ 605 { 606 "temperature": torch.tensor([0.5], requires_grad=True), 607 "probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True), 608 }, 609 { 610 "temperature": torch.tensor([2.0]), 611 "probs": torch.tensor([0.3]), 612 }, 613 { 614 "temperature": torch.tensor([7.2]), 615 "logits": torch.tensor([-2.0, 2.0, 1.0, 5.0]), 616 }, 617 ], 618 ), 619 Example( 620 RelaxedOneHotCategorical, 621 [ 622 { 623 "temperature": torch.tensor([0.5], requires_grad=True), 624 "probs": torch.tensor( 625 [[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True 626 ), 627 }, 628 { 629 "temperature": torch.tensor([2.0]), 630 "probs": torch.tensor([[1.0, 0.0], [0.0, 1.0]]), 631 }, 632 { 633 "temperature": torch.tensor([7.2]), 634 "logits": torch.tensor([[-2.0, 2.0], [1.0, 5.0]]), 635 }, 636 ], 637 ), 638 Example( 639 TransformedDistribution, 640 [ 641 { 642 "base_distribution": Normal( 643 torch.randn(2, 3, requires_grad=True), 644 torch.randn(2, 3).abs().requires_grad_(), 645 ), 646 "transforms": [], 647 }, 648 { 649 "base_distribution": Normal( 650 torch.randn(2, 3, requires_grad=True), 651 torch.randn(2, 3).abs().requires_grad_(), 652 ), 653 "transforms": ExpTransform(), 654 }, 655 { 656 "base_distribution": Normal( 657 torch.randn(2, 3, 5, requires_grad=True), 658 torch.randn(2, 3, 5).abs().requires_grad_(), 659 ), 660 "transforms": [ 661 AffineTransform(torch.randn(3, 5), torch.randn(3, 5)), 662 ExpTransform(), 663 ], 664 }, 665 { 666 "base_distribution": Normal( 667 torch.randn(2, 3, 5, requires_grad=True), 668 torch.randn(2, 3, 5).abs().requires_grad_(), 669 ), 670 "transforms": AffineTransform(1, 2), 671 }, 672 { 673 "base_distribution": Uniform( 674 torch.tensor(1e8).log(), torch.tensor(1e10).log() 675 ), 676 "transforms": ExpTransform(), 677 }, 678 ], 679 ), 680 Example( 681 Uniform, 682 [ 683 { 684 "low": torch.zeros(5, 5, requires_grad=True), 685 "high": torch.ones(5, 5, requires_grad=True), 686 }, 687 { 688 "low": torch.zeros(1, requires_grad=True), 689 "high": torch.ones(1, requires_grad=True), 690 }, 691 { 692 "low": torch.tensor([1.0, 1.0], requires_grad=True), 693 "high": torch.tensor([2.0, 3.0], requires_grad=True), 694 }, 695 ], 696 ), 697 Example( 698 Weibull, 699 [ 700 { 701 "scale": torch.randn(5, 5).abs().requires_grad_(), 702 "concentration": torch.randn(1).abs().requires_grad_(), 703 } 704 ], 705 ), 706 Example( 707 Wishart, 708 [ 709 { 710 "covariance_matrix": torch.tensor( 711 [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 712 ), 713 "df": torch.tensor([3.0], requires_grad=True), 714 }, 715 { 716 "precision_matrix": torch.tensor( 717 [[2.0, 0.1, 0.0], [0.1, 0.25, 0.0], [0.0, 0.0, 0.3]], 718 requires_grad=True, 719 ), 720 "df": torch.tensor([5.0, 4], requires_grad=True), 721 }, 722 { 723 "scale_tril": torch.tensor( 724 [ 725 [[2.0, 0.0], [-0.5, 0.25]], 726 [[2.0, 0.0], [0.3, 0.25]], 727 [[5.0, 0.0], [-0.5, 1.5]], 728 ], 729 requires_grad=True, 730 ), 731 "df": torch.tensor([5.0, 3.5, 3], requires_grad=True), 732 }, 733 { 734 "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 735 "df": torch.tensor([3.0]), 736 }, 737 { 738 "covariance_matrix": torch.tensor([[5.0, -0.5], [-0.5, 1.5]]), 739 "df": 3.0, 740 }, 741 ], 742 ), 743 Example( 744 MixtureSameFamily, 745 [ 746 { 747 "mixture_distribution": Categorical( 748 torch.rand(5, requires_grad=True) 749 ), 750 "component_distribution": Normal( 751 torch.randn(5, requires_grad=True), 752 torch.rand(5, requires_grad=True), 753 ), 754 }, 755 { 756 "mixture_distribution": Categorical( 757 torch.rand(5, requires_grad=True) 758 ), 759 "component_distribution": MultivariateNormal( 760 loc=torch.randn(5, 2, requires_grad=True), 761 covariance_matrix=torch.tensor( 762 [[2.0, 0.3], [0.3, 0.25]], requires_grad=True 763 ), 764 ), 765 }, 766 ], 767 ), 768 Example( 769 VonMises, 770 [ 771 { 772 "loc": torch.tensor(1.0, requires_grad=True), 773 "concentration": torch.tensor(10.0, requires_grad=True), 774 }, 775 { 776 "loc": torch.tensor([0.0, math.pi / 2], requires_grad=True), 777 "concentration": torch.tensor([1.0, 10.0], requires_grad=True), 778 }, 779 ], 780 ), 781 Example( 782 ContinuousBernoulli, 783 [ 784 {"probs": torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, 785 {"probs": torch.tensor([0.3], requires_grad=True)}, 786 {"probs": 0.3}, 787 {"logits": torch.tensor([0.0], requires_grad=True)}, 788 ], 789 ), 790 Example( 791 InverseGamma, 792 [ 793 { 794 "concentration": torch.randn(2, 3).exp().requires_grad_(), 795 "rate": torch.randn(2, 3).exp().requires_grad_(), 796 }, 797 { 798 "concentration": torch.randn(1).exp().requires_grad_(), 799 "rate": torch.randn(1).exp().requires_grad_(), 800 }, 801 ], 802 ), 803 ] 804 805 806def _get_bad_examples(): 807 return [ 808 Example( 809 Bernoulli, 810 [ 811 {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 812 {"probs": torch.tensor([-0.5], requires_grad=True)}, 813 {"probs": 1.00001}, 814 ], 815 ), 816 Example( 817 Beta, 818 [ 819 { 820 "concentration1": torch.tensor([0.0], requires_grad=True), 821 "concentration0": torch.tensor([0.0], requires_grad=True), 822 }, 823 { 824 "concentration1": torch.tensor([-1.0], requires_grad=True), 825 "concentration0": torch.tensor([-2.0], requires_grad=True), 826 }, 827 ], 828 ), 829 Example( 830 Geometric, 831 [ 832 {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 833 {"probs": torch.tensor([-0.3], requires_grad=True)}, 834 {"probs": 1.00000001}, 835 ], 836 ), 837 Example( 838 Categorical, 839 [ 840 { 841 "probs": torch.tensor( 842 [[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 843 ) 844 }, 845 { 846 "probs": torch.tensor( 847 [[-1.0, 10.0], [0.0, -1.0]], requires_grad=True 848 ) 849 }, 850 ], 851 ), 852 Example( 853 Binomial, 854 [ 855 { 856 "probs": torch.tensor( 857 [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 858 ), 859 "total_count": 10, 860 }, 861 { 862 "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), 863 "total_count": 10, 864 }, 865 ], 866 ), 867 Example( 868 NegativeBinomial, 869 [ 870 { 871 "probs": torch.tensor( 872 [[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True 873 ), 874 "total_count": 10, 875 }, 876 { 877 "probs": torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True), 878 "total_count": 10, 879 }, 880 ], 881 ), 882 Example( 883 Cauchy, 884 [ 885 {"loc": 0.0, "scale": -1.0}, 886 {"loc": torch.tensor([0.0]), "scale": 0.0}, 887 { 888 "loc": torch.tensor([[0.0], [-2.0]]), 889 "scale": torch.tensor([[-0.000001], [1.0]]), 890 }, 891 ], 892 ), 893 Example( 894 Chi2, 895 [ 896 {"df": torch.tensor([0.0], requires_grad=True)}, 897 {"df": torch.tensor([-2.0], requires_grad=True)}, 898 ], 899 ), 900 Example( 901 StudentT, 902 [ 903 {"df": torch.tensor([0.0], requires_grad=True)}, 904 {"df": torch.tensor([-2.0], requires_grad=True)}, 905 ], 906 ), 907 Example( 908 Dirichlet, 909 [ 910 {"concentration": torch.tensor([0.0], requires_grad=True)}, 911 {"concentration": torch.tensor([-2.0], requires_grad=True)}, 912 ], 913 ), 914 Example( 915 Exponential, 916 [ 917 {"rate": torch.tensor([0.0, 0.0], requires_grad=True)}, 918 {"rate": torch.tensor([-2.0], requires_grad=True)}, 919 ], 920 ), 921 Example( 922 FisherSnedecor, 923 [ 924 { 925 "df1": torch.tensor([0.0, 0.0], requires_grad=True), 926 "df2": torch.tensor([-1.0, -100.0], requires_grad=True), 927 }, 928 { 929 "df1": torch.tensor([1.0, 1.0], requires_grad=True), 930 "df2": torch.tensor([0.0, 0.0], requires_grad=True), 931 }, 932 ], 933 ), 934 Example( 935 Gamma, 936 [ 937 { 938 "concentration": torch.tensor([0.0, 0.0], requires_grad=True), 939 "rate": torch.tensor([-1.0, -100.0], requires_grad=True), 940 }, 941 { 942 "concentration": torch.tensor([1.0, 1.0], requires_grad=True), 943 "rate": torch.tensor([0.0, 0.0], requires_grad=True), 944 }, 945 ], 946 ), 947 Example( 948 Gumbel, 949 [ 950 { 951 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 952 "scale": torch.tensor([0.0, 1.0], requires_grad=True), 953 }, 954 { 955 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 956 "scale": torch.tensor([1.0, -1.0], requires_grad=True), 957 }, 958 ], 959 ), 960 Example( 961 HalfCauchy, 962 [ 963 {"scale": -1.0}, 964 {"scale": 0.0}, 965 {"scale": torch.tensor([[-0.000001], [1.0]])}, 966 ], 967 ), 968 Example( 969 HalfNormal, 970 [ 971 {"scale": torch.tensor([0.0, 1.0], requires_grad=True)}, 972 {"scale": torch.tensor([1.0, -1.0], requires_grad=True)}, 973 ], 974 ), 975 Example( 976 LKJCholesky, 977 [ 978 {"dim": -2, "concentration": 0.1}, 979 { 980 "dim": 1, 981 "concentration": 2.0, 982 }, 983 { 984 "dim": 2, 985 "concentration": 0.0, 986 }, 987 ], 988 ), 989 Example( 990 Laplace, 991 [ 992 { 993 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 994 "scale": torch.tensor([0.0, 1.0], requires_grad=True), 995 }, 996 { 997 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 998 "scale": torch.tensor([1.0, -1.0], requires_grad=True), 999 }, 1000 ], 1001 ), 1002 Example( 1003 LogNormal, 1004 [ 1005 { 1006 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1007 "scale": torch.tensor([0.0, 1.0], requires_grad=True), 1008 }, 1009 { 1010 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1011 "scale": torch.tensor([1.0, -1.0], requires_grad=True), 1012 }, 1013 ], 1014 ), 1015 Example( 1016 MultivariateNormal, 1017 [ 1018 { 1019 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1020 "covariance_matrix": torch.tensor( 1021 [[1.0, 0.0], [0.0, -2.0]], requires_grad=True 1022 ), 1023 }, 1024 ], 1025 ), 1026 Example( 1027 Normal, 1028 [ 1029 { 1030 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1031 "scale": torch.tensor([0.0, 1.0], requires_grad=True), 1032 }, 1033 { 1034 "loc": torch.tensor([1.0, 1.0], requires_grad=True), 1035 "scale": torch.tensor([1.0, -1.0], requires_grad=True), 1036 }, 1037 { 1038 "loc": torch.tensor([1.0, 0.0], requires_grad=True), 1039 "scale": torch.tensor([1e-5, -1e-5], requires_grad=True), 1040 }, 1041 ], 1042 ), 1043 Example( 1044 OneHotCategorical, 1045 [ 1046 { 1047 "probs": torch.tensor( 1048 [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True 1049 ) 1050 }, 1051 {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 1052 ], 1053 ), 1054 Example( 1055 OneHotCategoricalStraightThrough, 1056 [ 1057 { 1058 "probs": torch.tensor( 1059 [[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True 1060 ) 1061 }, 1062 {"probs": torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, 1063 ], 1064 ), 1065 Example( 1066 Pareto, 1067 [ 1068 {"scale": 0.0, "alpha": 0.0}, 1069 { 1070 "scale": torch.tensor([0.0, 0.0], requires_grad=True), 1071 "alpha": torch.tensor([-1e-5, 0.0], requires_grad=True), 1072 }, 1073 {"scale": torch.tensor([1.0]), "alpha": -1.0}, 1074 ], 1075 ), 1076 Example( 1077 Poisson, 1078 [ 1079 { 1080 "rate": torch.tensor([-0.1], requires_grad=True), 1081 }, 1082 { 1083 "rate": -1.0, 1084 }, 1085 ], 1086 ), 1087 Example( 1088 RelaxedBernoulli, 1089 [ 1090 { 1091 "temperature": torch.tensor([1.5], requires_grad=True), 1092 "probs": torch.tensor([1.7, 0.2, 0.4], requires_grad=True), 1093 }, 1094 { 1095 "temperature": torch.tensor([2.0]), 1096 "probs": torch.tensor([-1.0]), 1097 }, 1098 ], 1099 ), 1100 Example( 1101 RelaxedOneHotCategorical, 1102 [ 1103 { 1104 "temperature": torch.tensor([0.5], requires_grad=True), 1105 "probs": torch.tensor( 1106 [[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True 1107 ), 1108 }, 1109 { 1110 "temperature": torch.tensor([2.0]), 1111 "probs": torch.tensor([[-1.0, 0.0], [-1.0, 1.1]]), 1112 }, 1113 ], 1114 ), 1115 Example( 1116 TransformedDistribution, 1117 [ 1118 { 1119 "base_distribution": Normal(0, 1), 1120 "transforms": lambda x: x, 1121 }, 1122 { 1123 "base_distribution": Normal(0, 1), 1124 "transforms": [lambda x: x], 1125 }, 1126 ], 1127 ), 1128 Example( 1129 Uniform, 1130 [ 1131 { 1132 "low": torch.tensor([2.0], requires_grad=True), 1133 "high": torch.tensor([2.0], requires_grad=True), 1134 }, 1135 { 1136 "low": torch.tensor([0.0], requires_grad=True), 1137 "high": torch.tensor([0.0], requires_grad=True), 1138 }, 1139 { 1140 "low": torch.tensor([1.0], requires_grad=True), 1141 "high": torch.tensor([0.0], requires_grad=True), 1142 }, 1143 ], 1144 ), 1145 Example( 1146 Weibull, 1147 [ 1148 { 1149 "scale": torch.tensor([0.0], requires_grad=True), 1150 "concentration": torch.tensor([0.0], requires_grad=True), 1151 }, 1152 { 1153 "scale": torch.tensor([1.0], requires_grad=True), 1154 "concentration": torch.tensor([-1.0], requires_grad=True), 1155 }, 1156 ], 1157 ), 1158 Example( 1159 Wishart, 1160 [ 1161 { 1162 "covariance_matrix": torch.tensor( 1163 [[1.0, 0.0], [0.0, -2.0]], requires_grad=True 1164 ), 1165 "df": torch.tensor([1.5], requires_grad=True), 1166 }, 1167 { 1168 "covariance_matrix": torch.tensor( 1169 [[1.0, 1.0], [1.0, -2.0]], requires_grad=True 1170 ), 1171 "df": torch.tensor([3.0], requires_grad=True), 1172 }, 1173 { 1174 "covariance_matrix": torch.tensor( 1175 [[1.0, 1.0], [1.0, -2.0]], requires_grad=True 1176 ), 1177 "df": 3.0, 1178 }, 1179 ], 1180 ), 1181 Example( 1182 ContinuousBernoulli, 1183 [ 1184 {"probs": torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, 1185 {"probs": torch.tensor([-0.5], requires_grad=True)}, 1186 {"probs": 1.00001}, 1187 ], 1188 ), 1189 Example( 1190 InverseGamma, 1191 [ 1192 { 1193 "concentration": torch.tensor([0.0, 0.0], requires_grad=True), 1194 "rate": torch.tensor([-1.0, -100.0], requires_grad=True), 1195 }, 1196 { 1197 "concentration": torch.tensor([1.0, 1.0], requires_grad=True), 1198 "rate": torch.tensor([0.0, 0.0], requires_grad=True), 1199 }, 1200 ], 1201 ), 1202 ] 1203 1204 1205class DistributionsTestCase(TestCase): 1206 def setUp(self): 1207 """The tests assume that the validation flag is set.""" 1208 torch.distributions.Distribution.set_default_validate_args(True) 1209 super().setUp() 1210 1211 1212@skipIfTorchDynamo("Not a TorchDynamo suitable test") 1213class TestDistributions(DistributionsTestCase): 1214 _do_cuda_memory_leak_check = True 1215 _do_cuda_non_default_stream = True 1216 1217 def _gradcheck_log_prob(self, dist_ctor, ctor_params): 1218 # performs gradient checks on log_prob 1219 distribution = dist_ctor(*ctor_params) 1220 s = distribution.sample() 1221 if not distribution.support.is_discrete: 1222 s = s.detach().requires_grad_() 1223 1224 expected_shape = distribution.batch_shape + distribution.event_shape 1225 self.assertEqual(s.size(), expected_shape) 1226 1227 def apply_fn(s, *params): 1228 return dist_ctor(*params).log_prob(s) 1229 1230 gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True) 1231 1232 def _check_forward_ad(self, fn): 1233 with fwAD.dual_level(): 1234 x = torch.tensor(1.0) 1235 t = torch.tensor(1.0) 1236 dual = fwAD.make_dual(x, t) 1237 dual_out = fn(dual) 1238 self.assertEqual( 1239 torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0 1240 ) 1241 1242 def _check_log_prob(self, dist, asset_fn): 1243 # checks that the log_prob matches a reference function 1244 s = dist.sample() 1245 log_probs = dist.log_prob(s) 1246 log_probs_data_flat = log_probs.view(-1) 1247 s_data_flat = s.view(len(log_probs_data_flat), -1) 1248 for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)): 1249 asset_fn(i, val.squeeze(), log_prob) 1250 1251 def _check_sampler_sampler( 1252 self, 1253 torch_dist, 1254 ref_dist, 1255 message, 1256 multivariate=False, 1257 circular=False, 1258 num_samples=10000, 1259 failure_rate=1e-3, 1260 ): 1261 # Checks that the .sample() method matches a reference function. 1262 torch_samples = torch_dist.sample((num_samples,)).squeeze() 1263 torch_samples = torch_samples.cpu().numpy() 1264 ref_samples = ref_dist.rvs(num_samples).astype(np.float64) 1265 if multivariate: 1266 # Project onto a random axis. 1267 axis = np.random.normal(size=(1,) + torch_samples.shape[1:]) 1268 axis /= np.linalg.norm(axis) 1269 torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1) 1270 ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1) 1271 samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples] 1272 if circular: 1273 samples = [(np.cos(x), v) for (x, v) in samples] 1274 shuffle( 1275 samples 1276 ) # necessary to prevent stable sort from making uneven bins for discrete 1277 samples.sort(key=lambda x: x[0]) 1278 samples = np.array(samples)[:, 1] 1279 1280 # Aggregate into bins filled with roughly zero-mean unit-variance RVs. 1281 num_bins = 10 1282 samples_per_bin = len(samples) // num_bins 1283 bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1) 1284 stddev = samples_per_bin**-0.5 1285 threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins) 1286 message = f"{message}.sample() is biased:\n{bins}" 1287 for bias in bins: 1288 self.assertLess(-threshold, bias, message) 1289 self.assertLess(bias, threshold, message) 1290 1291 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1292 def _check_sampler_discrete( 1293 self, torch_dist, ref_dist, message, num_samples=10000, failure_rate=1e-3 1294 ): 1295 """Runs a Chi2-test for the support, but ignores tail instead of combining""" 1296 torch_samples = torch_dist.sample((num_samples,)).squeeze() 1297 torch_samples = ( 1298 torch_samples.float() 1299 if torch_samples.dtype == torch.bfloat16 1300 else torch_samples 1301 ) 1302 torch_samples = torch_samples.cpu().numpy() 1303 unique, counts = np.unique(torch_samples, return_counts=True) 1304 pmf = ref_dist.pmf(unique) 1305 pmf = pmf / pmf.sum() # renormalize to 1.0 for chisq test 1306 msk = (counts > 5) & ((pmf * num_samples) > 5) 1307 self.assertGreater( 1308 pmf[msk].sum(), 1309 0.9, 1310 "Distribution is too sparse for test; try increasing num_samples", 1311 ) 1312 # Add a remainder bucket that combines counts for all values 1313 # below threshold, if such values exist (i.e. mask has False entries). 1314 if not msk.all(): 1315 counts = np.concatenate([counts[msk], np.sum(counts[~msk], keepdims=True)]) 1316 pmf = np.concatenate([pmf[msk], np.sum(pmf[~msk], keepdims=True)]) 1317 chisq, p = scipy.stats.chisquare(counts, pmf * num_samples) 1318 self.assertGreater(p, failure_rate, message) 1319 1320 def _check_enumerate_support(self, dist, examples): 1321 for params, expected in examples: 1322 params = {k: torch.tensor(v) for k, v in params.items()} 1323 d = dist(**params) 1324 actual = d.enumerate_support(expand=False) 1325 expected = torch.tensor(expected, dtype=actual.dtype) 1326 self.assertEqual(actual, expected) 1327 actual = d.enumerate_support(expand=True) 1328 expected_with_expand = expected.expand( 1329 (-1,) + d.batch_shape + d.event_shape 1330 ) 1331 self.assertEqual(actual, expected_with_expand) 1332 1333 def test_repr(self): 1334 for Dist, params in _get_examples(): 1335 for param in params: 1336 dist = Dist(**param) 1337 self.assertTrue(repr(dist).startswith(dist.__class__.__name__)) 1338 1339 def test_sample_detached(self): 1340 for Dist, params in _get_examples(): 1341 for i, param in enumerate(params): 1342 variable_params = [ 1343 p for p in param.values() if getattr(p, "requires_grad", False) 1344 ] 1345 if not variable_params: 1346 continue 1347 dist = Dist(**param) 1348 sample = dist.sample() 1349 self.assertFalse( 1350 sample.requires_grad, 1351 msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .sample() is not detached", 1352 ) 1353 1354 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 1355 def test_rsample_requires_grad(self): 1356 for Dist, params in _get_examples(): 1357 for i, param in enumerate(params): 1358 if not any(getattr(p, "requires_grad", False) for p in param.values()): 1359 continue 1360 dist = Dist(**param) 1361 if not dist.has_rsample: 1362 continue 1363 sample = dist.rsample() 1364 self.assertTrue( 1365 sample.requires_grad, 1366 msg=f"{Dist.__name__} example {i + 1}/{len(params)}, .rsample() does not require grad", 1367 ) 1368 1369 def test_enumerate_support_type(self): 1370 for Dist, params in _get_examples(): 1371 for i, param in enumerate(params): 1372 dist = Dist(**param) 1373 try: 1374 self.assertTrue( 1375 type(dist.sample()) is type(dist.enumerate_support()), 1376 msg=( 1377 "{} example {}/{}, return type mismatch between " 1378 + "sample and enumerate_support." 1379 ).format(Dist.__name__, i + 1, len(params)), 1380 ) 1381 except NotImplementedError: 1382 pass 1383 1384 def test_lazy_property_grad(self): 1385 x = torch.randn(1, requires_grad=True) 1386 1387 class Dummy: 1388 @lazy_property 1389 def y(self): 1390 return x + 1 1391 1392 def test(): 1393 x.grad = None 1394 Dummy().y.backward() 1395 self.assertEqual(x.grad, torch.ones(1)) 1396 1397 test() 1398 with torch.no_grad(): 1399 test() 1400 1401 mean = torch.randn(2) 1402 cov = torch.eye(2, requires_grad=True) 1403 distn = MultivariateNormal(mean, cov) 1404 with torch.no_grad(): 1405 distn.scale_tril 1406 distn.scale_tril.sum().backward() 1407 self.assertIsNotNone(cov.grad) 1408 1409 def test_has_examples(self): 1410 distributions_with_examples = {e.Dist for e in _get_examples()} 1411 for Dist in globals().values(): 1412 if ( 1413 isinstance(Dist, type) 1414 and issubclass(Dist, Distribution) 1415 and Dist is not Distribution 1416 and Dist is not ExponentialFamily 1417 ): 1418 self.assertIn( 1419 Dist, 1420 distributions_with_examples, 1421 f"Please add {Dist.__name__} to the _get_examples list in test_distributions.py", 1422 ) 1423 1424 def test_support_attributes(self): 1425 for Dist, params in _get_examples(): 1426 for param in params: 1427 d = Dist(**param) 1428 event_dim = len(d.event_shape) 1429 self.assertEqual(d.support.event_dim, event_dim) 1430 try: 1431 self.assertEqual(Dist.support.event_dim, event_dim) 1432 except NotImplementedError: 1433 pass 1434 is_discrete = d.support.is_discrete 1435 try: 1436 self.assertEqual(Dist.support.is_discrete, is_discrete) 1437 except NotImplementedError: 1438 pass 1439 1440 def test_distribution_expand(self): 1441 shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))] 1442 for Dist, params in _get_examples(): 1443 for param in params: 1444 for shape in shapes: 1445 d = Dist(**param) 1446 expanded_shape = shape + d.batch_shape 1447 original_shape = d.batch_shape + d.event_shape 1448 expected_shape = shape + original_shape 1449 expanded = d.expand(batch_shape=list(expanded_shape)) 1450 sample = expanded.sample() 1451 actual_shape = expanded.sample().shape 1452 self.assertEqual(expanded.__class__, d.__class__) 1453 self.assertEqual(d.sample().shape, original_shape) 1454 self.assertEqual(expanded.log_prob(sample), d.log_prob(sample)) 1455 self.assertEqual(actual_shape, expected_shape) 1456 self.assertEqual(expanded.batch_shape, expanded_shape) 1457 try: 1458 self.assertEqual( 1459 expanded.mean, d.mean.expand(expanded_shape + d.event_shape) 1460 ) 1461 self.assertEqual( 1462 expanded.variance, 1463 d.variance.expand(expanded_shape + d.event_shape), 1464 ) 1465 except NotImplementedError: 1466 pass 1467 1468 def test_distribution_subclass_expand(self): 1469 expand_by = torch.Size((2,)) 1470 for Dist, params in _get_examples(): 1471 1472 class SubClass(Dist): 1473 pass 1474 1475 for param in params: 1476 d = SubClass(**param) 1477 expanded_shape = expand_by + d.batch_shape 1478 original_shape = d.batch_shape + d.event_shape 1479 expected_shape = expand_by + original_shape 1480 expanded = d.expand(batch_shape=expanded_shape) 1481 sample = expanded.sample() 1482 actual_shape = expanded.sample().shape 1483 self.assertEqual(expanded.__class__, d.__class__) 1484 self.assertEqual(d.sample().shape, original_shape) 1485 self.assertEqual(expanded.log_prob(sample), d.log_prob(sample)) 1486 self.assertEqual(actual_shape, expected_shape) 1487 1488 @set_default_dtype(torch.double) 1489 def test_bernoulli(self): 1490 p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1491 r = torch.tensor(0.3, requires_grad=True) 1492 s = 0.3 1493 self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3)) 1494 self.assertFalse(Bernoulli(p).sample().requires_grad) 1495 self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,)) 1496 self.assertEqual(Bernoulli(r).sample().size(), ()) 1497 self.assertEqual( 1498 Bernoulli(r).sample((3, 2)).size(), 1499 ( 1500 3, 1501 2, 1502 ), 1503 ) 1504 self.assertEqual(Bernoulli(s).sample().size(), ()) 1505 self._gradcheck_log_prob(Bernoulli, (p,)) 1506 1507 def ref_log_prob(idx, val, log_prob): 1508 prob = p[idx] 1509 self.assertEqual(log_prob, math.log(prob if val else 1 - prob)) 1510 1511 self._check_log_prob(Bernoulli(p), ref_log_prob) 1512 self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob) 1513 self.assertRaises(NotImplementedError, Bernoulli(r).rsample) 1514 1515 # check entropy computation 1516 self.assertEqual( 1517 Bernoulli(p).entropy(), 1518 torch.tensor([0.6108, 0.5004, 0.6730]), 1519 atol=1e-4, 1520 rtol=0, 1521 ) 1522 self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0])) 1523 self.assertEqual( 1524 Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4, rtol=0 1525 ) 1526 1527 self._check_forward_ad(torch.bernoulli) 1528 self._check_forward_ad(lambda x: x.bernoulli_()) 1529 self._check_forward_ad(lambda x: x.bernoulli_(x.clone().detach())) 1530 self._check_forward_ad(lambda x: x.bernoulli_(x)) 1531 1532 def test_bernoulli_enumerate_support(self): 1533 examples = [ 1534 ({"probs": [0.1]}, [[0], [1]]), 1535 ({"probs": [0.1, 0.9]}, [[0], [1]]), 1536 ({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]), 1537 ] 1538 self._check_enumerate_support(Bernoulli, examples) 1539 1540 def test_bernoulli_3d(self): 1541 p = torch.full((2, 3, 5), 0.5).requires_grad_() 1542 self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5)) 1543 self.assertEqual( 1544 Bernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5) 1545 ) 1546 self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) 1547 1548 @set_default_dtype(torch.double) 1549 def test_geometric(self): 1550 p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1551 r = torch.tensor(0.3, requires_grad=True) 1552 s = 0.3 1553 self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3)) 1554 self.assertEqual(Geometric(1).sample(), 0) 1555 self.assertEqual(Geometric(1).log_prob(torch.tensor(1.0)), -inf) 1556 self.assertEqual(Geometric(1).log_prob(torch.tensor(0.0)), 0) 1557 self.assertFalse(Geometric(p).sample().requires_grad) 1558 self.assertEqual(Geometric(r).sample((8,)).size(), (8,)) 1559 self.assertEqual(Geometric(r).sample().size(), ()) 1560 self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2)) 1561 self.assertEqual(Geometric(s).sample().size(), ()) 1562 self._gradcheck_log_prob(Geometric, (p,)) 1563 self.assertRaises(ValueError, lambda: Geometric(0)) 1564 self.assertRaises(NotImplementedError, Geometric(r).rsample) 1565 1566 self._check_forward_ad(lambda x: x.geometric_(0.2)) 1567 1568 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1569 @set_default_dtype(torch.double) 1570 def test_geometric_log_prob_and_entropy(self): 1571 p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 1572 s = 0.3 1573 1574 def ref_log_prob(idx, val, log_prob): 1575 prob = p[idx].detach() 1576 self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val)) 1577 1578 self._check_log_prob(Geometric(p), ref_log_prob) 1579 self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob) 1580 1581 # check entropy computation 1582 self.assertEqual( 1583 Geometric(p).entropy(), 1584 scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), 1585 atol=1e-3, 1586 rtol=0, 1587 ) 1588 self.assertEqual( 1589 float(Geometric(s).entropy()), 1590 scipy.stats.geom(s, loc=-1).entropy().item(), 1591 atol=1e-3, 1592 rtol=0, 1593 ) 1594 1595 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1596 def test_geometric_sample(self): 1597 set_rng_seed(0) # see Note [Randomized statistical tests] 1598 for prob in [0.01, 0.18, 0.8]: 1599 self._check_sampler_discrete( 1600 Geometric(prob), 1601 scipy.stats.geom(p=prob, loc=-1), 1602 f"Geometric(prob={prob})", 1603 ) 1604 1605 @set_default_dtype(torch.double) 1606 def test_binomial(self): 1607 p = torch.arange(0.05, 1, 0.1).requires_grad_() 1608 for total_count in [1, 2, 10]: 1609 self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p]) 1610 self._gradcheck_log_prob( 1611 lambda p: Binomial(total_count, None, p.log()), [p] 1612 ) 1613 self.assertRaises(NotImplementedError, Binomial(10, p).rsample) 1614 1615 test_binomial_half = set_default_dtype(torch.float16)(test_binomial) 1616 test_binomial_bfloat16 = set_default_dtype(torch.bfloat16)(test_binomial) 1617 1618 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1619 def test_binomial_sample(self): 1620 set_rng_seed(0) # see Note [Randomized statistical tests] 1621 for prob in [0.01, 0.1, 0.5, 0.8, 0.9]: 1622 for count in [2, 10, 100, 500]: 1623 self._check_sampler_discrete( 1624 Binomial(total_count=count, probs=prob), 1625 scipy.stats.binom(count, prob), 1626 f"Binomial(total_count={count}, probs={prob})", 1627 ) 1628 1629 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1630 @set_default_dtype(torch.double) 1631 def test_binomial_log_prob_and_entropy(self): 1632 probs = torch.arange(0.05, 1, 0.1) 1633 for total_count in [1, 2, 10]: 1634 1635 def ref_log_prob(idx, x, log_prob): 1636 p = probs.view(-1)[idx].item() 1637 expected = scipy.stats.binom(total_count, p).logpmf(x) 1638 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1639 1640 self._check_log_prob(Binomial(total_count, probs), ref_log_prob) 1641 logits = probs_to_logits(probs, is_binary=True) 1642 self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob) 1643 1644 bin = Binomial(total_count, logits=logits) 1645 self.assertEqual( 1646 bin.entropy(), 1647 scipy.stats.binom( 1648 total_count, bin.probs.detach().numpy(), loc=-1 1649 ).entropy(), 1650 atol=1e-3, 1651 rtol=0, 1652 ) 1653 1654 def test_binomial_stable(self): 1655 logits = torch.tensor([-100.0, 100.0], dtype=torch.float) 1656 total_count = 1.0 1657 x = torch.tensor([0.0, 0.0], dtype=torch.float) 1658 log_prob = Binomial(total_count, logits=logits).log_prob(x) 1659 self.assertTrue(torch.isfinite(log_prob).all()) 1660 1661 # make sure that the grad at logits=0, value=0 is 0.5 1662 x = torch.tensor(0.0, requires_grad=True) 1663 y = Binomial(total_count, logits=x).log_prob(torch.tensor(0.0)) 1664 self.assertEqual(grad(y, x)[0], torch.tensor(-0.5)) 1665 1666 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1667 @set_default_dtype(torch.double) 1668 def test_binomial_log_prob_vectorized_count(self): 1669 probs = torch.tensor([0.2, 0.7, 0.9]) 1670 for total_count, sample in [ 1671 (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])), 1672 (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])), 1673 ]: 1674 log_prob = Binomial(total_count, probs).log_prob(sample) 1675 expected = scipy.stats.binom( 1676 total_count.cpu().numpy(), probs.cpu().numpy() 1677 ).logpmf(sample) 1678 self.assertEqual(log_prob, expected, atol=1e-4, rtol=0) 1679 1680 def test_binomial_enumerate_support(self): 1681 examples = [ 1682 ({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]), 1683 ({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]), 1684 ( 1685 {"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3}, 1686 [[[0]], [[1]], [[2]], [[3]]], 1687 ), 1688 ] 1689 self._check_enumerate_support(Binomial, examples) 1690 1691 @set_default_dtype(torch.double) 1692 def test_binomial_extreme_vals(self): 1693 total_count = 100 1694 bin0 = Binomial(total_count, 0) 1695 self.assertEqual(bin0.sample(), 0) 1696 self.assertEqual(bin0.log_prob(torch.tensor([0.0]))[0], 0, atol=1e-3, rtol=0) 1697 self.assertEqual(float(bin0.log_prob(torch.tensor([1.0])).exp()), 0) 1698 bin1 = Binomial(total_count, 1) 1699 self.assertEqual(bin1.sample(), total_count) 1700 self.assertEqual( 1701 bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0 1702 ) 1703 self.assertEqual( 1704 float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0 1705 ) 1706 zero_counts = torch.zeros(torch.Size((2, 2))) 1707 bin2 = Binomial(zero_counts, 1) 1708 self.assertEqual(bin2.sample(), zero_counts) 1709 self.assertEqual(bin2.log_prob(zero_counts), zero_counts) 1710 1711 @set_default_dtype(torch.double) 1712 def test_binomial_vectorized_count(self): 1713 set_rng_seed(1) # see Note [Randomized statistical tests] 1714 total_count = torch.tensor([[4, 7], [3, 8]], dtype=torch.float64) 1715 bin0 = Binomial(total_count, torch.tensor(1.0)) 1716 self.assertEqual(bin0.sample(), total_count) 1717 bin1 = Binomial(total_count, torch.tensor(0.5)) 1718 samples = bin1.sample(torch.Size((100000,))) 1719 self.assertTrue((samples <= total_count.type_as(samples)).all()) 1720 self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02, rtol=0) 1721 self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02, rtol=0) 1722 1723 @set_default_dtype(torch.double) 1724 def test_negative_binomial(self): 1725 p = torch.arange(0.05, 1, 0.1).requires_grad_() 1726 for total_count in [1, 2, 10]: 1727 self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p]) 1728 self._gradcheck_log_prob( 1729 lambda p: NegativeBinomial(total_count, None, p.log()), [p] 1730 ) 1731 self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample) 1732 self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy) 1733 1734 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1735 def test_negative_binomial_log_prob(self): 1736 probs = torch.arange(0.05, 1, 0.1) 1737 for total_count in [1, 2, 10]: 1738 1739 def ref_log_prob(idx, x, log_prob): 1740 p = probs.view(-1)[idx].item() 1741 expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x) 1742 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1743 1744 self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob) 1745 logits = probs_to_logits(probs, is_binary=True) 1746 self._check_log_prob( 1747 NegativeBinomial(total_count, logits=logits), ref_log_prob 1748 ) 1749 1750 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1751 @set_default_dtype(torch.double) 1752 def test_negative_binomial_log_prob_vectorized_count(self): 1753 probs = torch.tensor([0.2, 0.7, 0.9]) 1754 for total_count, sample in [ 1755 (torch.tensor([10]), torch.tensor([7.0, 3.0, 9.0])), 1756 (torch.tensor([1, 2, 10]), torch.tensor([0.0, 1.0, 9.0])), 1757 ]: 1758 log_prob = NegativeBinomial(total_count, probs).log_prob(sample) 1759 expected = scipy.stats.nbinom( 1760 total_count.cpu().numpy(), 1 - probs.cpu().numpy() 1761 ).logpmf(sample) 1762 self.assertEqual(log_prob, expected, atol=1e-4, rtol=0) 1763 1764 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 1765 def test_zero_excluded_binomial(self): 1766 vals = Binomial( 1767 total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda() 1768 ).sample(torch.Size((100000000,))) 1769 self.assertTrue((vals >= 0).all()) 1770 vals = Binomial( 1771 total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda() 1772 ).sample(torch.Size((100000000,))) 1773 self.assertTrue((vals < 2).all()) 1774 vals = Binomial( 1775 total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda() 1776 ).sample(torch.Size((10000,))) 1777 # vals should be roughly half zeroes, half ones 1778 assert (vals == 0.0).sum() > 4000 1779 assert (vals == 1.0).sum() > 4000 1780 1781 @set_default_dtype(torch.double) 1782 def test_multinomial_1d(self): 1783 total_count = 10 1784 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1785 self.assertEqual(Multinomial(total_count, p).sample().size(), (3,)) 1786 self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3)) 1787 self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3)) 1788 self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) 1789 self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) 1790 self.assertRaises(NotImplementedError, Multinomial(10, p).rsample) 1791 1792 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 1793 @set_default_dtype(torch.double) 1794 def test_multinomial_1d_log_prob_and_entropy(self): 1795 total_count = 10 1796 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1797 dist = Multinomial(total_count, probs=p) 1798 x = dist.sample() 1799 log_prob = dist.log_prob(x) 1800 expected = torch.tensor( 1801 scipy.stats.multinomial.logpmf( 1802 x.numpy(), n=total_count, p=dist.probs.detach().numpy() 1803 ) 1804 ) 1805 self.assertEqual(log_prob, expected) 1806 1807 dist = Multinomial(total_count, logits=p.log()) 1808 x = dist.sample() 1809 log_prob = dist.log_prob(x) 1810 expected = torch.tensor( 1811 scipy.stats.multinomial.logpmf( 1812 x.numpy(), n=total_count, p=dist.probs.detach().numpy() 1813 ) 1814 ) 1815 self.assertEqual(log_prob, expected) 1816 1817 expected = scipy.stats.multinomial.entropy( 1818 total_count, dist.probs.detach().numpy() 1819 ) 1820 self.assertEqual(dist.entropy(), expected, atol=1e-3, rtol=0) 1821 1822 @set_default_dtype(torch.double) 1823 def test_multinomial_2d(self): 1824 total_count = 10 1825 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1826 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1827 p = torch.tensor(probabilities, requires_grad=True) 1828 s = torch.tensor(probabilities_1, requires_grad=True) 1829 self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3)) 1830 self.assertEqual( 1831 Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3) 1832 ) 1833 self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3)) 1834 set_rng_seed(0) 1835 self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) 1836 self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) 1837 1838 # sample check for extreme value of probs 1839 self.assertEqual( 1840 Multinomial(total_count, s).sample(), 1841 torch.tensor([[total_count, 0], [0, total_count]], dtype=torch.float64), 1842 ) 1843 1844 @set_default_dtype(torch.double) 1845 def test_categorical_1d(self): 1846 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1847 self.assertTrue(is_all_nan(Categorical(p).mean)) 1848 self.assertTrue(is_all_nan(Categorical(p).variance)) 1849 self.assertEqual(Categorical(p).sample().size(), ()) 1850 self.assertFalse(Categorical(p).sample().requires_grad) 1851 self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2)) 1852 self.assertEqual(Categorical(p).sample((1,)).size(), (1,)) 1853 self._gradcheck_log_prob(Categorical, (p,)) 1854 self.assertRaises(NotImplementedError, Categorical(p).rsample) 1855 1856 @set_default_dtype(torch.double) 1857 def test_categorical_2d(self): 1858 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1859 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1860 p = torch.tensor(probabilities, requires_grad=True) 1861 s = torch.tensor(probabilities_1, requires_grad=True) 1862 self.assertEqual(Categorical(p).mean.size(), (2,)) 1863 self.assertEqual(Categorical(p).variance.size(), (2,)) 1864 self.assertTrue(is_all_nan(Categorical(p).mean)) 1865 self.assertTrue(is_all_nan(Categorical(p).variance)) 1866 self.assertEqual(Categorical(p).sample().size(), (2,)) 1867 self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2)) 1868 self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2)) 1869 self._gradcheck_log_prob(Categorical, (p,)) 1870 1871 # sample check for extreme value of probs 1872 set_rng_seed(0) 1873 self.assertEqual( 1874 Categorical(s).sample(sample_shape=(2,)), torch.tensor([[0, 1], [0, 1]]) 1875 ) 1876 1877 def ref_log_prob(idx, val, log_prob): 1878 sample_prob = p[idx][val] / p[idx].sum() 1879 self.assertEqual(log_prob, math.log(sample_prob)) 1880 1881 self._check_log_prob(Categorical(p), ref_log_prob) 1882 self._check_log_prob(Categorical(logits=p.log()), ref_log_prob) 1883 1884 # check entropy computation 1885 self.assertEqual( 1886 Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0 1887 ) 1888 self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0])) 1889 # issue gh-40553 1890 logits = p.log() 1891 logits[1, 1] = logits[0, 2] = float("-inf") 1892 e = Categorical(logits=logits).entropy() 1893 self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0) 1894 1895 def test_categorical_enumerate_support(self): 1896 examples = [ 1897 ({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]), 1898 ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]), 1899 ] 1900 self._check_enumerate_support(Categorical, examples) 1901 1902 @set_default_dtype(torch.double) 1903 def test_one_hot_categorical_1d(self): 1904 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 1905 self.assertEqual(OneHotCategorical(p).sample().size(), (3,)) 1906 self.assertFalse(OneHotCategorical(p).sample().requires_grad) 1907 self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3)) 1908 self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3)) 1909 self._gradcheck_log_prob(OneHotCategorical, (p,)) 1910 self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample) 1911 1912 @set_default_dtype(torch.double) 1913 def test_one_hot_categorical_2d(self): 1914 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 1915 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 1916 p = torch.tensor(probabilities, requires_grad=True) 1917 s = torch.tensor(probabilities_1, requires_grad=True) 1918 self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3)) 1919 self.assertEqual( 1920 OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3) 1921 ) 1922 self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3)) 1923 self._gradcheck_log_prob(OneHotCategorical, (p,)) 1924 1925 dist = OneHotCategorical(p) 1926 x = dist.sample() 1927 self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1])) 1928 1929 def test_one_hot_categorical_enumerate_support(self): 1930 examples = [ 1931 ({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 1932 ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]), 1933 ] 1934 self._check_enumerate_support(OneHotCategorical, examples) 1935 1936 def test_poisson_forward_ad(self): 1937 self._check_forward_ad(torch.poisson) 1938 1939 def test_poisson_shape(self): 1940 rate = torch.randn(2, 3).abs().requires_grad_() 1941 rate_1d = torch.randn(1).abs().requires_grad_() 1942 self.assertEqual(Poisson(rate).sample().size(), (2, 3)) 1943 self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3)) 1944 self.assertEqual(Poisson(rate_1d).sample().size(), (1,)) 1945 self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1)) 1946 self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,)) 1947 1948 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1949 @set_default_dtype(torch.double) 1950 def test_poisson_log_prob(self): 1951 rate = torch.randn(2, 3).abs().requires_grad_() 1952 rate_1d = torch.randn(1).abs().requires_grad_() 1953 rate_zero = torch.zeros([], requires_grad=True) 1954 1955 def ref_log_prob(ref_rate, idx, x, log_prob): 1956 l = ref_rate.view(-1)[idx].detach() 1957 expected = scipy.stats.poisson.logpmf(x, l) 1958 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 1959 1960 set_rng_seed(0) 1961 self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args)) 1962 self._check_log_prob( 1963 Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args) 1964 ) 1965 self._gradcheck_log_prob(Poisson, (rate,)) 1966 self._gradcheck_log_prob(Poisson, (rate_1d,)) 1967 1968 # We cannot check gradients automatically for zero rates because the finite difference 1969 # approximation enters the forbidden parameter space. We instead compare with the 1970 # theoretical results. 1971 dist = Poisson(rate_zero) 1972 dist.log_prob(torch.ones_like(rate_zero)).backward() 1973 self.assertEqual(rate_zero.grad, torch.inf) 1974 1975 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1976 def test_poisson_sample(self): 1977 set_rng_seed(1) # see Note [Randomized statistical tests] 1978 saved_dtype = torch.get_default_dtype() 1979 for dtype in [torch.float, torch.double, torch.bfloat16, torch.half]: 1980 torch.set_default_dtype(dtype) 1981 for rate in [0.1, 1.0, 5.0]: 1982 self._check_sampler_discrete( 1983 Poisson(rate), 1984 scipy.stats.poisson(rate), 1985 f"Poisson(lambda={rate})", 1986 failure_rate=1e-3, 1987 ) 1988 torch.set_default_dtype(saved_dtype) 1989 1990 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 1991 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 1992 def test_poisson_gpu_sample(self): 1993 set_rng_seed(1) 1994 for rate in [0.12, 0.9, 4.0]: 1995 self._check_sampler_discrete( 1996 Poisson(torch.tensor([rate]).cuda()), 1997 scipy.stats.poisson(rate), 1998 f"Poisson(lambda={rate}, cuda)", 1999 failure_rate=1e-3, 2000 ) 2001 2002 @set_default_dtype(torch.double) 2003 def test_relaxed_bernoulli(self): 2004 p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 2005 r = torch.tensor(0.3, requires_grad=True) 2006 s = 0.3 2007 temp = torch.tensor(0.67, requires_grad=True) 2008 self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3)) 2009 self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad) 2010 self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,)) 2011 self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ()) 2012 self.assertEqual( 2013 RelaxedBernoulli(temp, r).sample((3, 2)).size(), 2014 ( 2015 3, 2016 2, 2017 ), 2018 ) 2019 self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ()) 2020 self._gradcheck_log_prob(RelaxedBernoulli, (temp, p)) 2021 self._gradcheck_log_prob(RelaxedBernoulli, (temp, r)) 2022 2023 # test that rsample doesn't fail 2024 s = RelaxedBernoulli(temp, p).rsample() 2025 s.backward(torch.ones_like(s)) 2026 2027 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2028 def test_rounded_relaxed_bernoulli(self): 2029 set_rng_seed(0) # see Note [Randomized statistical tests] 2030 2031 class Rounded: 2032 def __init__(self, dist): 2033 self.dist = dist 2034 2035 def sample(self, *args, **kwargs): 2036 return torch.round(self.dist.sample(*args, **kwargs)) 2037 2038 for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]): 2039 self._check_sampler_discrete( 2040 Rounded(RelaxedBernoulli(temp, probs)), 2041 scipy.stats.bernoulli(probs), 2042 f"Rounded(RelaxedBernoulli(temp={temp}, probs={probs}))", 2043 failure_rate=1e-3, 2044 ) 2045 2046 for probs in [0.001, 0.2, 0.999]: 2047 equal_probs = torch.tensor(0.5) 2048 dist = RelaxedBernoulli(1e10, probs) 2049 s = dist.rsample() 2050 self.assertEqual(equal_probs, s) 2051 2052 @set_default_dtype(torch.double) 2053 def test_relaxed_one_hot_categorical_1d(self): 2054 p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True) 2055 temp = torch.tensor(0.67, requires_grad=True) 2056 self.assertEqual( 2057 RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,) 2058 ) 2059 self.assertFalse( 2060 RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad 2061 ) 2062 self.assertEqual( 2063 RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), 2064 (2, 2, 3), 2065 ) 2066 self.assertEqual( 2067 RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), 2068 (1, 3), 2069 ) 2070 self._gradcheck_log_prob( 2071 lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p) 2072 ) 2073 2074 @set_default_dtype(torch.double) 2075 def test_relaxed_one_hot_categorical_2d(self): 2076 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] 2077 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] 2078 temp = torch.tensor([3.0], requires_grad=True) 2079 # The lower the temperature, the more unstable the log_prob gradcheck is 2080 # w.r.t. the sample. Values below 0.25 empirically fail the default tol. 2081 temp_2 = torch.tensor([0.25], requires_grad=True) 2082 p = torch.tensor(probabilities, requires_grad=True) 2083 s = torch.tensor(probabilities_1, requires_grad=True) 2084 self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3)) 2085 self.assertEqual( 2086 RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), 2087 (3, 4, 2, 3), 2088 ) 2089 self.assertEqual( 2090 RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3) 2091 ) 2092 self._gradcheck_log_prob( 2093 lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), (temp, p) 2094 ) 2095 self._gradcheck_log_prob( 2096 lambda t, p: RelaxedOneHotCategorical(t, p, validate_args=False), 2097 (temp_2, p), 2098 ) 2099 2100 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2101 def test_argmax_relaxed_categorical(self): 2102 set_rng_seed(0) # see Note [Randomized statistical tests] 2103 2104 class ArgMax: 2105 def __init__(self, dist): 2106 self.dist = dist 2107 2108 def sample(self, *args, **kwargs): 2109 s = self.dist.sample(*args, **kwargs) 2110 _, idx = torch.max(s, -1) 2111 return idx 2112 2113 class ScipyCategorical: 2114 def __init__(self, dist): 2115 self.dist = dist 2116 2117 def pmf(self, samples): 2118 new_samples = np.zeros(samples.shape + self.dist.p.shape) 2119 new_samples[np.arange(samples.shape[0]), samples] = 1 2120 return self.dist.pmf(new_samples) 2121 2122 for probs, temp in product( 2123 [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0] 2124 ): 2125 self._check_sampler_discrete( 2126 ArgMax(RelaxedOneHotCategorical(temp, probs)), 2127 ScipyCategorical(scipy.stats.multinomial(1, probs)), 2128 f"Rounded(RelaxedOneHotCategorical(temp={temp}, probs={probs}))", 2129 failure_rate=1e-3, 2130 ) 2131 2132 for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]: 2133 equal_probs = torch.ones(probs.size()) / probs.size()[0] 2134 dist = RelaxedOneHotCategorical(1e10, probs) 2135 s = dist.rsample() 2136 self.assertEqual(equal_probs, s) 2137 2138 @set_default_dtype(torch.double) 2139 def test_uniform(self): 2140 low = torch.zeros(5, 5, requires_grad=True) 2141 high = (torch.ones(5, 5) * 3).requires_grad_() 2142 low_1d = torch.zeros(1, requires_grad=True) 2143 high_1d = (torch.ones(1) * 3).requires_grad_() 2144 self.assertEqual(Uniform(low, high).sample().size(), (5, 5)) 2145 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5)) 2146 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,)) 2147 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1)) 2148 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,)) 2149 2150 # Check log_prob computation when value outside range 2151 uniform = Uniform(low_1d, high_1d, validate_args=False) 2152 above_high = torch.tensor([4.0]) 2153 below_low = torch.tensor([-1.0]) 2154 self.assertEqual(uniform.log_prob(above_high).item(), -inf) 2155 self.assertEqual(uniform.log_prob(below_low).item(), -inf) 2156 2157 # check cdf computation when value outside range 2158 self.assertEqual(uniform.cdf(below_low).item(), 0) 2159 self.assertEqual(uniform.cdf(above_high).item(), 1) 2160 2161 set_rng_seed(1) 2162 self._gradcheck_log_prob(Uniform, (low, high)) 2163 self._gradcheck_log_prob(Uniform, (low, 1.0)) 2164 self._gradcheck_log_prob(Uniform, (0.0, high)) 2165 2166 state = torch.get_rng_state() 2167 rand = low.new(low.size()).uniform_() 2168 torch.set_rng_state(state) 2169 u = Uniform(low, high).rsample() 2170 u.backward(torch.ones_like(u)) 2171 self.assertEqual(low.grad, 1 - rand) 2172 self.assertEqual(high.grad, rand) 2173 low.grad.zero_() 2174 high.grad.zero_() 2175 2176 self._check_forward_ad(lambda x: x.uniform_()) 2177 2178 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2179 def test_vonmises_sample(self): 2180 for loc in [0.0, math.pi / 2.0]: 2181 for concentration in [0.03, 0.3, 1.0, 10.0, 100.0]: 2182 self._check_sampler_sampler( 2183 VonMises(loc, concentration), 2184 scipy.stats.vonmises(loc=loc, kappa=concentration), 2185 f"VonMises(loc={loc}, concentration={concentration})", 2186 num_samples=int(1e5), 2187 circular=True, 2188 ) 2189 2190 def test_vonmises_logprob(self): 2191 concentrations = [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0] 2192 for concentration in concentrations: 2193 grid = torch.arange(0.0, 2 * math.pi, 1e-4) 2194 prob = VonMises(0.0, concentration).log_prob(grid).exp() 2195 norm = prob.mean().item() * 2 * math.pi 2196 self.assertLess(abs(norm - 1), 1e-3) 2197 2198 @set_default_dtype(torch.double) 2199 def test_cauchy(self): 2200 loc = torch.zeros(5, 5, requires_grad=True) 2201 scale = torch.ones(5, 5, requires_grad=True) 2202 loc_1d = torch.zeros(1, requires_grad=True) 2203 scale_1d = torch.ones(1, requires_grad=True) 2204 self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean)) 2205 self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf) 2206 self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5)) 2207 self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5)) 2208 self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,)) 2209 self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 2210 self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,)) 2211 2212 set_rng_seed(1) 2213 self._gradcheck_log_prob(Cauchy, (loc, scale)) 2214 self._gradcheck_log_prob(Cauchy, (loc, 1.0)) 2215 self._gradcheck_log_prob(Cauchy, (0.0, scale)) 2216 2217 state = torch.get_rng_state() 2218 eps = loc.new(loc.size()).cauchy_() 2219 torch.set_rng_state(state) 2220 c = Cauchy(loc, scale).rsample() 2221 c.backward(torch.ones_like(c)) 2222 self.assertEqual(loc.grad, torch.ones_like(scale)) 2223 self.assertEqual(scale.grad, eps) 2224 loc.grad.zero_() 2225 scale.grad.zero_() 2226 2227 self._check_forward_ad(lambda x: x.cauchy_()) 2228 2229 @set_default_dtype(torch.double) 2230 def test_halfcauchy(self): 2231 scale = torch.ones(5, 5, requires_grad=True) 2232 scale_1d = torch.ones(1, requires_grad=True) 2233 self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all()) 2234 self.assertEqual(HalfCauchy(scale_1d).variance, inf) 2235 self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5)) 2236 self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5)) 2237 self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,)) 2238 self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1)) 2239 self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,)) 2240 2241 set_rng_seed(1) 2242 self._gradcheck_log_prob(HalfCauchy, (scale,)) 2243 self._gradcheck_log_prob(HalfCauchy, (1.0,)) 2244 2245 state = torch.get_rng_state() 2246 eps = scale.new(scale.size()).cauchy_().abs_() 2247 torch.set_rng_state(state) 2248 c = HalfCauchy(scale).rsample() 2249 c.backward(torch.ones_like(c)) 2250 self.assertEqual(scale.grad, eps) 2251 scale.grad.zero_() 2252 2253 @set_default_dtype(torch.double) 2254 def test_halfnormal(self): 2255 std = torch.randn(5, 5).abs().requires_grad_() 2256 std_1d = torch.randn(1).abs().requires_grad_() 2257 std_delta = torch.tensor([1e-5, 1e-5]) 2258 self.assertEqual(HalfNormal(std).sample().size(), (5, 5)) 2259 self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5)) 2260 self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1)) 2261 self.assertEqual(HalfNormal(std_1d).sample().size(), (1,)) 2262 self.assertEqual(HalfNormal(0.6).sample((1,)).size(), (1,)) 2263 self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,)) 2264 2265 # sample check for extreme value of std 2266 set_rng_seed(1) 2267 self.assertEqual( 2268 HalfNormal(std_delta).sample(sample_shape=(1, 2)), 2269 torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]), 2270 atol=1e-4, 2271 rtol=0, 2272 ) 2273 2274 self._gradcheck_log_prob(HalfNormal, (std,)) 2275 self._gradcheck_log_prob(HalfNormal, (1.0,)) 2276 2277 # check .log_prob() can broadcast. 2278 dist = HalfNormal(torch.ones(2, 1, 4)) 2279 log_prob = dist.log_prob(torch.ones(3, 1)) 2280 self.assertEqual(log_prob.shape, (2, 3, 4)) 2281 2282 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2283 def test_halfnormal_logprob(self): 2284 std = torch.randn(5, 1).abs().requires_grad_() 2285 2286 def ref_log_prob(idx, x, log_prob): 2287 s = std.view(-1)[idx].detach() 2288 expected = scipy.stats.halfnorm(scale=s).logpdf(x) 2289 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2290 2291 self._check_log_prob(HalfNormal(std), ref_log_prob) 2292 2293 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2294 def test_halfnormal_sample(self): 2295 set_rng_seed(0) # see Note [Randomized statistical tests] 2296 for std in [0.1, 1.0, 10.0]: 2297 self._check_sampler_sampler( 2298 HalfNormal(std), 2299 scipy.stats.halfnorm(scale=std), 2300 f"HalfNormal(scale={std})", 2301 ) 2302 2303 @set_default_dtype(torch.double) 2304 def test_inversegamma(self): 2305 alpha = torch.randn(2, 3).exp().requires_grad_() 2306 beta = torch.randn(2, 3).exp().requires_grad_() 2307 alpha_1d = torch.randn(1).exp().requires_grad_() 2308 beta_1d = torch.randn(1).exp().requires_grad_() 2309 self.assertEqual(InverseGamma(alpha, beta).sample().size(), (2, 3)) 2310 self.assertEqual(InverseGamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 2311 self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 2312 self.assertEqual(InverseGamma(alpha_1d, beta_1d).sample().size(), (1,)) 2313 self.assertEqual(InverseGamma(0.5, 0.5).sample().size(), ()) 2314 self.assertEqual(InverseGamma(0.5, 0.5).sample((1,)).size(), (1,)) 2315 2316 self._gradcheck_log_prob(InverseGamma, (alpha, beta)) 2317 2318 dist = InverseGamma(torch.ones(4), torch.ones(2, 1, 1)) 2319 log_prob = dist.log_prob(torch.ones(3, 1)) 2320 self.assertEqual(log_prob.shape, (2, 3, 4)) 2321 2322 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2323 def test_inversegamma_sample(self): 2324 set_rng_seed(0) # see Note [Randomized statistical tests] 2325 for concentration, rate in product([2, 5], [0.1, 1.0, 10.0]): 2326 self._check_sampler_sampler( 2327 InverseGamma(concentration, rate), 2328 scipy.stats.invgamma(concentration, scale=rate), 2329 "InverseGamma()", 2330 ) 2331 2332 @set_default_dtype(torch.double) 2333 def test_lognormal(self): 2334 mean = torch.randn(5, 5, requires_grad=True) 2335 std = torch.randn(5, 5).abs().requires_grad_() 2336 mean_1d = torch.randn(1, requires_grad=True) 2337 std_1d = torch.randn(1).abs().requires_grad_() 2338 mean_delta = torch.tensor([1.0, 0.0]) 2339 std_delta = torch.tensor([1e-5, 1e-5]) 2340 self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5)) 2341 self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5)) 2342 self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1)) 2343 self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,)) 2344 self.assertEqual(LogNormal(0.2, 0.6).sample((1,)).size(), (1,)) 2345 self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,)) 2346 2347 # sample check for extreme value of mean, std 2348 set_rng_seed(1) 2349 self.assertEqual( 2350 LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)), 2351 torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]), 2352 atol=1e-4, 2353 rtol=0, 2354 ) 2355 2356 self._gradcheck_log_prob(LogNormal, (mean, std)) 2357 self._gradcheck_log_prob(LogNormal, (mean, 1.0)) 2358 self._gradcheck_log_prob(LogNormal, (0.0, std)) 2359 2360 # check .log_prob() can broadcast. 2361 dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1)) 2362 log_prob = dist.log_prob(torch.ones(3, 1)) 2363 self.assertEqual(log_prob.shape, (2, 3, 4)) 2364 2365 self._check_forward_ad(lambda x: x.log_normal_()) 2366 2367 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2368 def test_lognormal_logprob(self): 2369 mean = torch.randn(5, 1, requires_grad=True) 2370 std = torch.randn(5, 1).abs().requires_grad_() 2371 2372 def ref_log_prob(idx, x, log_prob): 2373 m = mean.view(-1)[idx].detach() 2374 s = std.view(-1)[idx].detach() 2375 expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x) 2376 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2377 2378 self._check_log_prob(LogNormal(mean, std), ref_log_prob) 2379 2380 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2381 def test_lognormal_sample(self): 2382 set_rng_seed(0) # see Note [Randomized statistical tests] 2383 for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 2384 self._check_sampler_sampler( 2385 LogNormal(mean, std), 2386 scipy.stats.lognorm(scale=math.exp(mean), s=std), 2387 f"LogNormal(loc={mean}, scale={std})", 2388 ) 2389 2390 @set_default_dtype(torch.double) 2391 def test_logisticnormal(self): 2392 set_rng_seed(1) # see Note [Randomized statistical tests] 2393 mean = torch.randn(5, 5).requires_grad_() 2394 std = torch.randn(5, 5).abs().requires_grad_() 2395 mean_1d = torch.randn(1).requires_grad_() 2396 std_1d = torch.randn(1).abs().requires_grad_() 2397 mean_delta = torch.tensor([1.0, 0.0]) 2398 std_delta = torch.tensor([1e-5, 1e-5]) 2399 self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6)) 2400 self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6)) 2401 self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2)) 2402 self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,)) 2403 self.assertEqual(LogisticNormal(0.2, 0.6).sample().size(), (2,)) 2404 self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,)) 2405 2406 # sample check for extreme value of mean, std 2407 set_rng_seed(1) 2408 self.assertEqual( 2409 LogisticNormal(mean_delta, std_delta).sample(), 2410 torch.tensor( 2411 [ 2412 math.exp(1) / (1.0 + 1.0 + math.exp(1)), 2413 1.0 / (1.0 + 1.0 + math.exp(1)), 2414 1.0 / (1.0 + 1.0 + math.exp(1)), 2415 ] 2416 ), 2417 atol=1e-4, 2418 rtol=0, 2419 ) 2420 2421 # TODO: gradcheck seems to mutate the sample values so that the simplex 2422 # constraint fails by a very small margin. 2423 self._gradcheck_log_prob( 2424 lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, std) 2425 ) 2426 self._gradcheck_log_prob( 2427 lambda m, s: LogisticNormal(m, s, validate_args=False), (mean, 1.0) 2428 ) 2429 self._gradcheck_log_prob( 2430 lambda m, s: LogisticNormal(m, s, validate_args=False), (0.0, std) 2431 ) 2432 2433 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2434 def test_logisticnormal_logprob(self): 2435 mean = torch.randn(5, 7).requires_grad_() 2436 std = torch.randn(5, 7).abs().requires_grad_() 2437 2438 # Smoke test for now 2439 # TODO: Once _check_log_prob works with multidimensional distributions, 2440 # add proper testing of the log probabilities. 2441 dist = LogisticNormal(mean, std) 2442 assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,) 2443 2444 def _get_logistic_normal_ref_sampler(self, base_dist): 2445 def _sampler(num_samples): 2446 x = base_dist.rvs(num_samples) 2447 offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1)) 2448 z = 1.0 / (1.0 + np.exp(offset - x)) 2449 z_cumprod = np.cumprod(1 - z, axis=-1) 2450 y1 = np.pad(z, ((0, 0), (0, 1)), mode="constant", constant_values=1.0) 2451 y2 = np.pad( 2452 z_cumprod, ((0, 0), (1, 0)), mode="constant", constant_values=1.0 2453 ) 2454 return y1 * y2 2455 2456 return _sampler 2457 2458 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2459 def test_logisticnormal_sample(self): 2460 set_rng_seed(0) # see Note [Randomized statistical tests] 2461 means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)]) 2462 covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)]) 2463 for mean, cov in product(means, covs): 2464 base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov) 2465 ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov) 2466 ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist) 2467 mean_th = torch.tensor(mean) 2468 std_th = torch.tensor(np.sqrt(np.diag(cov))) 2469 self._check_sampler_sampler( 2470 LogisticNormal(mean_th, std_th), 2471 ref_dist, 2472 f"LogisticNormal(loc={mean_th}, scale={std_th})", 2473 multivariate=True, 2474 ) 2475 2476 def test_mixture_same_family_shape(self): 2477 normal_case_1d = MixtureSameFamily( 2478 Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5)) 2479 ) 2480 normal_case_1d_batch = MixtureSameFamily( 2481 Categorical(torch.rand(3, 5)), Normal(torch.randn(3, 5), torch.rand(3, 5)) 2482 ) 2483 normal_case_1d_multi_batch = MixtureSameFamily( 2484 Categorical(torch.rand(4, 3, 5)), 2485 Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)), 2486 ) 2487 normal_case_2d = MixtureSameFamily( 2488 Categorical(torch.rand(5)), 2489 Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1), 2490 ) 2491 normal_case_2d_batch = MixtureSameFamily( 2492 Categorical(torch.rand(3, 5)), 2493 Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1), 2494 ) 2495 normal_case_2d_multi_batch = MixtureSameFamily( 2496 Categorical(torch.rand(4, 3, 5)), 2497 Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1), 2498 ) 2499 2500 self.assertEqual(normal_case_1d.sample().size(), ()) 2501 self.assertEqual(normal_case_1d.sample((2,)).size(), (2,)) 2502 self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7)) 2503 self.assertEqual(normal_case_1d_batch.sample().size(), (3,)) 2504 self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3)) 2505 self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3)) 2506 self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3)) 2507 self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3)) 2508 self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3)) 2509 2510 self.assertEqual(normal_case_2d.sample().size(), (2,)) 2511 self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2)) 2512 self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2)) 2513 self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2)) 2514 self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2)) 2515 self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2)) 2516 self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2)) 2517 self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2)) 2518 self.assertEqual( 2519 normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2) 2520 ) 2521 2522 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2523 def test_mixture_same_family_log_prob(self): 2524 probs = torch.rand(5, 5).softmax(dim=-1) 2525 loc = torch.randn(5, 5) 2526 scale = torch.rand(5, 5) 2527 2528 def ref_log_prob(idx, x, log_prob): 2529 p = probs[idx].numpy() 2530 m = loc[idx].numpy() 2531 s = scale[idx].numpy() 2532 mix = scipy.stats.multinomial(1, p) 2533 comp = scipy.stats.norm(m, s) 2534 expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p)) 2535 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 2536 2537 self._check_log_prob( 2538 MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)), 2539 ref_log_prob, 2540 ) 2541 2542 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2543 def test_mixture_same_family_sample(self): 2544 probs = torch.rand(5).softmax(dim=-1) 2545 loc = torch.randn(5) 2546 scale = torch.rand(5) 2547 2548 class ScipyMixtureNormal: 2549 def __init__(self, probs, mu, std): 2550 self.probs = probs 2551 self.mu = mu 2552 self.std = std 2553 2554 def rvs(self, n_sample): 2555 comp_samples = [ 2556 scipy.stats.norm(m, s).rvs(n_sample) 2557 for m, s in zip(self.mu, self.std) 2558 ] 2559 mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample) 2560 samples = [] 2561 for i in range(n_sample): 2562 samples.append(comp_samples[mix_samples[i].argmax()][i]) 2563 return np.asarray(samples) 2564 2565 self._check_sampler_sampler( 2566 MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)), 2567 ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()), 2568 f"""MixtureSameFamily(Categorical(probs={probs}), 2569 Normal(loc={loc}, scale={scale}))""", 2570 ) 2571 2572 @set_default_dtype(torch.double) 2573 def test_normal(self): 2574 loc = torch.randn(5, 5, requires_grad=True) 2575 scale = torch.randn(5, 5).abs().requires_grad_() 2576 loc_1d = torch.randn(1, requires_grad=True) 2577 scale_1d = torch.randn(1).abs().requires_grad_() 2578 loc_delta = torch.tensor([1.0, 0.0]) 2579 scale_delta = torch.tensor([1e-5, 1e-5]) 2580 self.assertEqual(Normal(loc, scale).sample().size(), (5, 5)) 2581 self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5)) 2582 self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 2583 self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,)) 2584 self.assertEqual(Normal(0.2, 0.6).sample((1,)).size(), (1,)) 2585 self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,)) 2586 2587 # sample check for extreme value of mean, std 2588 set_rng_seed(1) 2589 self.assertEqual( 2590 Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)), 2591 torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]), 2592 atol=1e-4, 2593 rtol=0, 2594 ) 2595 2596 self._gradcheck_log_prob(Normal, (loc, scale)) 2597 self._gradcheck_log_prob(Normal, (loc, 1.0)) 2598 self._gradcheck_log_prob(Normal, (0.0, scale)) 2599 2600 state = torch.get_rng_state() 2601 eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale)) 2602 torch.set_rng_state(state) 2603 z = Normal(loc, scale).rsample() 2604 z.backward(torch.ones_like(z)) 2605 self.assertEqual(loc.grad, torch.ones_like(loc)) 2606 self.assertEqual(scale.grad, eps) 2607 loc.grad.zero_() 2608 scale.grad.zero_() 2609 self.assertEqual(z.size(), (5, 5)) 2610 2611 def ref_log_prob(idx, x, log_prob): 2612 m = loc.view(-1)[idx] 2613 s = scale.view(-1)[idx] 2614 expected = math.exp(-((x - m) ** 2) / (2 * s**2)) / math.sqrt( 2615 2 * math.pi * s**2 2616 ) 2617 self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0) 2618 2619 self._check_log_prob(Normal(loc, scale), ref_log_prob) 2620 self._check_forward_ad(torch.normal) 2621 self._check_forward_ad(lambda x: torch.normal(x, 0.5)) 2622 self._check_forward_ad(lambda x: torch.normal(0.2, x)) 2623 self._check_forward_ad(lambda x: torch.normal(x, x)) 2624 self._check_forward_ad(lambda x: x.normal_()) 2625 2626 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2627 def test_normal_sample(self): 2628 set_rng_seed(0) # see Note [Randomized statistical tests] 2629 for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 2630 self._check_sampler_sampler( 2631 Normal(loc, scale), 2632 scipy.stats.norm(loc=loc, scale=scale), 2633 f"Normal(mean={loc}, std={scale})", 2634 ) 2635 2636 @set_default_dtype(torch.double) 2637 def test_lowrank_multivariate_normal_shape(self): 2638 mean = torch.randn(5, 3, requires_grad=True) 2639 mean_no_batch = torch.randn(3, requires_grad=True) 2640 mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True) 2641 2642 # construct PSD covariance 2643 cov_factor = torch.randn(3, 1, requires_grad=True) 2644 cov_diag = torch.randn(3).abs().requires_grad_() 2645 2646 # construct batch of PSD covariances 2647 cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True) 2648 cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_() 2649 2650 # ensure that sample, batch, event shapes all handled correctly 2651 self.assertEqual( 2652 LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample().size(), 2653 (5, 3), 2654 ) 2655 self.assertEqual( 2656 LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2657 .sample() 2658 .size(), 2659 (3,), 2660 ) 2661 self.assertEqual( 2662 LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2663 .sample() 2664 .size(), 2665 (6, 5, 3), 2666 ) 2667 self.assertEqual( 2668 LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2,)).size(), 2669 (2, 5, 3), 2670 ) 2671 self.assertEqual( 2672 LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2673 .sample((2,)) 2674 .size(), 2675 (2, 3), 2676 ) 2677 self.assertEqual( 2678 LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2679 .sample((2,)) 2680 .size(), 2681 (2, 6, 5, 3), 2682 ) 2683 self.assertEqual( 2684 LowRankMultivariateNormal(mean, cov_factor, cov_diag).sample((2, 7)).size(), 2685 (2, 7, 5, 3), 2686 ) 2687 self.assertEqual( 2688 LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) 2689 .sample((2, 7)) 2690 .size(), 2691 (2, 7, 3), 2692 ) 2693 self.assertEqual( 2694 LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) 2695 .sample((2, 7)) 2696 .size(), 2697 (2, 7, 6, 5, 3), 2698 ) 2699 self.assertEqual( 2700 LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched) 2701 .sample((2, 7)) 2702 .size(), 2703 (2, 7, 6, 5, 3), 2704 ) 2705 self.assertEqual( 2706 LowRankMultivariateNormal( 2707 mean_no_batch, cov_factor_batched, cov_diag_batched 2708 ) 2709 .sample((2, 7)) 2710 .size(), 2711 (2, 7, 6, 5, 3), 2712 ) 2713 self.assertEqual( 2714 LowRankMultivariateNormal( 2715 mean_multi_batch, cov_factor_batched, cov_diag_batched 2716 ) 2717 .sample((2, 7)) 2718 .size(), 2719 (2, 7, 6, 5, 3), 2720 ) 2721 2722 # check gradients 2723 self._gradcheck_log_prob( 2724 LowRankMultivariateNormal, (mean, cov_factor, cov_diag) 2725 ) 2726 self._gradcheck_log_prob( 2727 LowRankMultivariateNormal, (mean_multi_batch, cov_factor, cov_diag) 2728 ) 2729 self._gradcheck_log_prob( 2730 LowRankMultivariateNormal, 2731 (mean_multi_batch, cov_factor_batched, cov_diag_batched), 2732 ) 2733 2734 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2735 def test_lowrank_multivariate_normal_log_prob(self): 2736 mean = torch.randn(3, requires_grad=True) 2737 cov_factor = torch.randn(3, 1, requires_grad=True) 2738 cov_diag = torch.randn(3).abs().requires_grad_() 2739 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2740 2741 # check that logprob values match scipy logpdf, 2742 # and that covariance and scale_tril parameters are equivalent 2743 dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2744 ref_dist = scipy.stats.multivariate_normal( 2745 mean.detach().numpy(), cov.detach().numpy() 2746 ) 2747 2748 x = dist1.sample((10,)) 2749 expected = ref_dist.logpdf(x.numpy()) 2750 2751 self.assertEqual( 2752 0.0, 2753 np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 2754 atol=1e-3, 2755 rtol=0, 2756 ) 2757 2758 # Double-check that batched versions behave the same as unbatched 2759 mean = torch.randn(5, 3, requires_grad=True) 2760 cov_factor = torch.randn(5, 3, 2, requires_grad=True) 2761 cov_diag = torch.randn(5, 3).abs().requires_grad_() 2762 2763 dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2764 dist_unbatched = [ 2765 LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i]) 2766 for i in range(mean.size(0)) 2767 ] 2768 2769 x = dist_batched.sample((10,)) 2770 batched_prob = dist_batched.log_prob(x) 2771 unbatched_prob = torch.stack( 2772 [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 2773 ).t() 2774 2775 self.assertEqual(batched_prob.shape, unbatched_prob.shape) 2776 self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 2777 2778 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2779 def test_lowrank_multivariate_normal_sample(self): 2780 set_rng_seed(0) # see Note [Randomized statistical tests] 2781 mean = torch.randn(5, requires_grad=True) 2782 cov_factor = torch.randn(5, 1, requires_grad=True) 2783 cov_diag = torch.randn(5).abs().requires_grad_() 2784 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2785 2786 self._check_sampler_sampler( 2787 LowRankMultivariateNormal(mean, cov_factor, cov_diag), 2788 scipy.stats.multivariate_normal( 2789 mean.detach().numpy(), cov.detach().numpy() 2790 ), 2791 f"LowRankMultivariateNormal(loc={mean}, cov_factor={cov_factor}, cov_diag={cov_diag})", 2792 multivariate=True, 2793 ) 2794 2795 def test_lowrank_multivariate_normal_properties(self): 2796 loc = torch.randn(5) 2797 cov_factor = torch.randn(5, 2) 2798 cov_diag = torch.randn(5).abs() 2799 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() 2800 m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag) 2801 m2 = MultivariateNormal(loc=loc, covariance_matrix=cov) 2802 self.assertEqual(m1.mean, m2.mean) 2803 self.assertEqual(m1.variance, m2.variance) 2804 self.assertEqual(m1.covariance_matrix, m2.covariance_matrix) 2805 self.assertEqual(m1.scale_tril, m2.scale_tril) 2806 self.assertEqual(m1.precision_matrix, m2.precision_matrix) 2807 self.assertEqual(m1.entropy(), m2.entropy()) 2808 2809 def test_lowrank_multivariate_normal_moments(self): 2810 set_rng_seed(0) # see Note [Randomized statistical tests] 2811 mean = torch.randn(5) 2812 cov_factor = torch.randn(5, 2) 2813 cov_diag = torch.randn(5).abs() 2814 d = LowRankMultivariateNormal(mean, cov_factor, cov_diag) 2815 samples = d.rsample((100000,)) 2816 empirical_mean = samples.mean(0) 2817 self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0) 2818 empirical_var = samples.var(0) 2819 self.assertEqual(d.variance, empirical_var, atol=0.02, rtol=0) 2820 2821 @set_default_dtype(torch.double) 2822 def test_multivariate_normal_shape(self): 2823 mean = torch.randn(5, 3, requires_grad=True) 2824 mean_no_batch = torch.randn(3, requires_grad=True) 2825 mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True) 2826 2827 # construct PSD covariance 2828 tmp = torch.randn(3, 10) 2829 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 2830 prec = cov.inverse().requires_grad_() 2831 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 2832 2833 # construct batch of PSD covariances 2834 tmp = torch.randn(6, 5, 3, 10) 2835 cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 2836 prec_batched = cov_batched.inverse() 2837 scale_tril_batched = torch.linalg.cholesky(cov_batched) 2838 2839 # ensure that sample, batch, event shapes all handled correctly 2840 self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3)) 2841 self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,)) 2842 self.assertEqual( 2843 MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3) 2844 ) 2845 self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3)) 2846 self.assertEqual( 2847 MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3) 2848 ) 2849 self.assertEqual( 2850 MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3) 2851 ) 2852 self.assertEqual( 2853 MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3) 2854 ) 2855 self.assertEqual( 2856 MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3) 2857 ) 2858 self.assertEqual( 2859 MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), 2860 (2, 7, 6, 5, 3), 2861 ) 2862 self.assertEqual( 2863 MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3) 2864 ) 2865 self.assertEqual( 2866 MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), 2867 (2, 7, 6, 5, 3), 2868 ) 2869 self.assertEqual( 2870 MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), 2871 (2, 7, 6, 5, 3), 2872 ) 2873 self.assertEqual( 2874 MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), 2875 (2, 7, 5, 3), 2876 ) 2877 self.assertEqual( 2878 MultivariateNormal(mean, precision_matrix=prec_batched) 2879 .sample((2, 7)) 2880 .size(), 2881 (2, 7, 6, 5, 3), 2882 ) 2883 self.assertEqual( 2884 MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), 2885 (2, 7, 5, 3), 2886 ) 2887 self.assertEqual( 2888 MultivariateNormal(mean, scale_tril=scale_tril_batched) 2889 .sample((2, 7)) 2890 .size(), 2891 (2, 7, 6, 5, 3), 2892 ) 2893 2894 # check gradients 2895 # We write a custom gradcheck function to maintain the symmetry 2896 # of the perturbed covariances and their inverses (precision) 2897 def multivariate_normal_log_prob_gradcheck( 2898 mean, covariance=None, precision=None, scale_tril=None 2899 ): 2900 mvn_samples = ( 2901 MultivariateNormal(mean, covariance, precision, scale_tril) 2902 .sample() 2903 .requires_grad_() 2904 ) 2905 2906 def gradcheck_func(samples, mu, sigma, prec, scale_tril): 2907 if sigma is not None: 2908 sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance 2909 if prec is not None: 2910 prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision 2911 if scale_tril is not None: 2912 scale_tril = scale_tril.tril() 2913 return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples) 2914 2915 gradcheck( 2916 gradcheck_func, 2917 (mvn_samples, mean, covariance, precision, scale_tril), 2918 raise_exception=True, 2919 ) 2920 2921 multivariate_normal_log_prob_gradcheck(mean, cov) 2922 multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov) 2923 multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched) 2924 multivariate_normal_log_prob_gradcheck(mean, None, prec) 2925 multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched) 2926 multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril) 2927 multivariate_normal_log_prob_gradcheck( 2928 mean_no_batch, None, None, scale_tril_batched 2929 ) 2930 2931 @set_default_dtype(torch.double) 2932 def test_multivariate_normal_stable_with_precision_matrix(self): 2933 x = torch.randn(10) 2934 P = torch.exp(-((x - x.unsqueeze(-1)) ** 2)) # RBF kernel 2935 MultivariateNormal(x.new_zeros(10), precision_matrix=P) 2936 2937 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 2938 def test_multivariate_normal_log_prob(self): 2939 mean = torch.randn(3, requires_grad=True) 2940 tmp = torch.randn(3, 10) 2941 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 2942 prec = cov.inverse().requires_grad_() 2943 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 2944 2945 # check that logprob values match scipy logpdf, 2946 # and that covariance and scale_tril parameters are equivalent 2947 dist1 = MultivariateNormal(mean, cov) 2948 dist2 = MultivariateNormal(mean, precision_matrix=prec) 2949 dist3 = MultivariateNormal(mean, scale_tril=scale_tril) 2950 ref_dist = scipy.stats.multivariate_normal( 2951 mean.detach().numpy(), cov.detach().numpy() 2952 ) 2953 2954 x = dist1.sample((10,)) 2955 expected = ref_dist.logpdf(x.numpy()) 2956 2957 self.assertEqual( 2958 0.0, 2959 np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 2960 atol=1e-3, 2961 rtol=0, 2962 ) 2963 self.assertEqual( 2964 0.0, 2965 np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2), 2966 atol=1e-3, 2967 rtol=0, 2968 ) 2969 self.assertEqual( 2970 0.0, 2971 np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2), 2972 atol=1e-3, 2973 rtol=0, 2974 ) 2975 2976 # Double-check that batched versions behave the same as unbatched 2977 mean = torch.randn(5, 3, requires_grad=True) 2978 tmp = torch.randn(5, 3, 10) 2979 cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 2980 2981 dist_batched = MultivariateNormal(mean, cov) 2982 dist_unbatched = [ 2983 MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0)) 2984 ] 2985 2986 x = dist_batched.sample((10,)) 2987 batched_prob = dist_batched.log_prob(x) 2988 unbatched_prob = torch.stack( 2989 [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 2990 ).t() 2991 2992 self.assertEqual(batched_prob.shape, unbatched_prob.shape) 2993 self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 2994 2995 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 2996 def test_multivariate_normal_sample(self): 2997 set_rng_seed(0) # see Note [Randomized statistical tests] 2998 mean = torch.randn(3, requires_grad=True) 2999 tmp = torch.randn(3, 10) 3000 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3001 prec = cov.inverse().requires_grad_() 3002 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3003 3004 self._check_sampler_sampler( 3005 MultivariateNormal(mean, cov), 3006 scipy.stats.multivariate_normal( 3007 mean.detach().numpy(), cov.detach().numpy() 3008 ), 3009 f"MultivariateNormal(loc={mean}, cov={cov})", 3010 multivariate=True, 3011 ) 3012 self._check_sampler_sampler( 3013 MultivariateNormal(mean, precision_matrix=prec), 3014 scipy.stats.multivariate_normal( 3015 mean.detach().numpy(), cov.detach().numpy() 3016 ), 3017 f"MultivariateNormal(loc={mean}, atol={prec})", 3018 multivariate=True, 3019 ) 3020 self._check_sampler_sampler( 3021 MultivariateNormal(mean, scale_tril=scale_tril), 3022 scipy.stats.multivariate_normal( 3023 mean.detach().numpy(), cov.detach().numpy() 3024 ), 3025 f"MultivariateNormal(loc={mean}, scale_tril={scale_tril})", 3026 multivariate=True, 3027 ) 3028 3029 @set_default_dtype(torch.double) 3030 def test_multivariate_normal_properties(self): 3031 loc = torch.randn(5) 3032 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) 3033 m = MultivariateNormal(loc=loc, scale_tril=scale_tril) 3034 self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t())) 3035 self.assertEqual( 3036 m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]) 3037 ) 3038 self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix)) 3039 3040 @set_default_dtype(torch.double) 3041 def test_multivariate_normal_moments(self): 3042 set_rng_seed(0) # see Note [Randomized statistical tests] 3043 mean = torch.randn(5) 3044 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) 3045 d = MultivariateNormal(mean, scale_tril=scale_tril) 3046 samples = d.rsample((100000,)) 3047 empirical_mean = samples.mean(0) 3048 self.assertEqual(d.mean, empirical_mean, atol=0.01, rtol=0) 3049 empirical_var = samples.var(0) 3050 self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0) 3051 3052 # We applied same tests in Multivariate Normal distribution for Wishart distribution 3053 @set_default_dtype(torch.double) 3054 def test_wishart_shape(self): 3055 set_rng_seed(0) # see Note [Randomized statistical tests] 3056 ndim = 3 3057 3058 df = torch.rand(5, requires_grad=True) + ndim 3059 df_no_batch = torch.rand([], requires_grad=True) + ndim 3060 df_multi_batch = torch.rand(6, 5, requires_grad=True) + ndim 3061 3062 # construct PSD covariance 3063 tmp = torch.randn(ndim, 10) 3064 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3065 prec = cov.inverse().requires_grad_() 3066 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3067 3068 # construct batch of PSD covariances 3069 tmp = torch.randn(6, 5, ndim, 10) 3070 cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 3071 prec_batched = cov_batched.inverse() 3072 scale_tril_batched = torch.linalg.cholesky(cov_batched) 3073 3074 # ensure that sample, batch, event shapes all handled correctly 3075 self.assertEqual(Wishart(df, cov).sample().size(), (5, ndim, ndim)) 3076 self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (ndim, ndim)) 3077 self.assertEqual( 3078 Wishart(df_multi_batch, cov).sample().size(), (6, 5, ndim, ndim) 3079 ) 3080 self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, ndim, ndim)) 3081 self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, ndim, ndim)) 3082 self.assertEqual( 3083 Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, ndim, ndim) 3084 ) 3085 self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, ndim, ndim)) 3086 self.assertEqual( 3087 Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, ndim, ndim) 3088 ) 3089 self.assertEqual( 3090 Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim) 3091 ) 3092 self.assertEqual( 3093 Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, ndim, ndim) 3094 ) 3095 self.assertEqual( 3096 Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), 3097 (2, 7, 6, 5, ndim, ndim), 3098 ) 3099 self.assertEqual( 3100 Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), 3101 (2, 7, 6, 5, ndim, ndim), 3102 ) 3103 self.assertEqual( 3104 Wishart(df, precision_matrix=prec).sample((2, 7)).size(), 3105 (2, 7, 5, ndim, ndim), 3106 ) 3107 self.assertEqual( 3108 Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), 3109 (2, 7, 6, 5, ndim, ndim), 3110 ) 3111 self.assertEqual( 3112 Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), 3113 (2, 7, 5, ndim, ndim), 3114 ) 3115 self.assertEqual( 3116 Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), 3117 (2, 7, 6, 5, ndim, ndim), 3118 ) 3119 3120 # check gradients 3121 # Modified and applied the same tests for multivariate_normal 3122 def wishart_log_prob_gradcheck( 3123 df=None, covariance=None, precision=None, scale_tril=None 3124 ): 3125 wishart_samples = ( 3126 Wishart(df, covariance, precision, scale_tril).sample().requires_grad_() 3127 ) 3128 3129 def gradcheck_func(samples, nu, sigma, prec, scale_tril): 3130 if sigma is not None: 3131 sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance 3132 if prec is not None: 3133 prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision 3134 if scale_tril is not None: 3135 scale_tril = scale_tril.tril() 3136 return Wishart(nu, sigma, prec, scale_tril).log_prob(samples) 3137 3138 gradcheck( 3139 gradcheck_func, 3140 (wishart_samples, df, covariance, precision, scale_tril), 3141 raise_exception=True, 3142 ) 3143 3144 wishart_log_prob_gradcheck(df, cov) 3145 wishart_log_prob_gradcheck(df_multi_batch, cov) 3146 wishart_log_prob_gradcheck(df_multi_batch, cov_batched) 3147 wishart_log_prob_gradcheck(df, None, prec) 3148 wishart_log_prob_gradcheck(df_no_batch, None, prec_batched) 3149 wishart_log_prob_gradcheck(df, None, None, scale_tril) 3150 wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched) 3151 3152 def test_wishart_stable_with_precision_matrix(self): 3153 set_rng_seed(0) # see Note [Randomized statistical tests] 3154 ndim = 10 3155 x = torch.randn(ndim) 3156 P = torch.exp(-((x - x.unsqueeze(-1)) ** 2)) # RBF kernel 3157 Wishart(torch.tensor(ndim), precision_matrix=P) 3158 3159 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3160 @set_default_dtype(torch.double) 3161 def test_wishart_log_prob(self): 3162 set_rng_seed(0) # see Note [Randomized statistical tests] 3163 ndim = 3 3164 df = torch.rand([], requires_grad=True) + ndim - 1 3165 # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3166 if version.parse(scipy.__version__) < version.parse("1.7.0"): 3167 df += 1.0 3168 tmp = torch.randn(ndim, 10) 3169 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3170 prec = cov.inverse().requires_grad_() 3171 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3172 3173 # check that logprob values match scipy logpdf, 3174 # and that covariance and scale_tril parameters are equivalent 3175 dist1 = Wishart(df, cov) 3176 dist2 = Wishart(df, precision_matrix=prec) 3177 dist3 = Wishart(df, scale_tril=scale_tril) 3178 ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy()) 3179 3180 x = dist1.sample((1000,)) 3181 expected = ref_dist.logpdf(x.transpose(0, 2).numpy()) 3182 3183 self.assertEqual( 3184 0.0, 3185 np.mean((dist1.log_prob(x).detach().numpy() - expected) ** 2), 3186 atol=1e-3, 3187 rtol=0, 3188 ) 3189 self.assertEqual( 3190 0.0, 3191 np.mean((dist2.log_prob(x).detach().numpy() - expected) ** 2), 3192 atol=1e-3, 3193 rtol=0, 3194 ) 3195 self.assertEqual( 3196 0.0, 3197 np.mean((dist3.log_prob(x).detach().numpy() - expected) ** 2), 3198 atol=1e-3, 3199 rtol=0, 3200 ) 3201 3202 # Double-check that batched versions behave the same as unbatched 3203 df = torch.rand(5, requires_grad=True) + ndim - 1 3204 # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3205 if version.parse(scipy.__version__) < version.parse("1.7.0"): 3206 df += 1.0 3207 tmp = torch.randn(5, ndim, 10) 3208 cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_() 3209 3210 dist_batched = Wishart(df, cov) 3211 dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))] 3212 3213 x = dist_batched.sample((1000,)) 3214 batched_prob = dist_batched.log_prob(x) 3215 unbatched_prob = torch.stack( 3216 [dist_unbatched[i].log_prob(x[:, i]) for i in range(5)] 3217 ).t() 3218 3219 self.assertEqual(batched_prob.shape, unbatched_prob.shape) 3220 self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) 3221 3222 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3223 @set_default_dtype(torch.double) 3224 def test_wishart_sample(self): 3225 set_rng_seed(0) # see Note [Randomized statistical tests] 3226 ndim = 3 3227 df = torch.rand([], requires_grad=True) + ndim - 1 3228 # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 3229 if version.parse(scipy.__version__) < version.parse("1.7.0"): 3230 df += 1.0 3231 tmp = torch.randn(ndim, 10) 3232 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_() 3233 prec = cov.inverse().requires_grad_() 3234 scale_tril = torch.linalg.cholesky(cov).requires_grad_() 3235 3236 ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy()) 3237 3238 self._check_sampler_sampler( 3239 Wishart(df, cov), 3240 ref_dist, 3241 f"Wishart(df={df}, covariance_matrix={cov})", 3242 multivariate=True, 3243 ) 3244 self._check_sampler_sampler( 3245 Wishart(df, precision_matrix=prec), 3246 ref_dist, 3247 f"Wishart(df={df}, precision_matrix={prec})", 3248 multivariate=True, 3249 ) 3250 self._check_sampler_sampler( 3251 Wishart(df, scale_tril=scale_tril), 3252 ref_dist, 3253 f"Wishart(df={df}, scale_tril={scale_tril})", 3254 multivariate=True, 3255 ) 3256 3257 def test_wishart_properties(self): 3258 set_rng_seed(0) # see Note [Randomized statistical tests] 3259 ndim = 5 3260 df = torch.rand([]) + ndim - 1 3261 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim)) 3262 m = Wishart(df=df, scale_tril=scale_tril) 3263 self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t())) 3264 self.assertEqual( 3265 m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]) 3266 ) 3267 self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix)) 3268 3269 def test_wishart_moments(self): 3270 set_rng_seed(0) # see Note [Randomized statistical tests] 3271 ndim = 3 3272 df = torch.rand([]) + ndim - 1 3273 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(ndim, ndim)) 3274 d = Wishart(df=df, scale_tril=scale_tril) 3275 samples = d.rsample((ndim * ndim * 100000,)) 3276 empirical_mean = samples.mean(0) 3277 self.assertEqual(d.mean, empirical_mean, atol=0.5, rtol=0) 3278 empirical_var = samples.var(0) 3279 self.assertEqual(d.variance, empirical_var, atol=0.5, rtol=0) 3280 3281 @set_default_dtype(torch.double) 3282 def test_exponential(self): 3283 rate = torch.randn(5, 5).abs().requires_grad_() 3284 rate_1d = torch.randn(1).abs().requires_grad_() 3285 self.assertEqual(Exponential(rate).sample().size(), (5, 5)) 3286 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5)) 3287 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1)) 3288 self.assertEqual(Exponential(rate_1d).sample().size(), (1,)) 3289 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) 3290 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) 3291 3292 self._gradcheck_log_prob(Exponential, (rate,)) 3293 state = torch.get_rng_state() 3294 eps = rate.new(rate.size()).exponential_() 3295 torch.set_rng_state(state) 3296 z = Exponential(rate).rsample() 3297 z.backward(torch.ones_like(z)) 3298 self.assertEqual(rate.grad, -eps / rate**2) 3299 rate.grad.zero_() 3300 self.assertEqual(z.size(), (5, 5)) 3301 3302 def ref_log_prob(idx, x, log_prob): 3303 m = rate.view(-1)[idx] 3304 expected = math.log(m) - m * x 3305 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3306 3307 self._check_log_prob(Exponential(rate), ref_log_prob) 3308 self._check_forward_ad(lambda x: x.exponential_()) 3309 3310 def mean_var(lambd, sample): 3311 sample.exponential_(lambd) 3312 mean = sample.float().mean() 3313 var = sample.float().var() 3314 self.assertEqual((1.0 / lambd), mean, atol=2e-2, rtol=2e-2) 3315 self.assertEqual((1.0 / lambd) ** 2, var, atol=2e-2, rtol=2e-2) 3316 3317 for dtype in [torch.float, torch.double, torch.bfloat16, torch.float16]: 3318 for lambd in [0.2, 0.5, 1.0, 1.5, 2.0, 5.0]: 3319 sample_len = 50000 3320 mean_var(lambd, torch.rand(sample_len, dtype=dtype)) 3321 3322 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3323 def test_exponential_sample(self): 3324 set_rng_seed(1) # see Note [Randomized statistical tests] 3325 for rate in [1e-5, 1.0, 10.0]: 3326 self._check_sampler_sampler( 3327 Exponential(rate), 3328 scipy.stats.expon(scale=1.0 / rate), 3329 f"Exponential(rate={rate})", 3330 ) 3331 3332 @set_default_dtype(torch.double) 3333 def test_laplace(self): 3334 loc = torch.randn(5, 5, requires_grad=True) 3335 scale = torch.randn(5, 5).abs().requires_grad_() 3336 loc_1d = torch.randn(1, requires_grad=True) 3337 scale_1d = torch.randn(1, requires_grad=True) 3338 loc_delta = torch.tensor([1.0, 0.0]) 3339 scale_delta = torch.tensor([1e-5, 1e-5]) 3340 self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5)) 3341 self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5)) 3342 self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 3343 self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,)) 3344 self.assertEqual(Laplace(0.2, 0.6).sample((1,)).size(), (1,)) 3345 self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,)) 3346 3347 # sample check for extreme value of mean, std 3348 set_rng_seed(0) 3349 self.assertEqual( 3350 Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)), 3351 torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]), 3352 atol=1e-4, 3353 rtol=0, 3354 ) 3355 3356 self._gradcheck_log_prob(Laplace, (loc, scale)) 3357 self._gradcheck_log_prob(Laplace, (loc, 1.0)) 3358 self._gradcheck_log_prob(Laplace, (0.0, scale)) 3359 3360 state = torch.get_rng_state() 3361 eps = torch.ones_like(loc).uniform_(-0.5, 0.5) 3362 torch.set_rng_state(state) 3363 z = Laplace(loc, scale).rsample() 3364 z.backward(torch.ones_like(z)) 3365 self.assertEqual(loc.grad, torch.ones_like(loc)) 3366 self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs())) 3367 loc.grad.zero_() 3368 scale.grad.zero_() 3369 self.assertEqual(z.size(), (5, 5)) 3370 3371 def ref_log_prob(idx, x, log_prob): 3372 m = loc.view(-1)[idx] 3373 s = scale.view(-1)[idx] 3374 expected = -math.log(2 * s) - abs(x - m) / s 3375 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3376 3377 self._check_log_prob(Laplace(loc, scale), ref_log_prob) 3378 3379 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3380 @set_default_dtype(torch.double) 3381 def test_laplace_sample(self): 3382 set_rng_seed(1) # see Note [Randomized statistical tests] 3383 for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]): 3384 self._check_sampler_sampler( 3385 Laplace(loc, scale), 3386 scipy.stats.laplace(loc=loc, scale=scale), 3387 f"Laplace(loc={loc}, scale={scale})", 3388 ) 3389 3390 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3391 def test_gamma_shape(self): 3392 alpha = torch.randn(2, 3).exp().requires_grad_() 3393 beta = torch.randn(2, 3).exp().requires_grad_() 3394 alpha_1d = torch.randn(1).exp().requires_grad_() 3395 beta_1d = torch.randn(1).exp().requires_grad_() 3396 self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3)) 3397 self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 3398 self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 3399 self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,)) 3400 self.assertEqual(Gamma(0.5, 0.5).sample().size(), ()) 3401 self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,)) 3402 3403 def ref_log_prob(idx, x, log_prob): 3404 a = alpha.view(-1)[idx].detach() 3405 b = beta.view(-1)[idx].detach() 3406 expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b) 3407 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3408 3409 self._check_log_prob(Gamma(alpha, beta), ref_log_prob) 3410 3411 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3412 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3413 def test_gamma_gpu_shape(self): 3414 alpha = torch.randn(2, 3).cuda().exp().requires_grad_() 3415 beta = torch.randn(2, 3).cuda().exp().requires_grad_() 3416 alpha_1d = torch.randn(1).cuda().exp().requires_grad_() 3417 beta_1d = torch.randn(1).cuda().exp().requires_grad_() 3418 self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3)) 3419 self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3)) 3420 self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1)) 3421 self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,)) 3422 self.assertEqual(Gamma(0.5, 0.5).sample().size(), ()) 3423 self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,)) 3424 3425 def ref_log_prob(idx, x, log_prob): 3426 a = alpha.view(-1)[idx].detach().cpu() 3427 b = beta.view(-1)[idx].detach().cpu() 3428 expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b) 3429 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3430 3431 self._check_log_prob(Gamma(alpha, beta), ref_log_prob) 3432 3433 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3434 def test_gamma_sample(self): 3435 set_rng_seed(0) # see Note [Randomized statistical tests] 3436 for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3437 self._check_sampler_sampler( 3438 Gamma(alpha, beta), 3439 scipy.stats.gamma(alpha, scale=1.0 / beta), 3440 f"Gamma(concentration={alpha}, rate={beta})", 3441 ) 3442 3443 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3444 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3445 def test_gamma_gpu_sample(self): 3446 set_rng_seed(0) 3447 for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3448 a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda() 3449 self._check_sampler_sampler( 3450 Gamma(a, b), 3451 scipy.stats.gamma(alpha, scale=1.0 / beta), 3452 f"Gamma(alpha={alpha}, beta={beta})", 3453 failure_rate=1e-4, 3454 ) 3455 3456 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3457 def test_pareto(self): 3458 scale = torch.randn(2, 3).abs().requires_grad_() 3459 alpha = torch.randn(2, 3).abs().requires_grad_() 3460 scale_1d = torch.randn(1).abs().requires_grad_() 3461 alpha_1d = torch.randn(1).abs().requires_grad_() 3462 self.assertEqual(Pareto(scale_1d, 0.5).mean, inf) 3463 self.assertEqual(Pareto(scale_1d, 0.5).variance, inf) 3464 self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3)) 3465 self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3)) 3466 self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1)) 3467 self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,)) 3468 self.assertEqual(Pareto(1.0, 1.0).sample().size(), ()) 3469 self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,)) 3470 3471 def ref_log_prob(idx, x, log_prob): 3472 s = scale.view(-1)[idx].detach() 3473 a = alpha.view(-1)[idx].detach() 3474 expected = scipy.stats.pareto.logpdf(x, a, scale=s) 3475 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3476 3477 self._check_log_prob(Pareto(scale, alpha), ref_log_prob) 3478 3479 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3480 def test_pareto_sample(self): 3481 set_rng_seed(1) # see Note [Randomized statistical tests] 3482 for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3483 self._check_sampler_sampler( 3484 Pareto(scale, alpha), 3485 scipy.stats.pareto(alpha, scale=scale), 3486 f"Pareto(scale={scale}, alpha={alpha})", 3487 ) 3488 3489 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3490 def test_gumbel(self): 3491 loc = torch.randn(2, 3, requires_grad=True) 3492 scale = torch.randn(2, 3).abs().requires_grad_() 3493 loc_1d = torch.randn(1, requires_grad=True) 3494 scale_1d = torch.randn(1).abs().requires_grad_() 3495 self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3)) 3496 self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3)) 3497 self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,)) 3498 self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1)) 3499 self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ()) 3500 self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,)) 3501 self.assertEqual( 3502 Gumbel( 3503 torch.tensor(0.0, dtype=torch.float32), 3504 torch.tensor(1.0, dtype=torch.float32), 3505 validate_args=False, 3506 ).cdf(20.0), 3507 1.0, 3508 atol=1e-4, 3509 rtol=0, 3510 ) 3511 self.assertEqual( 3512 Gumbel( 3513 torch.tensor(0.0, dtype=torch.float64), 3514 torch.tensor(1.0, dtype=torch.float64), 3515 validate_args=False, 3516 ).cdf(50.0), 3517 1.0, 3518 atol=1e-4, 3519 rtol=0, 3520 ) 3521 self.assertEqual( 3522 Gumbel( 3523 torch.tensor(0.0, dtype=torch.float32), 3524 torch.tensor(1.0, dtype=torch.float32), 3525 validate_args=False, 3526 ).cdf(-5.0), 3527 0.0, 3528 atol=1e-4, 3529 rtol=0, 3530 ) 3531 self.assertEqual( 3532 Gumbel( 3533 torch.tensor(0.0, dtype=torch.float64), 3534 torch.tensor(1.0, dtype=torch.float64), 3535 validate_args=False, 3536 ).cdf(-10.0), 3537 0.0, 3538 atol=1e-8, 3539 rtol=0, 3540 ) 3541 3542 def ref_log_prob(idx, x, log_prob): 3543 l = loc.view(-1)[idx].detach() 3544 s = scale.view(-1)[idx].detach() 3545 expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s) 3546 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3547 3548 self._check_log_prob(Gumbel(loc, scale), ref_log_prob) 3549 3550 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3551 @set_default_dtype(torch.double) 3552 def test_gumbel_sample(self): 3553 set_rng_seed(1) # see note [Randomized statistical tests] 3554 for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): 3555 self._check_sampler_sampler( 3556 Gumbel(loc, scale), 3557 scipy.stats.gumbel_r(loc=loc, scale=scale), 3558 f"Gumbel(loc={loc}, scale={scale})", 3559 ) 3560 3561 def test_kumaraswamy_shape(self): 3562 concentration1 = torch.randn(2, 3).abs().requires_grad_() 3563 concentration0 = torch.randn(2, 3).abs().requires_grad_() 3564 concentration1_1d = torch.randn(1).abs().requires_grad_() 3565 concentration0_1d = torch.randn(1).abs().requires_grad_() 3566 self.assertEqual( 3567 Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3) 3568 ) 3569 self.assertEqual( 3570 Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3) 3571 ) 3572 self.assertEqual( 3573 Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,) 3574 ) 3575 self.assertEqual( 3576 Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), 3577 (1, 1), 3578 ) 3579 self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ()) 3580 self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,)) 3581 3582 # Kumaraswamy distribution is not implemented in SciPy 3583 # Hence these tests are explicit 3584 def test_kumaraswamy_mean_variance(self): 3585 c1_1 = torch.randn(2, 3).abs().requires_grad_() 3586 c0_1 = torch.randn(2, 3).abs().requires_grad_() 3587 c1_2 = torch.randn(4).abs().requires_grad_() 3588 c0_2 = torch.randn(4).abs().requires_grad_() 3589 cases = [(c1_1, c0_1), (c1_2, c0_2)] 3590 for i, (a, b) in enumerate(cases): 3591 m = Kumaraswamy(a, b) 3592 samples = m.sample((60000,)) 3593 expected = samples.mean(0) 3594 actual = m.mean 3595 error = (expected - actual).abs() 3596 max_error = max(error[error == error]) 3597 self.assertLess( 3598 max_error, 3599 0.01, 3600 f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .mean", 3601 ) 3602 expected = samples.var(0) 3603 actual = m.variance 3604 error = (expected - actual).abs() 3605 max_error = max(error[error == error]) 3606 self.assertLess( 3607 max_error, 3608 0.01, 3609 f"Kumaraswamy example {i + 1}/{len(cases)}, incorrect .variance", 3610 ) 3611 3612 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3613 def test_fishersnedecor(self): 3614 df1 = torch.randn(2, 3).abs().requires_grad_() 3615 df2 = torch.randn(2, 3).abs().requires_grad_() 3616 df1_1d = torch.randn(1).abs() 3617 df2_1d = torch.randn(1).abs() 3618 self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean)) 3619 self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance)) 3620 self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3)) 3621 self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3)) 3622 self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,)) 3623 self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1)) 3624 self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ()) 3625 self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,)) 3626 3627 def ref_log_prob(idx, x, log_prob): 3628 f1 = df1.view(-1)[idx].detach() 3629 f2 = df2.view(-1)[idx].detach() 3630 expected = scipy.stats.f.logpdf(x, f1, f2) 3631 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3632 3633 self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob) 3634 3635 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3636 def test_fishersnedecor_sample(self): 3637 set_rng_seed(1) # see note [Randomized statistical tests] 3638 for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]): 3639 self._check_sampler_sampler( 3640 FisherSnedecor(df1, df2), 3641 scipy.stats.f(df1, df2), 3642 f"FisherSnedecor(loc={df1}, scale={df2})", 3643 ) 3644 3645 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3646 def test_chi2_shape(self): 3647 df = torch.randn(2, 3).exp().requires_grad_() 3648 df_1d = torch.randn(1).exp().requires_grad_() 3649 self.assertEqual(Chi2(df).sample().size(), (2, 3)) 3650 self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3)) 3651 self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1)) 3652 self.assertEqual(Chi2(df_1d).sample().size(), (1,)) 3653 self.assertEqual( 3654 Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), () 3655 ) 3656 self.assertEqual(Chi2(0.5).sample().size(), ()) 3657 self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,)) 3658 3659 def ref_log_prob(idx, x, log_prob): 3660 d = df.view(-1)[idx].detach() 3661 expected = scipy.stats.chi2.logpdf(x, d) 3662 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3663 3664 self._check_log_prob(Chi2(df), ref_log_prob) 3665 3666 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3667 def test_chi2_sample(self): 3668 set_rng_seed(0) # see Note [Randomized statistical tests] 3669 for df in [0.1, 1.0, 5.0]: 3670 self._check_sampler_sampler( 3671 Chi2(df), scipy.stats.chi2(df), f"Chi2(df={df})" 3672 ) 3673 3674 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3675 def test_studentT(self): 3676 df = torch.randn(2, 3).exp().requires_grad_() 3677 df_1d = torch.randn(1).exp().requires_grad_() 3678 self.assertTrue(is_all_nan(StudentT(1).mean)) 3679 self.assertTrue(is_all_nan(StudentT(1).variance)) 3680 self.assertEqual(StudentT(2).variance, inf) 3681 self.assertEqual(StudentT(df).sample().size(), (2, 3)) 3682 self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3)) 3683 self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1)) 3684 self.assertEqual(StudentT(df_1d).sample().size(), (1,)) 3685 self.assertEqual( 3686 StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), () 3687 ) 3688 self.assertEqual(StudentT(0.5).sample().size(), ()) 3689 self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,)) 3690 3691 def ref_log_prob(idx, x, log_prob): 3692 d = df.view(-1)[idx].detach() 3693 expected = scipy.stats.t.logpdf(x, d) 3694 self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) 3695 3696 self._check_log_prob(StudentT(df), ref_log_prob) 3697 3698 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3699 @set_default_dtype(torch.double) 3700 def test_studentT_sample(self): 3701 set_rng_seed(11) # see Note [Randomized statistical tests] 3702 for df, loc, scale in product( 3703 [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0] 3704 ): 3705 self._check_sampler_sampler( 3706 StudentT(df=df, loc=loc, scale=scale), 3707 scipy.stats.t(df=df, loc=loc, scale=scale), 3708 f"StudentT(df={df}, loc={loc}, scale={scale})", 3709 ) 3710 3711 @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3712 def test_studentT_log_prob(self): 3713 set_rng_seed(0) # see Note [Randomized statistical tests] 3714 num_samples = 10 3715 for df, loc, scale in product( 3716 [0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0] 3717 ): 3718 dist = StudentT(df=df, loc=loc, scale=scale) 3719 x = dist.sample((num_samples,)) 3720 actual_log_prob = dist.log_prob(x) 3721 for i in range(num_samples): 3722 expected_log_prob = scipy.stats.t.logpdf( 3723 x[i], df=df, loc=loc, scale=scale 3724 ) 3725 self.assertEqual( 3726 float(actual_log_prob[i]), 3727 float(expected_log_prob), 3728 atol=1e-3, 3729 rtol=0, 3730 ) 3731 3732 def test_dirichlet_shape(self): 3733 alpha = torch.randn(2, 3).exp().requires_grad_() 3734 alpha_1d = torch.randn(4).exp().requires_grad_() 3735 self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3)) 3736 self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3)) 3737 self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,)) 3738 self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4)) 3739 3740 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3741 @set_default_dtype(torch.double) 3742 def test_dirichlet_log_prob(self): 3743 num_samples = 10 3744 alpha = torch.exp(torch.randn(5)) 3745 dist = Dirichlet(alpha) 3746 x = dist.sample((num_samples,)) 3747 actual_log_prob = dist.log_prob(x) 3748 for i in range(num_samples): 3749 expected_log_prob = scipy.stats.dirichlet.logpdf( 3750 x[i].numpy(), alpha.numpy() 3751 ) 3752 self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0) 3753 3754 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3755 def test_dirichlet_log_prob_zero(self): 3756 # Specifically test the special case where x=0 and alpha=1. The PDF is 3757 # proportional to x**(alpha-1), which in this case works out to 0**0=1. 3758 # The log PDF of this term should therefore be 0. However, it's easy 3759 # to accidentally introduce NaNs by calculating log(x) without regard 3760 # for the value of alpha-1. 3761 alpha = torch.tensor([1, 2]) 3762 dist = Dirichlet(alpha) 3763 x = torch.tensor([0, 1]) 3764 actual_log_prob = dist.log_prob(x) 3765 expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy()) 3766 self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0) 3767 3768 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3769 def test_dirichlet_sample(self): 3770 set_rng_seed(0) # see Note [Randomized statistical tests] 3771 alpha = torch.exp(torch.randn(3)) 3772 self._check_sampler_sampler( 3773 Dirichlet(alpha), 3774 scipy.stats.dirichlet(alpha.numpy()), 3775 f"Dirichlet(alpha={list(alpha)})", 3776 multivariate=True, 3777 ) 3778 3779 def test_dirichlet_mode(self): 3780 # Test a few edge cases for the Dirichlet distribution mode. This also covers beta distributions. 3781 concentrations_and_modes = [ 3782 ([2, 2, 1], [0.5, 0.5, 0.0]), 3783 ([3, 2, 1], [2 / 3, 1 / 3, 0]), 3784 ([0.5, 0.2, 0.2], [1.0, 0.0, 0.0]), 3785 ([1, 1, 1], [nan, nan, nan]), 3786 ] 3787 for concentration, mode in concentrations_and_modes: 3788 dist = Dirichlet(torch.tensor(concentration)) 3789 self.assertEqual(dist.mode, torch.tensor(mode)) 3790 3791 def test_beta_shape(self): 3792 con1 = torch.randn(2, 3).exp().requires_grad_() 3793 con0 = torch.randn(2, 3).exp().requires_grad_() 3794 con1_1d = torch.randn(4).exp().requires_grad_() 3795 con0_1d = torch.randn(4).exp().requires_grad_() 3796 self.assertEqual(Beta(con1, con0).sample().size(), (2, 3)) 3797 self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3)) 3798 self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,)) 3799 self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4)) 3800 self.assertEqual(Beta(0.1, 0.3).sample().size(), ()) 3801 self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,)) 3802 3803 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3804 def test_beta_log_prob(self): 3805 for _ in range(100): 3806 con1 = np.exp(np.random.normal()) 3807 con0 = np.exp(np.random.normal()) 3808 dist = Beta(con1, con0) 3809 x = dist.sample() 3810 actual_log_prob = dist.log_prob(x).sum() 3811 expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0) 3812 self.assertEqual( 3813 float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0 3814 ) 3815 3816 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 3817 @set_default_dtype(torch.double) 3818 def test_beta_sample(self): 3819 set_rng_seed(1) # see Note [Randomized statistical tests] 3820 for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]): 3821 self._check_sampler_sampler( 3822 Beta(con1, con0), 3823 scipy.stats.beta(con1, con0), 3824 f"Beta(alpha={con1}, beta={con0})", 3825 ) 3826 # Check that small alphas do not cause NANs. 3827 for Tensor in [torch.FloatTensor, torch.DoubleTensor]: 3828 x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0] 3829 self.assertTrue(np.isfinite(x) and x > 0, f"Invalid Beta.sample(): {x}") 3830 3831 def test_beta_underflow(self): 3832 # For low values of (alpha, beta), the gamma samples can underflow 3833 # with float32 and result in a spurious mode at 0.5. To prevent this, 3834 # torch._sample_dirichlet works with double precision for intermediate 3835 # calculations. 3836 set_rng_seed(1) 3837 num_samples = 50000 3838 for dtype in [torch.float, torch.double]: 3839 conc = torch.tensor(1e-2, dtype=dtype) 3840 beta_samples = Beta(conc, conc).sample([num_samples]) 3841 self.assertEqual((beta_samples == 0).sum(), 0) 3842 self.assertEqual((beta_samples == 1).sum(), 0) 3843 # assert support is concentrated around 0 and 1 3844 frac_zeros = float((beta_samples < 0.1).sum()) / num_samples 3845 frac_ones = float((beta_samples > 0.9).sum()) / num_samples 3846 self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0) 3847 self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0) 3848 3849 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 3850 def test_beta_underflow_gpu(self): 3851 set_rng_seed(1) 3852 num_samples = 50000 3853 conc = torch.tensor(1e-2, dtype=torch.float64).cuda() 3854 beta_samples = Beta(conc, conc).sample([num_samples]) 3855 self.assertEqual((beta_samples == 0).sum(), 0) 3856 self.assertEqual((beta_samples == 1).sum(), 0) 3857 # assert support is concentrated around 0 and 1 3858 frac_zeros = float((beta_samples < 0.1).sum()) / num_samples 3859 frac_ones = float((beta_samples > 0.9).sum()) / num_samples 3860 # TODO: increase precision once imbalance on GPU is fixed. 3861 self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0) 3862 self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0) 3863 3864 @set_default_dtype(torch.double) 3865 def test_continuous_bernoulli(self): 3866 p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) 3867 r = torch.tensor(0.3, requires_grad=True) 3868 s = 0.3 3869 self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3)) 3870 self.assertFalse(ContinuousBernoulli(p).sample().requires_grad) 3871 self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,)) 3872 self.assertEqual(ContinuousBernoulli(r).sample().size(), ()) 3873 self.assertEqual( 3874 ContinuousBernoulli(r).sample((3, 2)).size(), 3875 ( 3876 3, 3877 2, 3878 ), 3879 ) 3880 self.assertEqual(ContinuousBernoulli(s).sample().size(), ()) 3881 self._gradcheck_log_prob(ContinuousBernoulli, (p,)) 3882 3883 def ref_log_prob(idx, val, log_prob): 3884 prob = p[idx] 3885 if prob > 0.499 and prob < 0.501: # using default value of lim here 3886 log_norm_const = ( 3887 math.log(2.0) 3888 + 4.0 / 3.0 * math.pow(prob - 0.5, 2) 3889 + 104.0 / 45.0 * math.pow(prob - 0.5, 4) 3890 ) 3891 else: 3892 log_norm_const = math.log( 3893 2.0 * math.atanh(1.0 - 2.0 * prob) / (1.0 - 2.0 * prob) 3894 ) 3895 res = ( 3896 val * math.log(prob) + (1.0 - val) * math.log1p(-prob) + log_norm_const 3897 ) 3898 self.assertEqual(log_prob, res) 3899 3900 self._check_log_prob(ContinuousBernoulli(p), ref_log_prob) 3901 self._check_log_prob( 3902 ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob 3903 ) 3904 3905 # check entropy computation 3906 self.assertEqual( 3907 ContinuousBernoulli(p).entropy(), 3908 torch.tensor([-0.02938, -0.07641, -0.00682]), 3909 atol=1e-4, 3910 rtol=0, 3911 ) 3912 # entropy below corresponds to the clamped value of prob when using float 64 3913 # the value for float32 should be -1.76898 3914 self.assertEqual( 3915 ContinuousBernoulli(torch.tensor([0.0])).entropy(), 3916 torch.tensor([-2.58473]), 3917 atol=1e-5, 3918 rtol=0, 3919 ) 3920 self.assertEqual( 3921 ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4, rtol=0 3922 ) 3923 3924 def test_continuous_bernoulli_3d(self): 3925 p = torch.full((2, 3, 5), 0.5).requires_grad_() 3926 self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5)) 3927 self.assertEqual( 3928 ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(), (2, 5, 2, 3, 5) 3929 ) 3930 self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) 3931 3932 def test_lkj_cholesky_log_prob(self): 3933 def tril_cholesky_to_tril_corr(x): 3934 x = vec_to_tril_matrix(x, -1) 3935 diag = (1 - (x * x).sum(-1)).sqrt().diag_embed() 3936 x = x + diag 3937 return tril_matrix_to_vec(x @ x.T, -1) 3938 3939 for dim in range(2, 5): 3940 log_probs = [] 3941 lkj = LKJCholesky(dim, concentration=1.0, validate_args=True) 3942 for i in range(2): 3943 sample = lkj.sample() 3944 sample_tril = tril_matrix_to_vec(sample, diag=-1) 3945 log_prob = lkj.log_prob(sample) 3946 log_abs_det_jacobian = torch.slogdet( 3947 jacobian(tril_cholesky_to_tril_corr, sample_tril) 3948 ).logabsdet 3949 log_probs.append(log_prob - log_abs_det_jacobian) 3950 # for concentration=1., the density is uniform over the space of all 3951 # correlation matrices. 3952 if dim == 2: 3953 # for dim=2, pdf = 0.5 (jacobian adjustment factor is 0.) 3954 self.assertTrue( 3955 all( 3956 torch.allclose(x, torch.tensor(0.5).log(), atol=1e-10) 3957 for x in log_probs 3958 ) 3959 ) 3960 self.assertEqual(log_probs[0], log_probs[1]) 3961 invalid_sample = torch.cat([sample, sample.new_ones(1, dim)], dim=0) 3962 self.assertRaises(ValueError, lambda: lkj.log_prob(invalid_sample)) 3963 3964 def test_independent_shape(self): 3965 for Dist, params in _get_examples(): 3966 for param in params: 3967 base_dist = Dist(**param) 3968 x = base_dist.sample() 3969 base_log_prob_shape = base_dist.log_prob(x).shape 3970 for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1): 3971 indep_dist = Independent(base_dist, reinterpreted_batch_ndims) 3972 indep_log_prob_shape = base_log_prob_shape[ 3973 : len(base_log_prob_shape) - reinterpreted_batch_ndims 3974 ] 3975 self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape) 3976 self.assertEqual( 3977 indep_dist.sample().shape, base_dist.sample().shape 3978 ) 3979 self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) 3980 if indep_dist.has_rsample: 3981 self.assertEqual( 3982 indep_dist.sample().shape, base_dist.sample().shape 3983 ) 3984 try: 3985 self.assertEqual( 3986 indep_dist.enumerate_support().shape, 3987 base_dist.enumerate_support().shape, 3988 ) 3989 self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) 3990 except NotImplementedError: 3991 pass 3992 try: 3993 self.assertEqual( 3994 indep_dist.variance.shape, base_dist.variance.shape 3995 ) 3996 except NotImplementedError: 3997 pass 3998 try: 3999 self.assertEqual( 4000 indep_dist.entropy().shape, indep_log_prob_shape 4001 ) 4002 except NotImplementedError: 4003 pass 4004 4005 def test_independent_expand(self): 4006 for Dist, params in _get_examples(): 4007 for param in params: 4008 base_dist = Dist(**param) 4009 for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1): 4010 for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]: 4011 indep_dist = Independent(base_dist, reinterpreted_batch_ndims) 4012 expanded_shape = s + indep_dist.batch_shape 4013 expanded = indep_dist.expand(expanded_shape) 4014 expanded_sample = expanded.sample() 4015 expected_shape = expanded_shape + indep_dist.event_shape 4016 self.assertEqual(expanded_sample.shape, expected_shape) 4017 self.assertEqual( 4018 expanded.log_prob(expanded_sample), 4019 indep_dist.log_prob(expanded_sample), 4020 ) 4021 self.assertEqual(expanded.event_shape, indep_dist.event_shape) 4022 self.assertEqual(expanded.batch_shape, expanded_shape) 4023 4024 @set_default_dtype(torch.double) 4025 def test_cdf_icdf_inverse(self): 4026 # Tests the invertibility property on the distributions 4027 for Dist, params in _get_examples(): 4028 for i, param in enumerate(params): 4029 dist = Dist(**param) 4030 samples = dist.sample(sample_shape=(20,)) 4031 try: 4032 cdf = dist.cdf(samples) 4033 actual = dist.icdf(cdf) 4034 except NotImplementedError: 4035 continue 4036 rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples)) 4037 self.assertLess( 4038 rel_error.max(), 4039 1e-4, 4040 msg="\n".join( 4041 [ 4042 f"{Dist.__name__} example {i + 1}/{len(params)}, icdf(cdf(x)) != x", 4043 f"x = {samples}", 4044 f"cdf(x) = {cdf}", 4045 f"icdf(cdf(x)) = {actual}", 4046 ] 4047 ), 4048 ) 4049 4050 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4051 def test_gamma_log_prob_at_boundary(self): 4052 for concentration, log_prob in [(0.5, inf), (1, 0), (2, -inf)]: 4053 dist = Gamma(concentration, 1) 4054 scipy_dist = scipy.stats.gamma(concentration) 4055 self.assertAlmostEqual(dist.log_prob(0), log_prob) 4056 self.assertAlmostEqual(dist.log_prob(0), scipy_dist.logpdf(0)) 4057 4058 @set_default_dtype(torch.double) 4059 def test_cdf_log_prob(self): 4060 # Tests if the differentiation of the CDF gives the PDF at a given value 4061 for Dist, params in _get_examples(): 4062 for i, param in enumerate(params): 4063 # We do not need grads wrt params here, e.g. shape of gamma distribution. 4064 param = { 4065 key: value.detach() if isinstance(value, torch.Tensor) else value 4066 for key, value in param.items() 4067 } 4068 dist = Dist(**param) 4069 samples = dist.sample() 4070 if not dist.support.is_discrete: 4071 samples.requires_grad_() 4072 try: 4073 cdfs = dist.cdf(samples) 4074 pdfs = dist.log_prob(samples).exp() 4075 except NotImplementedError: 4076 continue 4077 cdfs_derivative = grad(cdfs.sum(), [samples])[ 4078 0 4079 ] # this should not be wrapped in torch.abs() 4080 self.assertEqual( 4081 cdfs_derivative, 4082 pdfs, 4083 msg="\n".join( 4084 [ 4085 f"{Dist.__name__} example {i + 1}/{len(params)}, d(cdf)/dx != pdf(x)", 4086 f"x = {samples}", 4087 f"cdf = {cdfs}", 4088 f"pdf = {pdfs}", 4089 f"grad(cdf) = {cdfs_derivative}", 4090 ] 4091 ), 4092 ) 4093 4094 def test_valid_parameter_broadcasting(self): 4095 # Test correct broadcasting of parameter sizes for distributions that have multiple 4096 # parameters. 4097 # example type (distribution instance, expected sample shape) 4098 valid_examples = [ 4099 (Normal(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4100 (Normal(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4101 (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4102 ( 4103 Normal( 4104 loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4105 ), 4106 (2, 2), 4107 ), 4108 (Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)), 4109 (Normal(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4110 (FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=1), (2,)), 4111 (FisherSnedecor(df1=1, df2=torch.tensor([1.0, 1.0])), (2,)), 4112 ( 4113 FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([1.0])), 4114 (2,), 4115 ), 4116 ( 4117 FisherSnedecor( 4118 df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0], [1.0]]) 4119 ), 4120 (2, 2), 4121 ), 4122 ( 4123 FisherSnedecor(df1=torch.tensor([1.0, 1.0]), df2=torch.tensor([[1.0]])), 4124 (1, 2), 4125 ), 4126 ( 4127 FisherSnedecor(df1=torch.tensor([1.0]), df2=torch.tensor([[1.0]])), 4128 (1, 1), 4129 ), 4130 (Gamma(concentration=torch.tensor([1.0, 1.0]), rate=1), (2,)), 4131 (Gamma(concentration=1, rate=torch.tensor([1.0, 1.0])), (2,)), 4132 ( 4133 Gamma( 4134 concentration=torch.tensor([1.0, 1.0]), 4135 rate=torch.tensor([[1.0], [1.0], [1.0]]), 4136 ), 4137 (3, 2), 4138 ), 4139 ( 4140 Gamma( 4141 concentration=torch.tensor([1.0, 1.0]), 4142 rate=torch.tensor([[1.0], [1.0]]), 4143 ), 4144 (2, 2), 4145 ), 4146 ( 4147 Gamma( 4148 concentration=torch.tensor([1.0, 1.0]), rate=torch.tensor([[1.0]]) 4149 ), 4150 (1, 2), 4151 ), 4152 ( 4153 Gamma(concentration=torch.tensor([1.0]), rate=torch.tensor([[1.0]])), 4154 (1, 1), 4155 ), 4156 (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4157 (Gumbel(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4158 (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4159 ( 4160 Gumbel( 4161 loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4162 ), 4163 (2, 2), 4164 ), 4165 (Gumbel(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), (1, 2)), 4166 (Gumbel(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4167 ( 4168 Kumaraswamy( 4169 concentration1=torch.tensor([1.0, 1.0]), concentration0=1.0 4170 ), 4171 (2,), 4172 ), 4173 ( 4174 Kumaraswamy(concentration1=1, concentration0=torch.tensor([1.0, 1.0])), 4175 (2,), 4176 ), 4177 ( 4178 Kumaraswamy( 4179 concentration1=torch.tensor([1.0, 1.0]), 4180 concentration0=torch.tensor([1.0]), 4181 ), 4182 (2,), 4183 ), 4184 ( 4185 Kumaraswamy( 4186 concentration1=torch.tensor([1.0, 1.0]), 4187 concentration0=torch.tensor([[1.0], [1.0]]), 4188 ), 4189 (2, 2), 4190 ), 4191 ( 4192 Kumaraswamy( 4193 concentration1=torch.tensor([1.0, 1.0]), 4194 concentration0=torch.tensor([[1.0]]), 4195 ), 4196 (1, 2), 4197 ), 4198 ( 4199 Kumaraswamy( 4200 concentration1=torch.tensor([1.0]), 4201 concentration0=torch.tensor([[1.0]]), 4202 ), 4203 (1, 1), 4204 ), 4205 (Laplace(loc=torch.tensor([0.0, 0.0]), scale=1), (2,)), 4206 (Laplace(loc=0, scale=torch.tensor([1.0, 1.0])), (2,)), 4207 (Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0])), (2,)), 4208 ( 4209 Laplace( 4210 loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0], [1.0]]) 4211 ), 4212 (2, 2), 4213 ), 4214 ( 4215 Laplace(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([[1.0]])), 4216 (1, 2), 4217 ), 4218 (Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4219 (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=1), (2,)), 4220 (Pareto(scale=1, alpha=torch.tensor([1.0, 1.0])), (2,)), 4221 (Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([1.0])), (2,)), 4222 ( 4223 Pareto( 4224 scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0], [1.0]]) 4225 ), 4226 (2, 2), 4227 ), 4228 ( 4229 Pareto(scale=torch.tensor([1.0, 1.0]), alpha=torch.tensor([[1.0]])), 4230 (1, 2), 4231 ), 4232 (Pareto(scale=torch.tensor([1.0]), alpha=torch.tensor([[1.0]])), (1, 1)), 4233 (StudentT(df=torch.tensor([1.0, 1.0]), loc=1), (2,)), 4234 (StudentT(df=1, scale=torch.tensor([1.0, 1.0])), (2,)), 4235 (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([1.0])), (2,)), 4236 ( 4237 StudentT( 4238 df=torch.tensor([1.0, 1.0]), scale=torch.tensor([[1.0], [1.0]]) 4239 ), 4240 (2, 2), 4241 ), 4242 (StudentT(df=torch.tensor([1.0, 1.0]), loc=torch.tensor([[1.0]])), (1, 2)), 4243 (StudentT(df=torch.tensor([1.0]), scale=torch.tensor([[1.0]])), (1, 1)), 4244 (StudentT(df=1.0, loc=torch.zeros(5, 1), scale=torch.ones(3)), (5, 3)), 4245 ] 4246 4247 for dist, expected_size in valid_examples: 4248 actual_size = dist.sample().size() 4249 self.assertEqual( 4250 actual_size, 4251 expected_size, 4252 msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}", 4253 ) 4254 4255 sample_shape = torch.Size((2,)) 4256 expected_size = sample_shape + expected_size 4257 actual_size = dist.sample(sample_shape).size() 4258 self.assertEqual( 4259 actual_size, 4260 expected_size, 4261 msg=f"{dist} actual size: {actual_size} != expected size: {expected_size}", 4262 ) 4263 4264 def test_invalid_parameter_broadcasting(self): 4265 # invalid broadcasting cases; should throw error 4266 # example type (distribution class, distribution params) 4267 invalid_examples = [ 4268 ( 4269 Normal, 4270 {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])}, 4271 ), 4272 ( 4273 Normal, 4274 { 4275 "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]), 4276 "scale": torch.tensor([1, 1]), 4277 }, 4278 ), 4279 ( 4280 FisherSnedecor, 4281 { 4282 "df1": torch.tensor([1, 1]), 4283 "df2": torch.tensor([1, 1, 1]), 4284 }, 4285 ), 4286 ( 4287 Gumbel, 4288 {"loc": torch.tensor([[0, 0]]), "scale": torch.tensor([1, 1, 1, 1])}, 4289 ), 4290 ( 4291 Gumbel, 4292 { 4293 "loc": torch.tensor([[[0, 0, 0], [0, 0, 0]]]), 4294 "scale": torch.tensor([1, 1]), 4295 }, 4296 ), 4297 ( 4298 Gamma, 4299 { 4300 "concentration": torch.tensor([0, 0]), 4301 "rate": torch.tensor([1, 1, 1]), 4302 }, 4303 ), 4304 ( 4305 Kumaraswamy, 4306 { 4307 "concentration1": torch.tensor([[1, 1]]), 4308 "concentration0": torch.tensor([1, 1, 1, 1]), 4309 }, 4310 ), 4311 ( 4312 Kumaraswamy, 4313 { 4314 "concentration1": torch.tensor([[[1, 1, 1], [1, 1, 1]]]), 4315 "concentration0": torch.tensor([1, 1]), 4316 }, 4317 ), 4318 (Laplace, {"loc": torch.tensor([0, 0]), "scale": torch.tensor([1, 1, 1])}), 4319 (Pareto, {"scale": torch.tensor([1, 1]), "alpha": torch.tensor([1, 1, 1])}), 4320 ( 4321 StudentT, 4322 { 4323 "df": torch.tensor([1.0, 1.0]), 4324 "scale": torch.tensor([1.0, 1.0, 1.0]), 4325 }, 4326 ), 4327 ( 4328 StudentT, 4329 {"df": torch.tensor([1.0, 1.0]), "loc": torch.tensor([1.0, 1.0, 1.0])}, 4330 ), 4331 ] 4332 4333 for dist, kwargs in invalid_examples: 4334 self.assertRaises(RuntimeError, dist, **kwargs) 4335 4336 def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite): 4337 # We cannot easily check the mode for discrete distributions, but we can look left and right 4338 # to ensure the log probability is smaller than at the mode. 4339 for step in [-1, 1]: 4340 log_prob_mode = dist.log_prob(sanitized_mode) 4341 if isinstance(dist, OneHotCategorical): 4342 idx = (dist._categorical.mode + 1) % dist.probs.shape[-1] 4343 other = torch.nn.functional.one_hot( 4344 idx, num_classes=dist.probs.shape[-1] 4345 ).to(dist.mode) 4346 else: 4347 other = dist.mode + step 4348 mask = batch_isfinite & dist.support.check(other) 4349 self.assertTrue(mask.any() or dist.mode.unique().numel() == 1) 4350 # Add a dimension to the right if the event shape is not a scalar, e.g. OneHotCategorical. 4351 other = torch.where( 4352 mask[..., None] if mask.ndim < other.ndim else mask, 4353 other, 4354 dist.sample(), 4355 ) 4356 log_prob_other = dist.log_prob(other) 4357 delta = log_prob_mode - log_prob_other 4358 self.assertTrue( 4359 (-1e-12 < delta[mask].detach()).all() 4360 ) # Allow up to 1e-12 rounding error. 4361 4362 def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite): 4363 # We perturb the mode in the unconstrained space and expect the log probability to decrease. 4364 num_points = 10 4365 transform = transform_to(dist.support) 4366 unconstrained_mode = transform.inv(sanitized_mode) 4367 perturbation = 1e-5 * ( 4368 torch.rand((num_points,) + unconstrained_mode.shape) - 0.5 4369 ) 4370 perturbed_mode = transform(perturbation + unconstrained_mode) 4371 log_prob_mode = dist.log_prob(sanitized_mode) 4372 log_prob_other = dist.log_prob(perturbed_mode) 4373 delta = log_prob_mode - log_prob_other 4374 4375 # We pass the test with a small tolerance to allow for rounding and manually set the 4376 # difference to zero if both log probs are infinite with the same sign. 4377 both_infinite_with_same_sign = (log_prob_mode == log_prob_other) & ( 4378 log_prob_mode.abs() == inf 4379 ) 4380 delta[both_infinite_with_same_sign] = 0.0 4381 ordering = (delta > -1e-12).all(axis=0) 4382 self.assertTrue(ordering[batch_isfinite].all()) 4383 4384 @set_default_dtype(torch.double) 4385 def test_mode(self): 4386 discrete_distributions = ( 4387 Bernoulli, 4388 Binomial, 4389 Categorical, 4390 Geometric, 4391 NegativeBinomial, 4392 OneHotCategorical, 4393 Poisson, 4394 ) 4395 no_mode_available = ( 4396 ContinuousBernoulli, 4397 LKJCholesky, 4398 LogisticNormal, 4399 MixtureSameFamily, 4400 Multinomial, 4401 RelaxedBernoulli, 4402 RelaxedOneHotCategorical, 4403 ) 4404 4405 for dist_cls, params in _get_examples(): 4406 for param in params: 4407 dist = dist_cls(**param) 4408 if ( 4409 isinstance(dist, no_mode_available) 4410 or type(dist) is TransformedDistribution 4411 ): 4412 with self.assertRaises(NotImplementedError): 4413 dist.mode 4414 continue 4415 4416 # Check that either all or no elements in the event shape are nan: the mode cannot be 4417 # defined for part of an event. 4418 isfinite = dist.mode.isfinite().reshape( 4419 dist.batch_shape + (dist.event_shape.numel(),) 4420 ) 4421 batch_isfinite = isfinite.all(axis=-1) 4422 self.assertTrue((batch_isfinite | ~isfinite.any(axis=-1)).all()) 4423 4424 # We sanitize undefined modes by sampling from the distribution. 4425 sanitized_mode = torch.where( 4426 ~dist.mode.isnan(), dist.mode, dist.sample() 4427 ) 4428 if isinstance(dist, discrete_distributions): 4429 self._test_discrete_distribution_mode( 4430 dist, sanitized_mode, batch_isfinite 4431 ) 4432 else: 4433 self._test_continuous_distribution_mode( 4434 dist, sanitized_mode, batch_isfinite 4435 ) 4436 4437 self.assertFalse(dist.log_prob(sanitized_mode).isnan().any()) 4438 4439 4440# These tests are only needed for a few distributions that implement custom 4441# reparameterized gradients. Most .rsample() implementations simply rely on 4442# the reparameterization trick and do not need to be tested for accuracy. 4443@skipIfTorchDynamo("Not a TorchDynamo suitable test") 4444class TestRsample(DistributionsTestCase): 4445 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4446 def test_gamma(self): 4447 num_samples = 100 4448 for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]: 4449 alphas = torch.tensor( 4450 [alpha] * num_samples, dtype=torch.float, requires_grad=True 4451 ) 4452 betas = alphas.new_ones(num_samples) 4453 x = Gamma(alphas, betas).rsample() 4454 x.sum().backward() 4455 x, ind = x.sort() 4456 x = x.detach().numpy() 4457 actual_grad = alphas.grad[ind].numpy() 4458 # Compare with expected gradient dx/dalpha along constant cdf(x,alpha). 4459 cdf = scipy.stats.gamma.cdf 4460 pdf = scipy.stats.gamma.pdf 4461 eps = 0.01 * alpha / (1.0 + alpha**0.5) 4462 cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps) 4463 cdf_x = pdf(x, alpha) 4464 expected_grad = -cdf_alpha / cdf_x 4465 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4466 self.assertLess( 4467 np.max(rel_error), 4468 0.0005, 4469 "\n".join( 4470 [ 4471 f"Bad gradient dx/alpha for x ~ Gamma({alpha}, 1)", 4472 f"x {x}", 4473 f"expected {expected_grad}", 4474 f"actual {actual_grad}", 4475 f"rel error {rel_error}", 4476 f"max error {rel_error.max()}", 4477 f"at alpha={alpha}, x={x[rel_error.argmax()]}", 4478 ] 4479 ), 4480 ) 4481 4482 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4483 def test_chi2(self): 4484 num_samples = 100 4485 for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]: 4486 dfs = torch.tensor( 4487 [df] * num_samples, dtype=torch.float, requires_grad=True 4488 ) 4489 x = Chi2(dfs).rsample() 4490 x.sum().backward() 4491 x, ind = x.sort() 4492 x = x.detach().numpy() 4493 actual_grad = dfs.grad[ind].numpy() 4494 # Compare with expected gradient dx/ddf along constant cdf(x,df). 4495 cdf = scipy.stats.chi2.cdf 4496 pdf = scipy.stats.chi2.pdf 4497 eps = 0.01 * df / (1.0 + df**0.5) 4498 cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps) 4499 cdf_x = pdf(x, df) 4500 expected_grad = -cdf_df / cdf_x 4501 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4502 self.assertLess( 4503 np.max(rel_error), 4504 0.001, 4505 "\n".join( 4506 [ 4507 f"Bad gradient dx/ddf for x ~ Chi2({df})", 4508 f"x {x}", 4509 f"expected {expected_grad}", 4510 f"actual {actual_grad}", 4511 f"rel error {rel_error}", 4512 f"max error {rel_error.max()}", 4513 ] 4514 ), 4515 ) 4516 4517 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4518 def test_dirichlet_on_diagonal(self): 4519 num_samples = 20 4520 grid = [1e-1, 1e0, 1e1] 4521 for a0, a1, a2 in product(grid, grid, grid): 4522 alphas = torch.tensor( 4523 [[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True 4524 ) 4525 x = Dirichlet(alphas).rsample()[:, 0] 4526 x.sum().backward() 4527 x, ind = x.sort() 4528 x = x.detach().numpy() 4529 actual_grad = alphas.grad[ind].numpy()[:, 0] 4530 # Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha). 4531 # This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]). 4532 cdf = scipy.stats.beta.cdf 4533 pdf = scipy.stats.beta.pdf 4534 alpha, beta = a0, a1 + a2 4535 eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) 4536 cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / ( 4537 2 * eps 4538 ) 4539 cdf_x = pdf(x, alpha, beta) 4540 expected_grad = -cdf_alpha / cdf_x 4541 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4542 self.assertLess( 4543 np.max(rel_error), 4544 0.001, 4545 "\n".join( 4546 [ 4547 f"Bad gradient dx[0]/dalpha[0] for Dirichlet([{a0}, {a1}, {a2}])", 4548 f"x {x}", 4549 f"expected {expected_grad}", 4550 f"actual {actual_grad}", 4551 f"rel error {rel_error}", 4552 f"max error {rel_error.max()}", 4553 f"at x={x[rel_error.argmax()]}", 4554 ] 4555 ), 4556 ) 4557 4558 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4559 def test_beta_wrt_alpha(self): 4560 num_samples = 20 4561 grid = [1e-2, 1e-1, 1e0, 1e1, 1e2] 4562 for con1, con0 in product(grid, grid): 4563 con1s = torch.tensor( 4564 [con1] * num_samples, dtype=torch.float, requires_grad=True 4565 ) 4566 con0s = con1s.new_tensor([con0] * num_samples) 4567 x = Beta(con1s, con0s).rsample() 4568 x.sum().backward() 4569 x, ind = x.sort() 4570 x = x.detach().numpy() 4571 actual_grad = con1s.grad[ind].numpy() 4572 # Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0). 4573 cdf = scipy.stats.beta.cdf 4574 pdf = scipy.stats.beta.pdf 4575 eps = 0.01 * con1 / (1.0 + np.sqrt(con1)) 4576 cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / ( 4577 2 * eps 4578 ) 4579 cdf_x = pdf(x, con1, con0) 4580 expected_grad = -cdf_alpha / cdf_x 4581 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4582 self.assertLess( 4583 np.max(rel_error), 4584 0.005, 4585 "\n".join( 4586 [ 4587 f"Bad gradient dx/dcon1 for x ~ Beta({con1}, {con0})", 4588 f"x {x}", 4589 f"expected {expected_grad}", 4590 f"actual {actual_grad}", 4591 f"rel error {rel_error}", 4592 f"max error {rel_error.max()}", 4593 f"at x = {x[rel_error.argmax()]}", 4594 ] 4595 ), 4596 ) 4597 4598 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 4599 def test_beta_wrt_beta(self): 4600 num_samples = 20 4601 grid = [1e-2, 1e-1, 1e0, 1e1, 1e2] 4602 for con1, con0 in product(grid, grid): 4603 con0s = torch.tensor( 4604 [con0] * num_samples, dtype=torch.float, requires_grad=True 4605 ) 4606 con1s = con0s.new_tensor([con1] * num_samples) 4607 x = Beta(con1s, con0s).rsample() 4608 x.sum().backward() 4609 x, ind = x.sort() 4610 x = x.detach().numpy() 4611 actual_grad = con0s.grad[ind].numpy() 4612 # Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0). 4613 cdf = scipy.stats.beta.cdf 4614 pdf = scipy.stats.beta.pdf 4615 eps = 0.01 * con0 / (1.0 + np.sqrt(con0)) 4616 cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps) 4617 cdf_x = pdf(x, con1, con0) 4618 expected_grad = -cdf_beta / cdf_x 4619 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30) 4620 self.assertLess( 4621 np.max(rel_error), 4622 0.005, 4623 "\n".join( 4624 [ 4625 f"Bad gradient dx/dcon0 for x ~ Beta({con1}, {con0})", 4626 f"x {x}", 4627 f"expected {expected_grad}", 4628 f"actual {actual_grad}", 4629 f"rel error {rel_error}", 4630 f"max error {rel_error.max()}", 4631 f"at x = {x[rel_error.argmax()]!r}", 4632 ] 4633 ), 4634 ) 4635 4636 def test_dirichlet_multivariate(self): 4637 alpha_crit = 0.25 * (5.0**0.5 - 1.0) 4638 num_samples = 100000 4639 for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]: 4640 alpha = alpha_crit + shift 4641 alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True) 4642 alpha_vec = torch.cat([alpha, alpha, alpha.new([1])]) 4643 z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample() 4644 mean_z3 = 1.0 / (2.0 * alpha + 1.0) 4645 loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean() 4646 actual_grad = grad(loss, [alpha])[0] 4647 # Compute expected gradient by hand. 4648 num = 1.0 - 2.0 * alpha - 4.0 * alpha**2 4649 den = (1.0 + alpha) ** 2 * (1.0 + 2.0 * alpha) ** 3 4650 expected_grad = num / den 4651 self.assertEqual( 4652 actual_grad, 4653 expected_grad, 4654 atol=0.002, 4655 rtol=0, 4656 msg="\n".join( 4657 [ 4658 "alpha = alpha_c + %.2g" % shift, # noqa: UP031 4659 "expected_grad: %.5g" % expected_grad, # noqa: UP031 4660 "actual_grad: %.5g" % actual_grad, # noqa: UP031 4661 "error = %.2g" # noqa: UP031 4662 % torch.abs(expected_grad - actual_grad).max(), # noqa: UP031 4663 ] 4664 ), 4665 ) 4666 4667 @set_default_dtype(torch.double) 4668 def test_dirichlet_tangent_field(self): 4669 num_samples = 20 4670 alpha_grid = [0.5, 1.0, 2.0] 4671 4672 # v = dx/dalpha[0] is the reparameterized gradient aka tangent field. 4673 def compute_v(x, alpha): 4674 return torch.stack( 4675 [ 4676 _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0] 4677 for i in range(3) 4678 ], 4679 dim=-1, 4680 ) 4681 4682 for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid): 4683 alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand( 4684 num_samples, 3 4685 ) 4686 x = Dirichlet(alpha).rsample() 4687 dlogp_da = grad( 4688 [Dirichlet(alpha).log_prob(x.detach()).sum()], 4689 [alpha], 4690 retain_graph=True, 4691 )[0][:, 0] 4692 dlogp_dx = grad( 4693 [Dirichlet(alpha.detach()).log_prob(x).sum()], [x], retain_graph=True 4694 )[0] 4695 v = torch.stack( 4696 [ 4697 grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0] 4698 for i in range(3) 4699 ], 4700 dim=-1, 4701 ) 4702 # Compute ramaining properties by finite difference. 4703 self.assertEqual(compute_v(x, alpha), v, msg="Bug in compute_v() helper") 4704 # dx is an arbitrary orthonormal basis tangent to the simplex. 4705 dx = torch.tensor([[2.0, -1.0, -1.0], [0.0, 1.0, -1.0]]) 4706 dx /= dx.norm(2, -1, True) 4707 eps = 1e-2 * x.min(-1, True)[0] # avoid boundary 4708 dv0 = ( 4709 compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha) 4710 ) / (2 * eps) 4711 dv1 = ( 4712 compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha) 4713 ) / (2 * eps) 4714 div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1) 4715 # This is a modification of the standard continuity equation, using the product rule to allow 4716 # expression in terms of log_prob rather than the less numerically stable log_prob.exp(). 4717 error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v 4718 self.assertLess( 4719 torch.abs(error).max(), 4720 0.005, 4721 "\n".join( 4722 [ 4723 f"Dirichlet([{a1}, {a2}, {a3}]) gradient violates continuity equation:", 4724 f"error = {error}", 4725 ] 4726 ), 4727 ) 4728 4729 4730class TestDistributionShapes(DistributionsTestCase): 4731 def setUp(self): 4732 super().setUp() 4733 self.scalar_sample = 1 4734 self.tensor_sample_1 = torch.ones(3, 2) 4735 self.tensor_sample_2 = torch.ones(3, 2, 3) 4736 4737 def test_entropy_shape(self): 4738 for Dist, params in _get_examples(): 4739 for i, param in enumerate(params): 4740 dist = Dist(validate_args=False, **param) 4741 try: 4742 actual_shape = dist.entropy().size() 4743 expected_shape = ( 4744 dist.batch_shape if dist.batch_shape else torch.Size() 4745 ) 4746 message = f"{Dist.__name__} example {i + 1}/{len(params)}, shape mismatch. expected {expected_shape}, actual {actual_shape}" # noqa: B950 4747 self.assertEqual(actual_shape, expected_shape, msg=message) 4748 except NotImplementedError: 4749 continue 4750 4751 def test_bernoulli_shape_scalar_params(self): 4752 bernoulli = Bernoulli(0.3) 4753 self.assertEqual(bernoulli._batch_shape, torch.Size()) 4754 self.assertEqual(bernoulli._event_shape, torch.Size()) 4755 self.assertEqual(bernoulli.sample().size(), torch.Size()) 4756 self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2))) 4757 self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample) 4758 self.assertEqual( 4759 bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4760 ) 4761 self.assertEqual( 4762 bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4763 ) 4764 4765 def test_bernoulli_shape_tensor_params(self): 4766 bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4767 self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2))) 4768 self.assertEqual(bernoulli._event_shape, torch.Size(())) 4769 self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2))) 4770 self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4771 self.assertEqual( 4772 bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4773 ) 4774 self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2) 4775 self.assertEqual( 4776 bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4777 ) 4778 4779 def test_geometric_shape_scalar_params(self): 4780 geometric = Geometric(0.3) 4781 self.assertEqual(geometric._batch_shape, torch.Size()) 4782 self.assertEqual(geometric._event_shape, torch.Size()) 4783 self.assertEqual(geometric.sample().size(), torch.Size()) 4784 self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2))) 4785 self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample) 4786 self.assertEqual( 4787 geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4788 ) 4789 self.assertEqual( 4790 geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4791 ) 4792 4793 def test_geometric_shape_tensor_params(self): 4794 geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4795 self.assertEqual(geometric._batch_shape, torch.Size((3, 2))) 4796 self.assertEqual(geometric._event_shape, torch.Size(())) 4797 self.assertEqual(geometric.sample().size(), torch.Size((3, 2))) 4798 self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4799 self.assertEqual( 4800 geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4801 ) 4802 self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2) 4803 self.assertEqual( 4804 geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4805 ) 4806 4807 def test_beta_shape_scalar_params(self): 4808 dist = Beta(0.1, 0.1) 4809 self.assertEqual(dist._batch_shape, torch.Size()) 4810 self.assertEqual(dist._event_shape, torch.Size()) 4811 self.assertEqual(dist.sample().size(), torch.Size()) 4812 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2))) 4813 self.assertRaises(ValueError, dist.log_prob, self.scalar_sample) 4814 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4815 self.assertEqual( 4816 dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4817 ) 4818 4819 def test_beta_shape_tensor_params(self): 4820 dist = Beta( 4821 torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), 4822 torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), 4823 ) 4824 self.assertEqual(dist._batch_shape, torch.Size((3, 2))) 4825 self.assertEqual(dist._event_shape, torch.Size(())) 4826 self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4827 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4828 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4829 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4830 self.assertEqual( 4831 dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)) 4832 ) 4833 4834 def test_binomial_shape(self): 4835 dist = Binomial(10, torch.tensor([0.6, 0.3])) 4836 self.assertEqual(dist._batch_shape, torch.Size((2,))) 4837 self.assertEqual(dist._event_shape, torch.Size(())) 4838 self.assertEqual(dist.sample().size(), torch.Size((2,))) 4839 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2))) 4840 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4841 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4842 4843 def test_binomial_shape_vectorized_n(self): 4844 dist = Binomial( 4845 torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]) 4846 ) 4847 self.assertEqual(dist._batch_shape, torch.Size((2, 3))) 4848 self.assertEqual(dist._event_shape, torch.Size(())) 4849 self.assertEqual(dist.sample().size(), torch.Size((2, 3))) 4850 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3))) 4851 self.assertEqual( 4852 dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4853 ) 4854 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4855 4856 def test_multinomial_shape(self): 4857 dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4858 self.assertEqual(dist._batch_shape, torch.Size((3,))) 4859 self.assertEqual(dist._event_shape, torch.Size((2,))) 4860 self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4861 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4862 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) 4863 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4864 self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3))) 4865 4866 def test_categorical_shape(self): 4867 # unbatched 4868 dist = Categorical(torch.tensor([0.6, 0.3, 0.1])) 4869 self.assertEqual(dist._batch_shape, torch.Size(())) 4870 self.assertEqual(dist._event_shape, torch.Size(())) 4871 self.assertEqual(dist.sample().size(), torch.Size()) 4872 self.assertEqual( 4873 dist.sample((3, 2)).size(), 4874 torch.Size( 4875 ( 4876 3, 4877 2, 4878 ) 4879 ), 4880 ) 4881 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 4882 self.assertEqual( 4883 dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4884 ) 4885 self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1))) 4886 # batched 4887 dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4888 self.assertEqual(dist._batch_shape, torch.Size((3,))) 4889 self.assertEqual(dist._event_shape, torch.Size(())) 4890 self.assertEqual(dist.sample().size(), torch.Size((3,))) 4891 self.assertEqual( 4892 dist.sample((3, 2)).size(), 4893 torch.Size( 4894 ( 4895 3, 4896 2, 4897 3, 4898 ) 4899 ), 4900 ) 4901 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4902 self.assertEqual( 4903 dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4904 ) 4905 self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3))) 4906 4907 def test_one_hot_categorical_shape(self): 4908 # unbatched 4909 dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1])) 4910 self.assertEqual(dist._batch_shape, torch.Size(())) 4911 self.assertEqual(dist._event_shape, torch.Size((3,))) 4912 self.assertEqual(dist.sample().size(), torch.Size((3,))) 4913 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3))) 4914 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) 4915 sample = torch.tensor([0.0, 1.0, 0.0]).expand(3, 2, 3) 4916 self.assertEqual( 4917 dist.log_prob(sample).size(), 4918 torch.Size( 4919 ( 4920 3, 4921 2, 4922 ) 4923 ), 4924 ) 4925 self.assertEqual( 4926 dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)) 4927 ) 4928 sample = torch.eye(3) 4929 self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) 4930 # batched 4931 dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) 4932 self.assertEqual(dist._batch_shape, torch.Size((3,))) 4933 self.assertEqual(dist._event_shape, torch.Size((2,))) 4934 self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 4935 self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) 4936 sample = torch.tensor([0.0, 1.0]) 4937 self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,))) 4938 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 4939 self.assertEqual( 4940 dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)) 4941 ) 4942 sample = torch.tensor([0.0, 1.0]).expand(3, 1, 2) 4943 self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3))) 4944 4945 def test_cauchy_shape_scalar_params(self): 4946 cauchy = Cauchy(0, 1) 4947 self.assertEqual(cauchy._batch_shape, torch.Size()) 4948 self.assertEqual(cauchy._event_shape, torch.Size()) 4949 self.assertEqual(cauchy.sample().size(), torch.Size()) 4950 self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) 4951 self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample) 4952 self.assertEqual( 4953 cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4954 ) 4955 self.assertEqual( 4956 cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4957 ) 4958 4959 def test_cauchy_shape_tensor_params(self): 4960 cauchy = Cauchy(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 4961 self.assertEqual(cauchy._batch_shape, torch.Size((2,))) 4962 self.assertEqual(cauchy._event_shape, torch.Size(())) 4963 self.assertEqual(cauchy.sample().size(), torch.Size((2,))) 4964 self.assertEqual( 4965 cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 4966 ) 4967 self.assertEqual( 4968 cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4969 ) 4970 self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2) 4971 self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 4972 4973 def test_halfcauchy_shape_scalar_params(self): 4974 halfcauchy = HalfCauchy(1) 4975 self.assertEqual(halfcauchy._batch_shape, torch.Size()) 4976 self.assertEqual(halfcauchy._event_shape, torch.Size()) 4977 self.assertEqual(halfcauchy.sample().size(), torch.Size()) 4978 self.assertEqual( 4979 halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)) 4980 ) 4981 self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample) 4982 self.assertEqual( 4983 halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4984 ) 4985 self.assertEqual( 4986 halfcauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 4987 ) 4988 4989 def test_halfcauchy_shape_tensor_params(self): 4990 halfcauchy = HalfCauchy(torch.tensor([1.0, 1.0])) 4991 self.assertEqual(halfcauchy._batch_shape, torch.Size((2,))) 4992 self.assertEqual(halfcauchy._event_shape, torch.Size(())) 4993 self.assertEqual(halfcauchy.sample().size(), torch.Size((2,))) 4994 self.assertEqual( 4995 halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 4996 ) 4997 self.assertEqual( 4998 halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 4999 ) 5000 self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2) 5001 self.assertEqual( 5002 halfcauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)) 5003 ) 5004 5005 def test_dirichlet_shape(self): 5006 dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]])) 5007 self.assertEqual(dist._batch_shape, torch.Size((3,))) 5008 self.assertEqual(dist._event_shape, torch.Size((2,))) 5009 self.assertEqual(dist.sample().size(), torch.Size((3, 2))) 5010 self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2))) 5011 simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum( 5012 -1, keepdim=True 5013 ) 5014 self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,))) 5015 self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) 5016 simplex_sample = torch.ones(3, 1, 2) 5017 simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1) 5018 self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3))) 5019 5020 def test_mixture_same_family_shape(self): 5021 dist = MixtureSameFamily( 5022 Categorical(torch.rand(5)), Normal(torch.randn(5), torch.rand(5)) 5023 ) 5024 self.assertEqual(dist._batch_shape, torch.Size()) 5025 self.assertEqual(dist._event_shape, torch.Size()) 5026 self.assertEqual(dist.sample().size(), torch.Size()) 5027 self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4))) 5028 self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5029 self.assertEqual( 5030 dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5031 ) 5032 5033 def test_gamma_shape_scalar_params(self): 5034 gamma = Gamma(1, 1) 5035 self.assertEqual(gamma._batch_shape, torch.Size()) 5036 self.assertEqual(gamma._event_shape, torch.Size()) 5037 self.assertEqual(gamma.sample().size(), torch.Size()) 5038 self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2))) 5039 self.assertEqual(gamma.log_prob(self.scalar_sample).size(), torch.Size()) 5040 self.assertEqual( 5041 gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5042 ) 5043 self.assertEqual( 5044 gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5045 ) 5046 5047 def test_gamma_shape_tensor_params(self): 5048 gamma = Gamma(torch.tensor([1.0, 1.0]), torch.tensor([1.0, 1.0])) 5049 self.assertEqual(gamma._batch_shape, torch.Size((2,))) 5050 self.assertEqual(gamma._event_shape, torch.Size(())) 5051 self.assertEqual(gamma.sample().size(), torch.Size((2,))) 5052 self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5053 self.assertEqual( 5054 gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5055 ) 5056 self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2) 5057 self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5058 5059 def test_chi2_shape_scalar_params(self): 5060 chi2 = Chi2(1) 5061 self.assertEqual(chi2._batch_shape, torch.Size()) 5062 self.assertEqual(chi2._event_shape, torch.Size()) 5063 self.assertEqual(chi2.sample().size(), torch.Size()) 5064 self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2))) 5065 self.assertEqual(chi2.log_prob(self.scalar_sample).size(), torch.Size()) 5066 self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5067 self.assertEqual( 5068 chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5069 ) 5070 5071 def test_chi2_shape_tensor_params(self): 5072 chi2 = Chi2(torch.tensor([1.0, 1.0])) 5073 self.assertEqual(chi2._batch_shape, torch.Size((2,))) 5074 self.assertEqual(chi2._event_shape, torch.Size(())) 5075 self.assertEqual(chi2.sample().size(), torch.Size((2,))) 5076 self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5077 self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5078 self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2) 5079 self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5080 5081 def test_studentT_shape_scalar_params(self): 5082 st = StudentT(1) 5083 self.assertEqual(st._batch_shape, torch.Size()) 5084 self.assertEqual(st._event_shape, torch.Size()) 5085 self.assertEqual(st.sample().size(), torch.Size()) 5086 self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2))) 5087 self.assertRaises(ValueError, st.log_prob, self.scalar_sample) 5088 self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5089 self.assertEqual( 5090 st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5091 ) 5092 5093 def test_studentT_shape_tensor_params(self): 5094 st = StudentT(torch.tensor([1.0, 1.0])) 5095 self.assertEqual(st._batch_shape, torch.Size((2,))) 5096 self.assertEqual(st._event_shape, torch.Size(())) 5097 self.assertEqual(st.sample().size(), torch.Size((2,))) 5098 self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5099 self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) 5100 self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2) 5101 self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5102 5103 def test_pareto_shape_scalar_params(self): 5104 pareto = Pareto(1, 1) 5105 self.assertEqual(pareto._batch_shape, torch.Size()) 5106 self.assertEqual(pareto._event_shape, torch.Size()) 5107 self.assertEqual(pareto.sample().size(), torch.Size()) 5108 self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2))) 5109 self.assertEqual( 5110 pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2)) 5111 ) 5112 self.assertEqual( 5113 pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3)) 5114 ) 5115 5116 def test_gumbel_shape_scalar_params(self): 5117 gumbel = Gumbel(1, 1) 5118 self.assertEqual(gumbel._batch_shape, torch.Size()) 5119 self.assertEqual(gumbel._event_shape, torch.Size()) 5120 self.assertEqual(gumbel.sample().size(), torch.Size()) 5121 self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2))) 5122 self.assertEqual( 5123 gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5124 ) 5125 self.assertEqual( 5126 gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5127 ) 5128 5129 def test_kumaraswamy_shape_scalar_params(self): 5130 kumaraswamy = Kumaraswamy(1, 1) 5131 self.assertEqual(kumaraswamy._batch_shape, torch.Size()) 5132 self.assertEqual(kumaraswamy._event_shape, torch.Size()) 5133 self.assertEqual(kumaraswamy.sample().size(), torch.Size()) 5134 self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2))) 5135 self.assertEqual( 5136 kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5137 ) 5138 self.assertEqual( 5139 kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5140 ) 5141 5142 def test_vonmises_shape_tensor_params(self): 5143 von_mises = VonMises(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5144 self.assertEqual(von_mises._batch_shape, torch.Size((2,))) 5145 self.assertEqual(von_mises._event_shape, torch.Size(())) 5146 self.assertEqual(von_mises.sample().size(), torch.Size((2,))) 5147 self.assertEqual( 5148 von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 5149 ) 5150 self.assertEqual( 5151 von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5152 ) 5153 self.assertEqual( 5154 von_mises.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)) 5155 ) 5156 5157 def test_vonmises_shape_scalar_params(self): 5158 von_mises = VonMises(0.0, 1.0) 5159 self.assertEqual(von_mises._batch_shape, torch.Size()) 5160 self.assertEqual(von_mises._event_shape, torch.Size()) 5161 self.assertEqual(von_mises.sample().size(), torch.Size()) 5162 self.assertEqual( 5163 von_mises.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)) 5164 ) 5165 self.assertEqual( 5166 von_mises.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5167 ) 5168 self.assertEqual( 5169 von_mises.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5170 ) 5171 5172 def test_weibull_scale_scalar_params(self): 5173 weibull = Weibull(1, 1) 5174 self.assertEqual(weibull._batch_shape, torch.Size()) 5175 self.assertEqual(weibull._event_shape, torch.Size()) 5176 self.assertEqual(weibull.sample().size(), torch.Size()) 5177 self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2))) 5178 self.assertEqual( 5179 weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5180 ) 5181 self.assertEqual( 5182 weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5183 ) 5184 5185 def test_wishart_shape_scalar_params(self): 5186 wishart = Wishart(torch.tensor(1), torch.tensor([[1.0]])) 5187 self.assertEqual(wishart._batch_shape, torch.Size()) 5188 self.assertEqual(wishart._event_shape, torch.Size((1, 1))) 5189 self.assertEqual(wishart.sample().size(), torch.Size((1, 1))) 5190 self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1))) 5191 self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample) 5192 5193 def test_wishart_shape_tensor_params(self): 5194 wishart = Wishart(torch.tensor([1.0, 1.0]), torch.tensor([[[1.0]], [[1.0]]])) 5195 self.assertEqual(wishart._batch_shape, torch.Size((2,))) 5196 self.assertEqual(wishart._event_shape, torch.Size((1, 1))) 5197 self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1))) 5198 self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1))) 5199 self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2) 5200 self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,))) 5201 5202 def test_normal_shape_scalar_params(self): 5203 normal = Normal(0, 1) 5204 self.assertEqual(normal._batch_shape, torch.Size()) 5205 self.assertEqual(normal._event_shape, torch.Size()) 5206 self.assertEqual(normal.sample().size(), torch.Size()) 5207 self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2))) 5208 self.assertRaises(ValueError, normal.log_prob, self.scalar_sample) 5209 self.assertEqual( 5210 normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5211 ) 5212 self.assertEqual( 5213 normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5214 ) 5215 5216 def test_normal_shape_tensor_params(self): 5217 normal = Normal(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5218 self.assertEqual(normal._batch_shape, torch.Size((2,))) 5219 self.assertEqual(normal._event_shape, torch.Size(())) 5220 self.assertEqual(normal.sample().size(), torch.Size((2,))) 5221 self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5222 self.assertEqual( 5223 normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5224 ) 5225 self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2) 5226 self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5227 5228 def test_uniform_shape_scalar_params(self): 5229 uniform = Uniform(0, 1) 5230 self.assertEqual(uniform._batch_shape, torch.Size()) 5231 self.assertEqual(uniform._event_shape, torch.Size()) 5232 self.assertEqual(uniform.sample().size(), torch.Size()) 5233 self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) 5234 self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample) 5235 self.assertEqual( 5236 uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5237 ) 5238 self.assertEqual( 5239 uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5240 ) 5241 5242 def test_uniform_shape_tensor_params(self): 5243 uniform = Uniform(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5244 self.assertEqual(uniform._batch_shape, torch.Size((2,))) 5245 self.assertEqual(uniform._event_shape, torch.Size(())) 5246 self.assertEqual(uniform.sample().size(), torch.Size((2,))) 5247 self.assertEqual( 5248 uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)) 5249 ) 5250 self.assertEqual( 5251 uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5252 ) 5253 self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2) 5254 self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5255 5256 def test_exponential_shape_scalar_param(self): 5257 expon = Exponential(1.0) 5258 self.assertEqual(expon._batch_shape, torch.Size()) 5259 self.assertEqual(expon._event_shape, torch.Size()) 5260 self.assertEqual(expon.sample().size(), torch.Size()) 5261 self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2))) 5262 self.assertRaises(ValueError, expon.log_prob, self.scalar_sample) 5263 self.assertEqual( 5264 expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5265 ) 5266 self.assertEqual( 5267 expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5268 ) 5269 5270 def test_exponential_shape_tensor_param(self): 5271 expon = Exponential(torch.tensor([1.0, 1.0])) 5272 self.assertEqual(expon._batch_shape, torch.Size((2,))) 5273 self.assertEqual(expon._event_shape, torch.Size(())) 5274 self.assertEqual(expon.sample().size(), torch.Size((2,))) 5275 self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5276 self.assertEqual( 5277 expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5278 ) 5279 self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2) 5280 self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2))) 5281 5282 def test_laplace_shape_scalar_params(self): 5283 laplace = Laplace(0, 1) 5284 self.assertEqual(laplace._batch_shape, torch.Size()) 5285 self.assertEqual(laplace._event_shape, torch.Size()) 5286 self.assertEqual(laplace.sample().size(), torch.Size()) 5287 self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2))) 5288 self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample) 5289 self.assertEqual( 5290 laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5291 ) 5292 self.assertEqual( 5293 laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)) 5294 ) 5295 5296 def test_laplace_shape_tensor_params(self): 5297 laplace = Laplace(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0])) 5298 self.assertEqual(laplace._batch_shape, torch.Size((2,))) 5299 self.assertEqual(laplace._event_shape, torch.Size(())) 5300 self.assertEqual(laplace.sample().size(), torch.Size((2,))) 5301 self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2))) 5302 self.assertEqual( 5303 laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)) 5304 ) 5305 self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2) 5306 self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) 5307 5308 def test_continuous_bernoulli_shape_scalar_params(self): 5309 continuous_bernoulli = ContinuousBernoulli(0.3) 5310 self.assertEqual(continuous_bernoulli._batch_shape, torch.Size()) 5311 self.assertEqual(continuous_bernoulli._event_shape, torch.Size()) 5312 self.assertEqual(continuous_bernoulli.sample().size(), torch.Size()) 5313 self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2))) 5314 self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample) 5315 self.assertEqual( 5316 continuous_bernoulli.log_prob(self.tensor_sample_1).size(), 5317 torch.Size((3, 2)), 5318 ) 5319 self.assertEqual( 5320 continuous_bernoulli.log_prob(self.tensor_sample_2).size(), 5321 torch.Size((3, 2, 3)), 5322 ) 5323 5324 def test_continuous_bernoulli_shape_tensor_params(self): 5325 continuous_bernoulli = ContinuousBernoulli( 5326 torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]) 5327 ) 5328 self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2))) 5329 self.assertEqual(continuous_bernoulli._event_shape, torch.Size(())) 5330 self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2))) 5331 self.assertEqual( 5332 continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)) 5333 ) 5334 self.assertEqual( 5335 continuous_bernoulli.log_prob(self.tensor_sample_1).size(), 5336 torch.Size((3, 2)), 5337 ) 5338 self.assertRaises( 5339 ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2 5340 ) 5341 self.assertEqual( 5342 continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), 5343 torch.Size((3, 3, 2)), 5344 ) 5345 5346 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 5347 def test_mixture_same_family_mean_shape(self): 5348 mix_distribution = Categorical(torch.ones([3, 1, 3])) 5349 component_distribution = Normal(torch.zeros([3, 3, 3]), torch.ones([3, 3, 3])) 5350 gmm = MixtureSameFamily(mix_distribution, component_distribution) 5351 self.assertEqual(len(gmm.mean.shape), 2) 5352 5353 5354@skipIfTorchDynamo("Not a TorchDynamo suitable test") 5355class TestKL(DistributionsTestCase): 5356 def setUp(self): 5357 super().setUp() 5358 5359 class Binomial30(Binomial): 5360 def __init__(self, probs): 5361 super().__init__(30, probs) 5362 5363 # These are pairs of distributions with 4 x 4 parameters as specified. 5364 # The first of the pair e.g. bernoulli[0] varies column-wise and the second 5365 # e.g. bernoulli[1] varies row-wise; that way we test all param pairs. 5366 bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9]) 5367 binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9]) 5368 binomial_vectorized_count = ( 5369 Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), 5370 Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])), 5371 ) 5372 beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) 5373 categorical = pairwise( 5374 Categorical, 5375 [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], 5376 ) 5377 cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5378 chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0]) 5379 dirichlet = pairwise( 5380 Dirichlet, 5381 [[0.1, 0.2, 0.7], [0.5, 0.4, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]], 5382 ) 5383 exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0]) 5384 gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) 5385 gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) 5386 halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0]) 5387 inversegamma = pairwise( 5388 InverseGamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5] 5389 ) 5390 laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) 5391 lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5392 normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) 5393 independent = (Independent(normal[0], 1), Independent(normal[1], 1)) 5394 onehotcategorical = pairwise( 5395 OneHotCategorical, 5396 [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], 5397 ) 5398 pareto = ( 5399 Pareto( 5400 torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4), 5401 torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), 5402 ), 5403 Pareto( 5404 torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4), 5405 torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), 5406 ), 5407 ) 5408 poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0]) 5409 uniform_within_unit = pairwise( 5410 Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8] 5411 ) 5412 uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7]) 5413 uniform_real = pairwise(Uniform, [-2.0, -1, 0, 2], [-1.0, 1, 1, 4]) 5414 uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5]) 5415 continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9]) 5416 5417 # These tests should pass with precision = 0.01, but that makes tests very expensive. 5418 # Instead, we test with precision = 0.1 and only test with higher precision locally 5419 # when adding a new KL implementation. 5420 # The following pairs are not tested due to very high variance of the monte carlo 5421 # estimator; their implementations have been reviewed with extra care: 5422 # - (pareto, normal) 5423 self.precision = 0.1 # Set this to 0.01 when testing a new KL implementation. 5424 self.max_samples = int(1e07) # Increase this when testing at smaller precision. 5425 self.samples_per_batch = int(1e04) 5426 self.finite_examples = [ 5427 (bernoulli, bernoulli), 5428 (bernoulli, poisson), 5429 (beta, beta), 5430 (beta, chi2), 5431 (beta, exponential), 5432 (beta, gamma), 5433 (beta, normal), 5434 (binomial30, binomial30), 5435 (binomial_vectorized_count, binomial_vectorized_count), 5436 (categorical, categorical), 5437 (cauchy, cauchy), 5438 (chi2, chi2), 5439 (chi2, exponential), 5440 (chi2, gamma), 5441 (chi2, normal), 5442 (dirichlet, dirichlet), 5443 (exponential, chi2), 5444 (exponential, exponential), 5445 (exponential, gamma), 5446 (exponential, gumbel), 5447 (exponential, normal), 5448 (gamma, chi2), 5449 (gamma, exponential), 5450 (gamma, gamma), 5451 (gamma, gumbel), 5452 (gamma, normal), 5453 (gumbel, gumbel), 5454 (gumbel, normal), 5455 (halfnormal, halfnormal), 5456 (independent, independent), 5457 (inversegamma, inversegamma), 5458 (laplace, laplace), 5459 (lognormal, lognormal), 5460 (laplace, normal), 5461 (normal, gumbel), 5462 (normal, laplace), 5463 (normal, normal), 5464 (onehotcategorical, onehotcategorical), 5465 (pareto, chi2), 5466 (pareto, pareto), 5467 (pareto, exponential), 5468 (pareto, gamma), 5469 (poisson, poisson), 5470 (uniform_within_unit, beta), 5471 (uniform_positive, chi2), 5472 (uniform_positive, exponential), 5473 (uniform_positive, gamma), 5474 (uniform_real, gumbel), 5475 (uniform_real, normal), 5476 (uniform_pareto, pareto), 5477 (continuous_bernoulli, continuous_bernoulli), 5478 (continuous_bernoulli, exponential), 5479 (continuous_bernoulli, normal), 5480 (beta, continuous_bernoulli), 5481 ] 5482 5483 self.infinite_examples = [ 5484 (Bernoulli(0), Bernoulli(1)), 5485 (Bernoulli(1), Bernoulli(0)), 5486 ( 5487 Categorical(torch.tensor([0.9, 0.1])), 5488 Categorical(torch.tensor([1.0, 0.0])), 5489 ), 5490 ( 5491 Categorical(torch.tensor([[0.9, 0.1], [0.9, 0.1]])), 5492 Categorical(torch.tensor([1.0, 0.0])), 5493 ), 5494 (Beta(1, 2), Uniform(0.25, 1)), 5495 (Beta(1, 2), Uniform(0, 0.75)), 5496 (Beta(1, 2), Uniform(0.25, 0.75)), 5497 (Beta(1, 2), Pareto(1, 2)), 5498 (Binomial(31, 0.7), Binomial(30, 0.3)), 5499 ( 5500 Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), 5501 Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8])), 5502 ), 5503 (Chi2(1), Beta(2, 3)), 5504 (Chi2(1), Pareto(2, 3)), 5505 (Chi2(1), Uniform(-2, 3)), 5506 (Exponential(1), Beta(2, 3)), 5507 (Exponential(1), Pareto(2, 3)), 5508 (Exponential(1), Uniform(-2, 3)), 5509 (Gamma(1, 2), Beta(3, 4)), 5510 (Gamma(1, 2), Pareto(3, 4)), 5511 (Gamma(1, 2), Uniform(-3, 4)), 5512 (Gumbel(-1, 2), Beta(3, 4)), 5513 (Gumbel(-1, 2), Chi2(3)), 5514 (Gumbel(-1, 2), Exponential(3)), 5515 (Gumbel(-1, 2), Gamma(3, 4)), 5516 (Gumbel(-1, 2), Pareto(3, 4)), 5517 (Gumbel(-1, 2), Uniform(-3, 4)), 5518 (Laplace(-1, 2), Beta(3, 4)), 5519 (Laplace(-1, 2), Chi2(3)), 5520 (Laplace(-1, 2), Exponential(3)), 5521 (Laplace(-1, 2), Gamma(3, 4)), 5522 (Laplace(-1, 2), Pareto(3, 4)), 5523 (Laplace(-1, 2), Uniform(-3, 4)), 5524 (Normal(-1, 2), Beta(3, 4)), 5525 (Normal(-1, 2), Chi2(3)), 5526 (Normal(-1, 2), Exponential(3)), 5527 (Normal(-1, 2), Gamma(3, 4)), 5528 (Normal(-1, 2), Pareto(3, 4)), 5529 (Normal(-1, 2), Uniform(-3, 4)), 5530 (Pareto(2, 1), Chi2(3)), 5531 (Pareto(2, 1), Exponential(3)), 5532 (Pareto(2, 1), Gamma(3, 4)), 5533 (Pareto(1, 2), Normal(-3, 4)), 5534 (Pareto(1, 2), Pareto(3, 4)), 5535 (Poisson(2), Bernoulli(0.5)), 5536 (Poisson(2.3), Binomial(10, 0.2)), 5537 (Uniform(-1, 1), Beta(2, 2)), 5538 (Uniform(0, 2), Beta(3, 4)), 5539 (Uniform(-1, 2), Beta(3, 4)), 5540 (Uniform(-1, 2), Chi2(3)), 5541 (Uniform(-1, 2), Exponential(3)), 5542 (Uniform(-1, 2), Gamma(3, 4)), 5543 (Uniform(-1, 2), Pareto(3, 4)), 5544 (ContinuousBernoulli(0.25), Uniform(0.25, 1)), 5545 (ContinuousBernoulli(0.25), Uniform(0, 0.75)), 5546 (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)), 5547 (ContinuousBernoulli(0.25), Pareto(1, 2)), 5548 (Exponential(1), ContinuousBernoulli(0.75)), 5549 (Gamma(1, 2), ContinuousBernoulli(0.75)), 5550 (Gumbel(-1, 2), ContinuousBernoulli(0.75)), 5551 (Laplace(-1, 2), ContinuousBernoulli(0.75)), 5552 (Normal(-1, 2), ContinuousBernoulli(0.75)), 5553 (Uniform(-1, 1), ContinuousBernoulli(0.75)), 5554 (Uniform(0, 2), ContinuousBernoulli(0.75)), 5555 (Uniform(-1, 2), ContinuousBernoulli(0.75)), 5556 ] 5557 5558 def test_kl_monte_carlo(self): 5559 set_rng_seed(0) # see Note [Randomized statistical tests] 5560 for (p, _), (_, q) in self.finite_examples: 5561 actual = kl_divergence(p, q) 5562 numerator = 0 5563 denominator = 0 5564 while denominator < self.max_samples: 5565 x = p.sample(sample_shape=(self.samples_per_batch,)) 5566 numerator += (p.log_prob(x) - q.log_prob(x)).sum(0) 5567 denominator += x.size(0) 5568 expected = numerator / denominator 5569 error = torch.abs(expected - actual) / (1 + expected) 5570 if error[error == error].max() < self.precision: 5571 break 5572 self.assertLess( 5573 error[error == error].max(), 5574 self.precision, 5575 "\n".join( 5576 [ 5577 f"Incorrect KL({type(p).__name__}, {type(q).__name__}).", 5578 f"Expected ({denominator} Monte Carlo samples): {expected}", 5579 f"Actual (analytic): {actual}", 5580 ] 5581 ), 5582 ) 5583 5584 # Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of 5585 # positive (semi) definite matrices. n is set to 5, but can be increased during testing. 5586 def test_kl_multivariate_normal(self): 5587 set_rng_seed(0) # see Note [Randomized statistical tests] 5588 n = 5 # Number of tests for multivariate_normal 5589 for i in range(0, n): 5590 loc = [torch.randn(4) for _ in range(0, 2)] 5591 scale_tril = [ 5592 transform_to(constraints.lower_cholesky)(torch.randn(4, 4)) 5593 for _ in range(0, 2) 5594 ] 5595 p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0]) 5596 q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1]) 5597 actual = kl_divergence(p, q) 5598 numerator = 0 5599 denominator = 0 5600 while denominator < self.max_samples: 5601 x = p.sample(sample_shape=(self.samples_per_batch,)) 5602 numerator += (p.log_prob(x) - q.log_prob(x)).sum(0) 5603 denominator += x.size(0) 5604 expected = numerator / denominator 5605 error = torch.abs(expected - actual) / (1 + expected) 5606 if error[error == error].max() < self.precision: 5607 break 5608 self.assertLess( 5609 error[error == error].max(), 5610 self.precision, 5611 "\n".join( 5612 [ 5613 f"Incorrect KL(MultivariateNormal, MultivariateNormal) instance {i + 1}/{n}", 5614 f"Expected ({denominator} Monte Carlo sample): {expected}", 5615 f"Actual (analytic): {actual}", 5616 ] 5617 ), 5618 ) 5619 5620 def test_kl_multivariate_normal_batched(self): 5621 b = 7 # Number of batches 5622 loc = [torch.randn(b, 3) for _ in range(0, 2)] 5623 scale_tril = [ 5624 transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)) 5625 for _ in range(0, 2) 5626 ] 5627 expected_kl = torch.stack( 5628 [ 5629 kl_divergence( 5630 MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]), 5631 MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i]), 5632 ) 5633 for i in range(0, b) 5634 ] 5635 ) 5636 actual_kl = kl_divergence( 5637 MultivariateNormal(loc[0], scale_tril=scale_tril[0]), 5638 MultivariateNormal(loc[1], scale_tril=scale_tril[1]), 5639 ) 5640 self.assertEqual(expected_kl, actual_kl) 5641 5642 def test_kl_multivariate_normal_batched_broadcasted(self): 5643 b = 7 # Number of batches 5644 loc = [torch.randn(b, 3) for _ in range(0, 2)] 5645 scale_tril = [ 5646 transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)), 5647 transform_to(constraints.lower_cholesky)(torch.randn(3, 3)), 5648 ] 5649 expected_kl = torch.stack( 5650 [ 5651 kl_divergence( 5652 MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]), 5653 MultivariateNormal(loc[1][i], scale_tril=scale_tril[1]), 5654 ) 5655 for i in range(0, b) 5656 ] 5657 ) 5658 actual_kl = kl_divergence( 5659 MultivariateNormal(loc[0], scale_tril=scale_tril[0]), 5660 MultivariateNormal(loc[1], scale_tril=scale_tril[1]), 5661 ) 5662 self.assertEqual(expected_kl, actual_kl) 5663 5664 def test_kl_lowrank_multivariate_normal(self): 5665 set_rng_seed(0) # see Note [Randomized statistical tests] 5666 n = 5 # Number of tests for lowrank_multivariate_normal 5667 for i in range(0, n): 5668 loc = [torch.randn(4) for _ in range(0, 2)] 5669 cov_factor = [torch.randn(4, 3) for _ in range(0, 2)] 5670 cov_diag = [ 5671 transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2) 5672 ] 5673 covariance_matrix = [ 5674 cov_factor[i].matmul(cov_factor[i].t()) + cov_diag[i].diag() 5675 for i in range(0, 2) 5676 ] 5677 p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]) 5678 q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]) 5679 p_full = MultivariateNormal(loc[0], covariance_matrix[0]) 5680 q_full = MultivariateNormal(loc[1], covariance_matrix[1]) 5681 expected = kl_divergence(p_full, q_full) 5682 5683 actual_lowrank_lowrank = kl_divergence(p, q) 5684 actual_lowrank_full = kl_divergence(p, q_full) 5685 actual_full_lowrank = kl_divergence(p_full, q) 5686 5687 error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max() 5688 self.assertLess( 5689 error_lowrank_lowrank, 5690 self.precision, 5691 "\n".join( 5692 [ 5693 f"Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}", 5694 f"Expected (from KL MultivariateNormal): {expected}", 5695 f"Actual (analytic): {actual_lowrank_lowrank}", 5696 ] 5697 ), 5698 ) 5699 5700 error_lowrank_full = torch.abs(actual_lowrank_full - expected).max() 5701 self.assertLess( 5702 error_lowrank_full, 5703 self.precision, 5704 "\n".join( 5705 [ 5706 f"Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {i + 1}/{n}", 5707 f"Expected (from KL MultivariateNormal): {expected}", 5708 f"Actual (analytic): {actual_lowrank_full}", 5709 ] 5710 ), 5711 ) 5712 5713 error_full_lowrank = torch.abs(actual_full_lowrank - expected).max() 5714 self.assertLess( 5715 error_full_lowrank, 5716 self.precision, 5717 "\n".join( 5718 [ 5719 f"Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {i + 1}/{n}", 5720 f"Expected (from KL MultivariateNormal): {expected}", 5721 f"Actual (analytic): {actual_full_lowrank}", 5722 ] 5723 ), 5724 ) 5725 5726 def test_kl_lowrank_multivariate_normal_batched(self): 5727 b = 7 # Number of batches 5728 loc = [torch.randn(b, 3) for _ in range(0, 2)] 5729 cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)] 5730 cov_diag = [ 5731 transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2) 5732 ] 5733 expected_kl = torch.stack( 5734 [ 5735 kl_divergence( 5736 LowRankMultivariateNormal( 5737 loc[0][i], cov_factor[0][i], cov_diag[0][i] 5738 ), 5739 LowRankMultivariateNormal( 5740 loc[1][i], cov_factor[1][i], cov_diag[1][i] 5741 ), 5742 ) 5743 for i in range(0, b) 5744 ] 5745 ) 5746 actual_kl = kl_divergence( 5747 LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]), 5748 LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]), 5749 ) 5750 self.assertEqual(expected_kl, actual_kl) 5751 5752 def test_kl_exponential_family(self): 5753 for (p, _), (_, q) in self.finite_examples: 5754 if type(p) == type(q) and issubclass(type(p), ExponentialFamily): 5755 actual = kl_divergence(p, q) 5756 expected = _kl_expfamily_expfamily(p, q) 5757 self.assertEqual( 5758 actual, 5759 expected, 5760 msg="\n".join( 5761 [ 5762 f"Incorrect KL({type(p).__name__}, {type(q).__name__}).", 5763 f"Expected (using Bregman Divergence) {expected}", 5764 f"Actual (analytic) {actual}", 5765 f"max error = {torch.abs(actual - expected).max()}", 5766 ] 5767 ), 5768 ) 5769 5770 def test_kl_infinite(self): 5771 for p, q in self.infinite_examples: 5772 self.assertTrue( 5773 (kl_divergence(p, q) == inf).all(), 5774 f"Incorrect KL({type(p).__name__}, {type(q).__name__})", 5775 ) 5776 5777 def test_kl_edgecases(self): 5778 self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0) 5779 self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0) 5780 self.assertEqual( 5781 kl_divergence( 5782 Categorical(torch.tensor([0.0, 1.0])), 5783 Categorical(torch.tensor([0.0, 1.0])), 5784 ), 5785 0, 5786 ) 5787 5788 def test_kl_shape(self): 5789 for Dist, params in _get_examples(): 5790 for i, param in enumerate(params): 5791 dist = Dist(**param) 5792 try: 5793 kl = kl_divergence(dist, dist) 5794 except NotImplementedError: 5795 continue 5796 expected_shape = dist.batch_shape if dist.batch_shape else torch.Size() 5797 self.assertEqual( 5798 kl.shape, 5799 expected_shape, 5800 msg="\n".join( 5801 [ 5802 f"{Dist.__name__} example {i + 1}/{len(params)}", 5803 f"Expected {expected_shape}", 5804 f"Actual {kl.shape}", 5805 ] 5806 ), 5807 ) 5808 5809 def test_kl_transformed(self): 5810 # Regression test for https://github.com/pytorch/pytorch/issues/34859 5811 scale = torch.ones(2, 3) 5812 loc = torch.zeros(2, 3) 5813 normal = Normal(loc=loc, scale=scale) 5814 diag_normal = Independent(normal, reinterpreted_batch_ndims=1) 5815 trans_dist = TransformedDistribution( 5816 diag_normal, AffineTransform(loc=0.0, scale=2.0) 5817 ) 5818 self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,)) 5819 self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,)) 5820 5821 @set_default_dtype(torch.double) 5822 def test_entropy_monte_carlo(self): 5823 set_rng_seed(0) # see Note [Randomized statistical tests] 5824 for Dist, params in _get_examples(): 5825 for i, param in enumerate(params): 5826 dist = Dist(**param) 5827 try: 5828 actual = dist.entropy() 5829 except NotImplementedError: 5830 continue 5831 x = dist.sample(sample_shape=(60000,)) 5832 expected = -dist.log_prob(x).mean(0) 5833 ignore = (expected == inf) | (expected == -inf) 5834 expected[ignore] = actual[ignore] 5835 self.assertEqual( 5836 actual, 5837 expected, 5838 atol=0.2, 5839 rtol=0, 5840 msg="\n".join( 5841 [ 5842 f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().", 5843 f"Expected (monte carlo) {expected}", 5844 f"Actual (analytic) {actual}", 5845 f"max error = {torch.abs(actual - expected).max()}", 5846 ] 5847 ), 5848 ) 5849 5850 @set_default_dtype(torch.double) 5851 def test_entropy_exponential_family(self): 5852 for Dist, params in _get_examples(): 5853 if not issubclass(Dist, ExponentialFamily): 5854 continue 5855 for i, param in enumerate(params): 5856 dist = Dist(**param) 5857 try: 5858 actual = dist.entropy() 5859 except NotImplementedError: 5860 continue 5861 try: 5862 expected = ExponentialFamily.entropy(dist) 5863 except NotImplementedError: 5864 continue 5865 self.assertEqual( 5866 actual, 5867 expected, 5868 msg="\n".join( 5869 [ 5870 f"{Dist.__name__} example {i + 1}/{len(params)}, incorrect .entropy().", 5871 f"Expected (Bregman Divergence) {expected}", 5872 f"Actual (analytic) {actual}", 5873 f"max error = {torch.abs(actual - expected).max()}", 5874 ] 5875 ), 5876 ) 5877 5878 5879class TestConstraints(DistributionsTestCase): 5880 def test_params_constraints(self): 5881 normalize_probs_dists = ( 5882 Categorical, 5883 Multinomial, 5884 OneHotCategorical, 5885 OneHotCategoricalStraightThrough, 5886 RelaxedOneHotCategorical, 5887 ) 5888 5889 for Dist, params in _get_examples(): 5890 for i, param in enumerate(params): 5891 dist = Dist(**param) 5892 for name, value in param.items(): 5893 if isinstance(value, numbers.Number): 5894 value = torch.tensor([value]) 5895 if Dist in normalize_probs_dists and name == "probs": 5896 # These distributions accept positive probs, but elsewhere we 5897 # use a stricter constraint to the simplex. 5898 value = value / value.sum(-1, True) 5899 try: 5900 constraint = dist.arg_constraints[name] 5901 except KeyError: 5902 continue # ignore optional parameters 5903 5904 # Check param shape is compatible with distribution shape. 5905 self.assertGreaterEqual(value.dim(), constraint.event_dim) 5906 value_batch_shape = value.shape[ 5907 : value.dim() - constraint.event_dim 5908 ] 5909 torch.broadcast_shapes(dist.batch_shape, value_batch_shape) 5910 5911 if is_dependent(constraint): 5912 continue 5913 5914 message = f"{Dist.__name__} example {i + 1}/{len(params)} parameter {name} = {value}" 5915 self.assertTrue(constraint.check(value).all(), msg=message) 5916 5917 def test_support_constraints(self): 5918 for Dist, params in _get_examples(): 5919 self.assertIsInstance(Dist.support, Constraint) 5920 for i, param in enumerate(params): 5921 dist = Dist(**param) 5922 value = dist.sample() 5923 constraint = dist.support 5924 message = ( 5925 f"{Dist.__name__} example {i + 1}/{len(params)} sample = {value}" 5926 ) 5927 self.assertEqual( 5928 constraint.event_dim, len(dist.event_shape), msg=message 5929 ) 5930 ok = constraint.check(value) 5931 self.assertEqual(ok.shape, dist.batch_shape, msg=message) 5932 self.assertTrue(ok.all(), msg=message) 5933 5934 5935@skipIfTorchDynamo("Not a TorchDynamo suitable test") 5936class TestNumericalStability(DistributionsTestCase): 5937 def _test_pdf_score( 5938 self, 5939 dist_class, 5940 x, 5941 expected_value, 5942 probs=None, 5943 logits=None, 5944 expected_gradient=None, 5945 atol=1e-5, 5946 ): 5947 if probs is not None: 5948 p = probs.detach().requires_grad_() 5949 dist = dist_class(p) 5950 else: 5951 p = logits.detach().requires_grad_() 5952 dist = dist_class(logits=p) 5953 log_pdf = dist.log_prob(x) 5954 log_pdf.sum().backward() 5955 self.assertEqual( 5956 log_pdf, 5957 expected_value, 5958 atol=atol, 5959 rtol=0, 5960 msg=f"Incorrect value for tensor type: {type(x)}. Expected = {expected_value}, Actual = {log_pdf}", 5961 ) 5962 if expected_gradient is not None: 5963 self.assertEqual( 5964 p.grad, 5965 expected_gradient, 5966 atol=atol, 5967 rtol=0, 5968 msg=f"Incorrect gradient for tensor type: {type(x)}. Expected = {expected_gradient}, Actual = {p.grad}", 5969 ) 5970 5971 def test_bernoulli_gradient(self): 5972 for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: 5973 self._test_pdf_score( 5974 dist_class=Bernoulli, 5975 probs=tensor_type([0]), 5976 x=tensor_type([0]), 5977 expected_value=tensor_type([0]), 5978 expected_gradient=tensor_type([0]), 5979 ) 5980 5981 self._test_pdf_score( 5982 dist_class=Bernoulli, 5983 probs=tensor_type([0]), 5984 x=tensor_type([1]), 5985 expected_value=tensor_type( 5986 [torch.finfo(tensor_type([]).dtype).eps] 5987 ).log(), 5988 expected_gradient=tensor_type([0]), 5989 ) 5990 5991 self._test_pdf_score( 5992 dist_class=Bernoulli, 5993 probs=tensor_type([1e-4]), 5994 x=tensor_type([1]), 5995 expected_value=tensor_type([math.log(1e-4)]), 5996 expected_gradient=tensor_type([10000]), 5997 ) 5998 5999 # Lower precision due to: 6000 # >>> 1 / (1 - torch.FloatTensor([0.9999])) 6001 # 9998.3408 6002 # [torch.FloatTensor of size 1] 6003 self._test_pdf_score( 6004 dist_class=Bernoulli, 6005 probs=tensor_type([1 - 1e-4]), 6006 x=tensor_type([0]), 6007 expected_value=tensor_type([math.log(1e-4)]), 6008 expected_gradient=tensor_type([-10000]), 6009 atol=2, 6010 ) 6011 6012 self._test_pdf_score( 6013 dist_class=Bernoulli, 6014 logits=tensor_type([math.log(9999)]), 6015 x=tensor_type([0]), 6016 expected_value=tensor_type([math.log(1e-4)]), 6017 expected_gradient=tensor_type([-1]), 6018 atol=1e-3, 6019 ) 6020 6021 def test_bernoulli_with_logits_underflow(self): 6022 for tensor_type, lim in [ 6023 (torch.FloatTensor, -1e38), 6024 (torch.DoubleTensor, -1e308), 6025 ]: 6026 self._test_pdf_score( 6027 dist_class=Bernoulli, 6028 logits=tensor_type([lim]), 6029 x=tensor_type([0]), 6030 expected_value=tensor_type([0]), 6031 expected_gradient=tensor_type([0]), 6032 ) 6033 6034 def test_bernoulli_with_logits_overflow(self): 6035 for tensor_type, lim in [ 6036 (torch.FloatTensor, 1e38), 6037 (torch.DoubleTensor, 1e308), 6038 ]: 6039 self._test_pdf_score( 6040 dist_class=Bernoulli, 6041 logits=tensor_type([lim]), 6042 x=tensor_type([1]), 6043 expected_value=tensor_type([0]), 6044 expected_gradient=tensor_type([0]), 6045 ) 6046 6047 def test_categorical_log_prob(self): 6048 for dtype in [torch.float, torch.double]: 6049 p = torch.tensor([0, 1], dtype=dtype, requires_grad=True) 6050 categorical = OneHotCategorical(p) 6051 log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype)) 6052 self.assertEqual(log_pdf.item(), 0) 6053 6054 def test_categorical_log_prob_with_logits(self): 6055 for dtype in [torch.float, torch.double]: 6056 p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True) 6057 categorical = OneHotCategorical(logits=p) 6058 log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype)) 6059 self.assertEqual(log_pdf_prob_1.item(), 0) 6060 log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype)) 6061 self.assertEqual(log_pdf_prob_0.item(), -inf) 6062 6063 def test_multinomial_log_prob(self): 6064 for dtype in [torch.float, torch.double]: 6065 p = torch.tensor([0, 1], dtype=dtype, requires_grad=True) 6066 s = torch.tensor([0, 10], dtype=dtype) 6067 multinomial = Multinomial(10, p) 6068 log_pdf = multinomial.log_prob(s) 6069 self.assertEqual(log_pdf.item(), 0) 6070 6071 def test_multinomial_log_prob_with_logits(self): 6072 for dtype in [torch.float, torch.double]: 6073 p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True) 6074 multinomial = Multinomial(10, logits=p) 6075 log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype)) 6076 self.assertEqual(log_pdf_prob_1.item(), 0) 6077 log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype)) 6078 self.assertEqual(log_pdf_prob_0.item(), -inf) 6079 6080 def test_continuous_bernoulli_gradient(self): 6081 def expec_val(x, probs=None, logits=None): 6082 assert not (probs is None and logits is None) 6083 if logits is not None: 6084 probs = 1.0 / (1.0 + math.exp(-logits)) 6085 bern_log_lik = x * math.log(probs) + (1.0 - x) * math.log1p(-probs) 6086 if probs < 0.499 or probs > 0.501: # using default values of lims here 6087 log_norm_const = ( 6088 math.log(math.fabs(math.atanh(1.0 - 2.0 * probs))) 6089 - math.log(math.fabs(1.0 - 2.0 * probs)) 6090 + math.log(2.0) 6091 ) 6092 else: 6093 aux = math.pow(probs - 0.5, 2) 6094 log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux 6095 log_lik = bern_log_lik + log_norm_const 6096 return log_lik 6097 6098 def expec_grad(x, probs=None, logits=None): 6099 assert not (probs is None and logits is None) 6100 if logits is not None: 6101 probs = 1.0 / (1.0 + math.exp(-logits)) 6102 grad_bern_log_lik = x / probs - (1.0 - x) / (1.0 - probs) 6103 if probs < 0.499 or probs > 0.501: # using default values of lims here 6104 grad_log_c = ( 6105 2.0 * probs 6106 - 4.0 * (probs - 1.0) * probs * math.atanh(1.0 - 2.0 * probs) 6107 - 1.0 6108 ) 6109 grad_log_c /= ( 6110 2.0 6111 * (probs - 1.0) 6112 * probs 6113 * (2.0 * probs - 1.0) 6114 * math.atanh(1.0 - 2.0 * probs) 6115 ) 6116 else: 6117 grad_log_c = 8.0 / 3.0 * (probs - 0.5) + 416.0 / 45.0 * math.pow( 6118 probs - 0.5, 3 6119 ) 6120 grad = grad_bern_log_lik + grad_log_c 6121 if logits is not None: 6122 grad *= 1.0 / (1.0 + math.exp(logits)) - 1.0 / math.pow( 6123 1.0 + math.exp(logits), 2 6124 ) 6125 return grad 6126 6127 for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: 6128 self._test_pdf_score( 6129 dist_class=ContinuousBernoulli, 6130 probs=tensor_type([0.1]), 6131 x=tensor_type([0.1]), 6132 expected_value=tensor_type([expec_val(0.1, probs=0.1)]), 6133 expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]), 6134 ) 6135 6136 self._test_pdf_score( 6137 dist_class=ContinuousBernoulli, 6138 probs=tensor_type([0.1]), 6139 x=tensor_type([1.0]), 6140 expected_value=tensor_type([expec_val(1.0, probs=0.1)]), 6141 expected_gradient=tensor_type([expec_grad(1.0, probs=0.1)]), 6142 ) 6143 6144 self._test_pdf_score( 6145 dist_class=ContinuousBernoulli, 6146 probs=tensor_type([0.4999]), 6147 x=tensor_type([0.9]), 6148 expected_value=tensor_type([expec_val(0.9, probs=0.4999)]), 6149 expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]), 6150 ) 6151 6152 self._test_pdf_score( 6153 dist_class=ContinuousBernoulli, 6154 probs=tensor_type([1e-4]), 6155 x=tensor_type([1]), 6156 expected_value=tensor_type([expec_val(1, probs=1e-4)]), 6157 expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])), 6158 atol=1e-3, 6159 ) 6160 6161 self._test_pdf_score( 6162 dist_class=ContinuousBernoulli, 6163 probs=tensor_type([1 - 1e-4]), 6164 x=tensor_type([0.1]), 6165 expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]), 6166 expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]), 6167 atol=2, 6168 ) 6169 6170 self._test_pdf_score( 6171 dist_class=ContinuousBernoulli, 6172 logits=tensor_type([math.log(9999)]), 6173 x=tensor_type([0]), 6174 expected_value=tensor_type([expec_val(0, logits=math.log(9999))]), 6175 expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]), 6176 atol=1e-3, 6177 ) 6178 6179 self._test_pdf_score( 6180 dist_class=ContinuousBernoulli, 6181 logits=tensor_type([0.001]), 6182 x=tensor_type([0.5]), 6183 expected_value=tensor_type([expec_val(0.5, logits=0.001)]), 6184 expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]), 6185 ) 6186 6187 def test_continuous_bernoulli_with_logits_underflow(self): 6188 for tensor_type, lim, expected in [ 6189 (torch.FloatTensor, -1e38, 2.76898), 6190 (torch.DoubleTensor, -1e308, 3.58473), 6191 ]: 6192 self._test_pdf_score( 6193 dist_class=ContinuousBernoulli, 6194 logits=tensor_type([lim]), 6195 x=tensor_type([0]), 6196 expected_value=tensor_type([expected]), 6197 expected_gradient=tensor_type([0.0]), 6198 ) 6199 6200 def test_continuous_bernoulli_with_logits_overflow(self): 6201 for tensor_type, lim, expected in [ 6202 (torch.FloatTensor, 1e38, 2.76898), 6203 (torch.DoubleTensor, 1e308, 3.58473), 6204 ]: 6205 self._test_pdf_score( 6206 dist_class=ContinuousBernoulli, 6207 logits=tensor_type([lim]), 6208 x=tensor_type([1]), 6209 expected_value=tensor_type([expected]), 6210 expected_gradient=tensor_type([0.0]), 6211 ) 6212 6213 6214# TODO: make this a pytest parameterized test 6215class TestLazyLogitsInitialization(DistributionsTestCase): 6216 def setUp(self): 6217 super().setUp() 6218 # ContinuousBernoulli is not tested because log_prob is not computed simply 6219 # from 'logits', but 'probs' is also needed 6220 self.examples = [ 6221 e 6222 for e in _get_examples() 6223 if e.Dist 6224 in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial) 6225 ] 6226 6227 def test_lazy_logits_initialization(self): 6228 for Dist, params in self.examples: 6229 param = params[0].copy() 6230 if "probs" not in param: 6231 continue 6232 probs = param.pop("probs") 6233 param["logits"] = probs_to_logits(probs) 6234 dist = Dist(**param) 6235 # Create new instance to generate a valid sample 6236 dist.log_prob(Dist(**param).sample()) 6237 message = f"Failed for {Dist.__name__} example 0/{len(params)}" 6238 self.assertNotIn("probs", dist.__dict__, msg=message) 6239 try: 6240 dist.enumerate_support() 6241 except NotImplementedError: 6242 pass 6243 self.assertNotIn("probs", dist.__dict__, msg=message) 6244 batch_shape, event_shape = dist.batch_shape, dist.event_shape 6245 self.assertNotIn("probs", dist.__dict__, msg=message) 6246 6247 def test_lazy_probs_initialization(self): 6248 for Dist, params in self.examples: 6249 param = params[0].copy() 6250 if "probs" not in param: 6251 continue 6252 dist = Dist(**param) 6253 dist.sample() 6254 message = f"Failed for {Dist.__name__} example 0/{len(params)}" 6255 self.assertNotIn("logits", dist.__dict__, msg=message) 6256 try: 6257 dist.enumerate_support() 6258 except NotImplementedError: 6259 pass 6260 self.assertNotIn("logits", dist.__dict__, msg=message) 6261 batch_shape, event_shape = dist.batch_shape, dist.event_shape 6262 self.assertNotIn("logits", dist.__dict__, msg=message) 6263 6264 6265@unittest.skipIf(not TEST_NUMPY, "NumPy not found") 6266@skipIfTorchDynamo("FIXME: Tries to trace through SciPy and fails") 6267class TestAgainstScipy(DistributionsTestCase): 6268 def setUp(self): 6269 super().setUp() 6270 positive_var = torch.randn(20, dtype=torch.double).exp() 6271 positive_var2 = torch.randn(20, dtype=torch.double).exp() 6272 random_var = torch.randn(20, dtype=torch.double) 6273 simplex_tensor = softmax(torch.randn(20, dtype=torch.double), dim=-1) 6274 cov_tensor = torch.randn(20, 20, dtype=torch.double) 6275 cov_tensor = cov_tensor @ cov_tensor.mT 6276 self.distribution_pairs = [ 6277 (Bernoulli(simplex_tensor), scipy.stats.bernoulli(simplex_tensor)), 6278 ( 6279 Beta(positive_var, positive_var2), 6280 scipy.stats.beta(positive_var, positive_var2), 6281 ), 6282 ( 6283 Binomial(10, simplex_tensor), 6284 scipy.stats.binom( 6285 10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy() 6286 ), 6287 ), 6288 ( 6289 Cauchy(random_var, positive_var), 6290 scipy.stats.cauchy(loc=random_var, scale=positive_var), 6291 ), 6292 (Dirichlet(positive_var), scipy.stats.dirichlet(positive_var)), 6293 ( 6294 Exponential(positive_var), 6295 scipy.stats.expon(scale=positive_var.reciprocal()), 6296 ), 6297 ( 6298 FisherSnedecor( 6299 positive_var, 4 + positive_var2 6300 ), # var for df2<=4 is undefined 6301 scipy.stats.f(positive_var, 4 + positive_var2), 6302 ), 6303 ( 6304 Gamma(positive_var, positive_var2), 6305 scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()), 6306 ), 6307 (Geometric(simplex_tensor), scipy.stats.geom(simplex_tensor, loc=-1)), 6308 ( 6309 Gumbel(random_var, positive_var2), 6310 scipy.stats.gumbel_r(random_var, positive_var2), 6311 ), 6312 (HalfCauchy(positive_var), scipy.stats.halfcauchy(scale=positive_var)), 6313 (HalfNormal(positive_var2), scipy.stats.halfnorm(scale=positive_var2)), 6314 ( 6315 InverseGamma(positive_var, positive_var2), 6316 scipy.stats.invgamma(positive_var, scale=positive_var2), 6317 ), 6318 ( 6319 Laplace(random_var, positive_var2), 6320 scipy.stats.laplace(random_var, positive_var2), 6321 ), 6322 ( 6323 # Tests fail 1e-5 threshold if scale > 3 6324 LogNormal(random_var, positive_var.clamp(max=3)), 6325 scipy.stats.lognorm( 6326 s=positive_var.clamp(max=3), scale=random_var.exp() 6327 ), 6328 ), 6329 ( 6330 LowRankMultivariateNormal( 6331 random_var, torch.zeros(20, 1, dtype=torch.double), positive_var2 6332 ), 6333 scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)), 6334 ), 6335 ( 6336 Multinomial(10, simplex_tensor), 6337 scipy.stats.multinomial(10, simplex_tensor), 6338 ), 6339 ( 6340 MultivariateNormal(random_var, torch.diag(positive_var2)), 6341 scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)), 6342 ), 6343 ( 6344 MultivariateNormal(random_var, cov_tensor), 6345 scipy.stats.multivariate_normal(random_var, cov_tensor), 6346 ), 6347 ( 6348 Normal(random_var, positive_var2), 6349 scipy.stats.norm(random_var, positive_var2), 6350 ), 6351 ( 6352 OneHotCategorical(simplex_tensor), 6353 scipy.stats.multinomial(1, simplex_tensor), 6354 ), 6355 ( 6356 Pareto(positive_var, 2 + positive_var2), 6357 scipy.stats.pareto(2 + positive_var2, scale=positive_var), 6358 ), 6359 (Poisson(positive_var), scipy.stats.poisson(positive_var)), 6360 ( 6361 StudentT(2 + positive_var, random_var, positive_var2), 6362 scipy.stats.t(2 + positive_var, random_var, positive_var2), 6363 ), 6364 ( 6365 Uniform(random_var, random_var + positive_var), 6366 scipy.stats.uniform(random_var, positive_var), 6367 ), 6368 ( 6369 VonMises(random_var, positive_var), 6370 scipy.stats.vonmises(positive_var, loc=random_var), 6371 ), 6372 ( 6373 Weibull( 6374 positive_var[0], positive_var2[0] 6375 ), # scipy var for Weibull only supports scalars 6376 scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0]), 6377 ), 6378 ( 6379 # scipy var for Wishart only supports scalars 6380 # SciPy allowed ndim -1 < df < ndim for Wishar distribution after version 1.7.0 6381 Wishart( 6382 ( 6383 20 6384 if version.parse(scipy.__version__) < version.parse("1.7.0") 6385 else 19 6386 ) 6387 + positive_var[0], 6388 cov_tensor, 6389 ), 6390 scipy.stats.wishart( 6391 ( 6392 20 6393 if version.parse(scipy.__version__) < version.parse("1.7.0") 6394 else 19 6395 ) 6396 + positive_var[0].item(), 6397 cov_tensor, 6398 ), 6399 ), 6400 ] 6401 6402 def test_mean(self): 6403 for pytorch_dist, scipy_dist in self.distribution_pairs: 6404 if isinstance(pytorch_dist, (Cauchy, HalfCauchy)): 6405 # Cauchy, HalfCauchy distributions' mean is nan, skipping check 6406 continue 6407 elif isinstance( 6408 pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal) 6409 ): 6410 self.assertEqual(pytorch_dist.mean, scipy_dist.mean, msg=pytorch_dist) 6411 else: 6412 self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), msg=pytorch_dist) 6413 6414 def test_variance_stddev(self): 6415 for pytorch_dist, scipy_dist in self.distribution_pairs: 6416 if isinstance(pytorch_dist, (Cauchy, HalfCauchy, VonMises)): 6417 # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check 6418 # VonMises variance is circular and scipy doesn't produce a correct result 6419 continue 6420 elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): 6421 self.assertEqual( 6422 pytorch_dist.variance, np.diag(scipy_dist.cov()), msg=pytorch_dist 6423 ) 6424 self.assertEqual( 6425 pytorch_dist.stddev, 6426 np.diag(scipy_dist.cov()) ** 0.5, 6427 msg=pytorch_dist, 6428 ) 6429 elif isinstance( 6430 pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal) 6431 ): 6432 self.assertEqual( 6433 pytorch_dist.variance, np.diag(scipy_dist.cov), msg=pytorch_dist 6434 ) 6435 self.assertEqual( 6436 pytorch_dist.stddev, 6437 np.diag(scipy_dist.cov) ** 0.5, 6438 msg=pytorch_dist, 6439 ) 6440 else: 6441 self.assertEqual( 6442 pytorch_dist.variance, scipy_dist.var(), msg=pytorch_dist 6443 ) 6444 self.assertEqual( 6445 pytorch_dist.stddev, scipy_dist.var() ** 0.5, msg=pytorch_dist 6446 ) 6447 6448 @set_default_dtype(torch.double) 6449 def test_cdf(self): 6450 for pytorch_dist, scipy_dist in self.distribution_pairs: 6451 samples = pytorch_dist.sample((5,)) 6452 try: 6453 cdf = pytorch_dist.cdf(samples) 6454 except NotImplementedError: 6455 continue 6456 self.assertEqual(cdf, scipy_dist.cdf(samples), msg=pytorch_dist) 6457 6458 def test_icdf(self): 6459 for pytorch_dist, scipy_dist in self.distribution_pairs: 6460 samples = torch.rand((5,) + pytorch_dist.batch_shape, dtype=torch.double) 6461 try: 6462 icdf = pytorch_dist.icdf(samples) 6463 except NotImplementedError: 6464 continue 6465 self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist) 6466 6467 6468class TestFunctors(DistributionsTestCase): 6469 def test_cat_transform(self): 6470 x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6471 x2 = (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100 6472 x3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6473 t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform 6474 dim = 0 6475 x = torch.cat([x1, x2, x3], dim=dim) 6476 t = CatTransform([t1, t2, t3], dim=dim) 6477 actual_dom_check = t.domain.check(x) 6478 expected_dom_check = torch.cat( 6479 [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim 6480 ) 6481 self.assertEqual(expected_dom_check, actual_dom_check) 6482 actual = t(x) 6483 expected = torch.cat([t1(x1), t2(x2), t3(x3)], dim=dim) 6484 self.assertEqual(expected, actual) 6485 y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6486 y2 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6487 y3 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6488 y = torch.cat([y1, y2, y3], dim=dim) 6489 actual_cod_check = t.codomain.check(y) 6490 expected_cod_check = torch.cat( 6491 [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)], 6492 dim=dim, 6493 ) 6494 self.assertEqual(actual_cod_check, expected_cod_check) 6495 actual_inv = t.inv(y) 6496 expected_inv = torch.cat([t1.inv(y1), t2.inv(y2), t3.inv(y3)], dim=dim) 6497 self.assertEqual(expected_inv, actual_inv) 6498 actual_jac = t.log_abs_det_jacobian(x, y) 6499 expected_jac = torch.cat( 6500 [ 6501 t1.log_abs_det_jacobian(x1, y1), 6502 t2.log_abs_det_jacobian(x2, y2), 6503 t3.log_abs_det_jacobian(x3, y3), 6504 ], 6505 dim=dim, 6506 ) 6507 self.assertEqual(actual_jac, expected_jac) 6508 6509 def test_cat_transform_non_uniform(self): 6510 x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6511 x2 = torch.cat( 6512 [ 6513 (torch.arange(1, 101, dtype=torch.float).view(-1, 100) - 1) / 100, 6514 torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6515 ] 6516 ) 6517 t1 = ExpTransform() 6518 t2 = CatTransform([AffineTransform(1, 100), identity_transform], dim=0) 6519 dim = 0 6520 x = torch.cat([x1, x2], dim=dim) 6521 t = CatTransform([t1, t2], dim=dim, lengths=[1, 2]) 6522 actual_dom_check = t.domain.check(x) 6523 expected_dom_check = torch.cat( 6524 [t1.domain.check(x1), t2.domain.check(x2)], dim=dim 6525 ) 6526 self.assertEqual(expected_dom_check, actual_dom_check) 6527 actual = t(x) 6528 expected = torch.cat([t1(x1), t2(x2)], dim=dim) 6529 self.assertEqual(expected, actual) 6530 y1 = torch.arange(1, 101, dtype=torch.float).view(-1, 100) 6531 y2 = torch.cat( 6532 [ 6533 torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6534 torch.arange(1, 101, dtype=torch.float).view(-1, 100), 6535 ] 6536 ) 6537 y = torch.cat([y1, y2], dim=dim) 6538 actual_cod_check = t.codomain.check(y) 6539 expected_cod_check = torch.cat( 6540 [t1.codomain.check(y1), t2.codomain.check(y2)], dim=dim 6541 ) 6542 self.assertEqual(actual_cod_check, expected_cod_check) 6543 actual_inv = t.inv(y) 6544 expected_inv = torch.cat([t1.inv(y1), t2.inv(y2)], dim=dim) 6545 self.assertEqual(expected_inv, actual_inv) 6546 actual_jac = t.log_abs_det_jacobian(x, y) 6547 expected_jac = torch.cat( 6548 [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)], dim=dim 6549 ) 6550 self.assertEqual(actual_jac, expected_jac) 6551 6552 def test_cat_event_dim(self): 6553 t1 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) 6554 t2 = AffineTransform(0, 2 * torch.ones(2), event_dim=1) 6555 dim = 1 6556 bs = 16 6557 x1 = torch.randn(bs, 2) 6558 x2 = torch.randn(bs, 2) 6559 x = torch.cat([x1, x2], dim=1) 6560 t = CatTransform([t1, t2], dim=dim, lengths=[2, 2]) 6561 y1 = t1(x1) 6562 y2 = t2(x2) 6563 y = t(x) 6564 actual_jac = t.log_abs_det_jacobian(x, y) 6565 expected_jac = sum( 6566 [t1.log_abs_det_jacobian(x1, y1), t2.log_abs_det_jacobian(x2, y2)] 6567 ) 6568 6569 def test_stack_transform(self): 6570 x1 = -1 * torch.arange(1, 101, dtype=torch.float) 6571 x2 = (torch.arange(1, 101, dtype=torch.float) - 1) / 100 6572 x3 = torch.arange(1, 101, dtype=torch.float) 6573 t1, t2, t3 = ExpTransform(), AffineTransform(1, 100), identity_transform 6574 dim = 0 6575 x = torch.stack([x1, x2, x3], dim=dim) 6576 t = StackTransform([t1, t2, t3], dim=dim) 6577 actual_dom_check = t.domain.check(x) 6578 expected_dom_check = torch.stack( 6579 [t1.domain.check(x1), t2.domain.check(x2), t3.domain.check(x3)], dim=dim 6580 ) 6581 self.assertEqual(expected_dom_check, actual_dom_check) 6582 actual = t(x) 6583 expected = torch.stack([t1(x1), t2(x2), t3(x3)], dim=dim) 6584 self.assertEqual(expected, actual) 6585 y1 = torch.arange(1, 101, dtype=torch.float) 6586 y2 = torch.arange(1, 101, dtype=torch.float) 6587 y3 = torch.arange(1, 101, dtype=torch.float) 6588 y = torch.stack([y1, y2, y3], dim=dim) 6589 actual_cod_check = t.codomain.check(y) 6590 expected_cod_check = torch.stack( 6591 [t1.codomain.check(y1), t2.codomain.check(y2), t3.codomain.check(y3)], 6592 dim=dim, 6593 ) 6594 self.assertEqual(actual_cod_check, expected_cod_check) 6595 actual_inv = t.inv(x) 6596 expected_inv = torch.stack([t1.inv(x1), t2.inv(x2), t3.inv(x3)], dim=dim) 6597 self.assertEqual(expected_inv, actual_inv) 6598 actual_jac = t.log_abs_det_jacobian(x, y) 6599 expected_jac = torch.stack( 6600 [ 6601 t1.log_abs_det_jacobian(x1, y1), 6602 t2.log_abs_det_jacobian(x2, y2), 6603 t3.log_abs_det_jacobian(x3, y3), 6604 ], 6605 dim=dim, 6606 ) 6607 self.assertEqual(actual_jac, expected_jac) 6608 6609 6610class TestValidation(DistributionsTestCase): 6611 def test_valid(self): 6612 for Dist, params in _get_examples(): 6613 for param in params: 6614 Dist(validate_args=True, **param) 6615 6616 @set_default_dtype(torch.double) 6617 def test_invalid_log_probs_arg(self): 6618 # Check that validation errors are indeed disabled, 6619 # but they might raise another error 6620 for Dist, params in _get_examples(): 6621 if Dist == TransformedDistribution: 6622 # TransformedDistribution has a distribution instance 6623 # as the argument, so we cannot do much about that 6624 continue 6625 for i, param in enumerate(params): 6626 d_nonval = Dist(validate_args=False, **param) 6627 d_val = Dist(validate_args=True, **param) 6628 for v in torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]): 6629 # samples with incorrect shape must throw ValueError only 6630 try: 6631 log_prob = d_val.log_prob(v) 6632 except ValueError: 6633 pass 6634 # get sample of correct shape 6635 val = torch.full(d_val.batch_shape + d_val.event_shape, v) 6636 # check samples with incorrect support 6637 try: 6638 log_prob = d_val.log_prob(val) 6639 except ValueError as e: 6640 if e.args and "must be within the support" in e.args[0]: 6641 try: 6642 log_prob = d_nonval.log_prob(val) 6643 except RuntimeError: 6644 pass 6645 6646 # check correct samples are ok 6647 valid_value = d_val.sample() 6648 d_val.log_prob(valid_value) 6649 # check invalid values raise ValueError 6650 if valid_value.dtype == torch.long: 6651 valid_value = valid_value.float() 6652 invalid_value = torch.full_like(valid_value, math.nan) 6653 try: 6654 with self.assertRaisesRegex( 6655 ValueError, 6656 "Expected value argument .* to be within the support .*", 6657 ): 6658 d_val.log_prob(invalid_value) 6659 except AssertionError as e: 6660 fail_string = "Support ValueError not raised for {} example {}/{}" 6661 raise AssertionError( 6662 fail_string.format(Dist.__name__, i + 1, len(params)) 6663 ) from e 6664 6665 @set_default_dtype(torch.double) 6666 def test_invalid(self): 6667 for Dist, params in _get_bad_examples(): 6668 for i, param in enumerate(params): 6669 try: 6670 with self.assertRaises(ValueError): 6671 Dist(validate_args=True, **param) 6672 except AssertionError as e: 6673 fail_string = "ValueError not raised for {} example {}/{}" 6674 raise AssertionError( 6675 fail_string.format(Dist.__name__, i + 1, len(params)) 6676 ) from e 6677 6678 def test_warning_unimplemented_constraints(self): 6679 class Delta(Distribution): 6680 def __init__(self, validate_args=True): 6681 super().__init__(validate_args=validate_args) 6682 6683 def sample(self, sample_shape=torch.Size()): 6684 return torch.tensor(0.0).expand(sample_shape) 6685 6686 def log_prob(self, value): 6687 if self._validate_args: 6688 self._validate_sample(value) 6689 value[value != 0.0] = -float("inf") 6690 value[value == 0.0] = 0.0 6691 return value 6692 6693 with self.assertWarns(UserWarning): 6694 d = Delta() 6695 sample = d.sample((2,)) 6696 with self.assertWarns(UserWarning): 6697 d.log_prob(sample) 6698 6699 6700class TestJit(DistributionsTestCase): 6701 def _examples(self): 6702 for Dist, params in _get_examples(): 6703 for param in params: 6704 keys = param.keys() 6705 values = tuple(param[key] for key in keys) 6706 if not all(isinstance(x, torch.Tensor) for x in values): 6707 continue 6708 sample = Dist(**param).sample() 6709 yield Dist, keys, values, sample 6710 6711 def _perturb_tensor(self, value, constraint): 6712 if isinstance(constraint, constraints._IntegerGreaterThan): 6713 return value + 1 6714 if isinstance( 6715 constraint, 6716 (constraints._PositiveDefinite, constraints._PositiveSemidefinite), 6717 ): 6718 return value + torch.eye(value.shape[-1]) 6719 if value.dtype in [torch.float, torch.double]: 6720 transform = transform_to(constraint) 6721 delta = value.new(value.shape).normal_() 6722 return transform(transform.inv(value) + delta) 6723 if value.dtype == torch.long: 6724 result = value.clone() 6725 result[value == 0] = 1 6726 result[value == 1] = 0 6727 return result 6728 raise NotImplementedError 6729 6730 def _perturb(self, Dist, keys, values, sample): 6731 with torch.no_grad(): 6732 if Dist is Uniform: 6733 param = dict(zip(keys, values)) 6734 param["low"] = param["low"] - torch.rand(param["low"].shape) 6735 param["high"] = param["high"] + torch.rand(param["high"].shape) 6736 values = [param[key] for key in keys] 6737 else: 6738 values = [ 6739 self._perturb_tensor( 6740 value, Dist.arg_constraints.get(key, constraints.real) 6741 ) 6742 for key, value in zip(keys, values) 6743 ] 6744 param = dict(zip(keys, values)) 6745 sample = Dist(**param).sample() 6746 return values, sample 6747 6748 @set_default_dtype(torch.double) 6749 def test_sample(self): 6750 for Dist, keys, values, sample in self._examples(): 6751 6752 def f(*values): 6753 param = dict(zip(keys, values)) 6754 dist = Dist(**param) 6755 return dist.sample() 6756 6757 traced_f = torch.jit.trace(f, values, check_trace=False) 6758 6759 # FIXME Schema not found for node 6760 xfail = [ 6761 Cauchy, # aten::cauchy(Double(2,1), float, float, Generator) 6762 HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator) 6763 VonMises, # Variance is not Euclidean 6764 ] 6765 if Dist in xfail: 6766 continue 6767 6768 with torch.random.fork_rng(): 6769 sample = f(*values) 6770 traced_sample = traced_f(*values) 6771 self.assertEqual(sample, traced_sample) 6772 6773 # FIXME no nondeterministic nodes found in trace 6774 xfail = [Beta, Dirichlet] 6775 if Dist not in xfail: 6776 self.assertTrue( 6777 any(n.isNondeterministic() for n in traced_f.graph.nodes()) 6778 ) 6779 6780 def test_rsample(self): 6781 for Dist, keys, values, sample in self._examples(): 6782 if not Dist.has_rsample: 6783 continue 6784 6785 def f(*values): 6786 param = dict(zip(keys, values)) 6787 dist = Dist(**param) 6788 return dist.rsample() 6789 6790 traced_f = torch.jit.trace(f, values, check_trace=False) 6791 6792 # FIXME Schema not found for node 6793 xfail = [ 6794 Cauchy, # aten::cauchy(Double(2,1), float, float, Generator) 6795 HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator) 6796 ] 6797 if Dist in xfail: 6798 continue 6799 6800 with torch.random.fork_rng(): 6801 sample = f(*values) 6802 traced_sample = traced_f(*values) 6803 self.assertEqual(sample, traced_sample) 6804 6805 # FIXME no nondeterministic nodes found in trace 6806 xfail = [Beta, Dirichlet] 6807 if Dist not in xfail: 6808 self.assertTrue( 6809 any(n.isNondeterministic() for n in traced_f.graph.nodes()) 6810 ) 6811 6812 @set_default_dtype(torch.double) 6813 def test_log_prob(self): 6814 for Dist, keys, values, sample in self._examples(): 6815 # FIXME traced functions produce incorrect results 6816 xfail = [LowRankMultivariateNormal, MultivariateNormal] 6817 if Dist in xfail: 6818 continue 6819 6820 def f(sample, *values): 6821 param = dict(zip(keys, values)) 6822 dist = Dist(**param) 6823 return dist.log_prob(sample) 6824 6825 traced_f = torch.jit.trace(f, (sample,) + values) 6826 6827 # check on different data 6828 values, sample = self._perturb(Dist, keys, values, sample) 6829 expected = f(sample, *values) 6830 actual = traced_f(sample, *values) 6831 self.assertEqual( 6832 expected, 6833 actual, 6834 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6835 ) 6836 6837 def test_enumerate_support(self): 6838 for Dist, keys, values, sample in self._examples(): 6839 # FIXME traced functions produce incorrect results 6840 xfail = [Binomial] 6841 if Dist in xfail: 6842 continue 6843 6844 def f(*values): 6845 param = dict(zip(keys, values)) 6846 dist = Dist(**param) 6847 return dist.enumerate_support() 6848 6849 try: 6850 traced_f = torch.jit.trace(f, values) 6851 except NotImplementedError: 6852 continue 6853 6854 # check on different data 6855 values, sample = self._perturb(Dist, keys, values, sample) 6856 expected = f(*values) 6857 actual = traced_f(*values) 6858 self.assertEqual( 6859 expected, 6860 actual, 6861 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6862 ) 6863 6864 def test_mean(self): 6865 for Dist, keys, values, sample in self._examples(): 6866 6867 def f(*values): 6868 param = dict(zip(keys, values)) 6869 dist = Dist(**param) 6870 return dist.mean 6871 6872 try: 6873 traced_f = torch.jit.trace(f, values) 6874 except NotImplementedError: 6875 continue 6876 6877 # check on different data 6878 values, sample = self._perturb(Dist, keys, values, sample) 6879 expected = f(*values) 6880 actual = traced_f(*values) 6881 expected[expected == float("inf")] = 0.0 6882 actual[actual == float("inf")] = 0.0 6883 self.assertEqual( 6884 expected, 6885 actual, 6886 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6887 ) 6888 6889 def test_variance(self): 6890 for Dist, keys, values, sample in self._examples(): 6891 if Dist in [Cauchy, HalfCauchy]: 6892 continue # infinite variance 6893 6894 def f(*values): 6895 param = dict(zip(keys, values)) 6896 dist = Dist(**param) 6897 return dist.variance 6898 6899 try: 6900 traced_f = torch.jit.trace(f, values) 6901 except NotImplementedError: 6902 continue 6903 6904 # check on different data 6905 values, sample = self._perturb(Dist, keys, values, sample) 6906 expected = f(*values).clone() 6907 actual = traced_f(*values).clone() 6908 expected[expected == float("inf")] = 0.0 6909 actual[actual == float("inf")] = 0.0 6910 self.assertEqual( 6911 expected, 6912 actual, 6913 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6914 ) 6915 6916 @set_default_dtype(torch.double) 6917 def test_entropy(self): 6918 for Dist, keys, values, sample in self._examples(): 6919 # FIXME traced functions produce incorrect results 6920 xfail = [LowRankMultivariateNormal, MultivariateNormal] 6921 if Dist in xfail: 6922 continue 6923 6924 def f(*values): 6925 param = dict(zip(keys, values)) 6926 dist = Dist(**param) 6927 return dist.entropy() 6928 6929 try: 6930 traced_f = torch.jit.trace(f, values) 6931 except NotImplementedError: 6932 continue 6933 6934 # check on different data 6935 values, sample = self._perturb(Dist, keys, values, sample) 6936 expected = f(*values) 6937 actual = traced_f(*values) 6938 self.assertEqual( 6939 expected, 6940 actual, 6941 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6942 ) 6943 6944 @set_default_dtype(torch.double) 6945 def test_cdf(self): 6946 for Dist, keys, values, sample in self._examples(): 6947 6948 def f(sample, *values): 6949 param = dict(zip(keys, values)) 6950 dist = Dist(**param) 6951 cdf = dist.cdf(sample) 6952 return dist.icdf(cdf) 6953 6954 try: 6955 traced_f = torch.jit.trace(f, (sample,) + values) 6956 except NotImplementedError: 6957 continue 6958 6959 # check on different data 6960 values, sample = self._perturb(Dist, keys, values, sample) 6961 expected = f(sample, *values) 6962 actual = traced_f(sample, *values) 6963 self.assertEqual( 6964 expected, 6965 actual, 6966 msg=f"{Dist.__name__}\nExpected:\n{expected}\nActual:\n{actual}", 6967 ) 6968 6969 6970if __name__ == "__main__" and torch._C.has_lapack: 6971 TestCase._default_dtype_check_enabled = True 6972 run_tests() 6973