1# Owner(s): ["module: distributions"] 2 3import io 4from numbers import Number 5 6import pytest 7 8import torch 9from torch.autograd import grad 10from torch.autograd.functional import jacobian 11from torch.distributions import ( 12 constraints, 13 Dirichlet, 14 Independent, 15 Normal, 16 TransformedDistribution, 17) 18from torch.distributions.transforms import ( 19 _InverseTransform, 20 AbsTransform, 21 AffineTransform, 22 ComposeTransform, 23 CorrCholeskyTransform, 24 CumulativeDistributionTransform, 25 ExpTransform, 26 identity_transform, 27 IndependentTransform, 28 LowerCholeskyTransform, 29 PositiveDefiniteTransform, 30 PowerTransform, 31 ReshapeTransform, 32 SigmoidTransform, 33 SoftmaxTransform, 34 SoftplusTransform, 35 StickBreakingTransform, 36 TanhTransform, 37 Transform, 38) 39from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix 40from torch.testing._internal.common_utils import run_tests 41 42 43def get_transforms(cache_size): 44 transforms = [ 45 AbsTransform(cache_size=cache_size), 46 ExpTransform(cache_size=cache_size), 47 PowerTransform(exponent=2, cache_size=cache_size), 48 PowerTransform(exponent=-2, cache_size=cache_size), 49 PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size), 50 PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size), 51 SigmoidTransform(cache_size=cache_size), 52 TanhTransform(cache_size=cache_size), 53 AffineTransform(0, 1, cache_size=cache_size), 54 AffineTransform(1, -2, cache_size=cache_size), 55 AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 56 AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), 57 SoftmaxTransform(cache_size=cache_size), 58 SoftplusTransform(cache_size=cache_size), 59 StickBreakingTransform(cache_size=cache_size), 60 LowerCholeskyTransform(cache_size=cache_size), 61 CorrCholeskyTransform(cache_size=cache_size), 62 PositiveDefiniteTransform(cache_size=cache_size), 63 ComposeTransform( 64 [ 65 AffineTransform( 66 torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size 67 ), 68 ] 69 ), 70 ComposeTransform( 71 [ 72 AffineTransform( 73 torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size 74 ), 75 ExpTransform(cache_size=cache_size), 76 ] 77 ), 78 ComposeTransform( 79 [ 80 AffineTransform(0, 1, cache_size=cache_size), 81 AffineTransform( 82 torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size 83 ), 84 AffineTransform(1, -2, cache_size=cache_size), 85 AffineTransform( 86 torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size 87 ), 88 ] 89 ), 90 ReshapeTransform((4, 5), (2, 5, 2)), 91 IndependentTransform( 92 AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1 93 ), 94 CumulativeDistributionTransform(Normal(0, 1)), 95 ] 96 transforms += [t.inv for t in transforms] 97 return transforms 98 99 100def reshape_transform(transform, shape): 101 # Needed to squash batch dims for testing jacobian 102 if isinstance(transform, AffineTransform): 103 if isinstance(transform.loc, Number): 104 return transform 105 try: 106 return AffineTransform( 107 transform.loc.expand(shape), 108 transform.scale.expand(shape), 109 cache_size=transform._cache_size, 110 ) 111 except RuntimeError: 112 return AffineTransform( 113 transform.loc.reshape(shape), 114 transform.scale.reshape(shape), 115 cache_size=transform._cache_size, 116 ) 117 if isinstance(transform, ComposeTransform): 118 reshaped_parts = [] 119 for p in transform.parts: 120 reshaped_parts.append(reshape_transform(p, shape)) 121 return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) 122 if isinstance(transform.inv, AffineTransform): 123 return reshape_transform(transform.inv, shape).inv 124 if isinstance(transform.inv, ComposeTransform): 125 return reshape_transform(transform.inv, shape).inv 126 return transform 127 128 129# Generate pytest ids 130def transform_id(x): 131 assert isinstance(x, Transform) 132 name = ( 133 f"Inv({type(x._inv).__name__})" 134 if isinstance(x, _InverseTransform) 135 else f"{type(x).__name__}" 136 ) 137 return f"{name}(cache_size={x._cache_size})" 138 139 140def generate_data(transform): 141 torch.manual_seed(1) 142 while isinstance(transform, IndependentTransform): 143 transform = transform.base_transform 144 if isinstance(transform, ReshapeTransform): 145 return torch.randn(transform.in_shape) 146 if isinstance(transform.inv, ReshapeTransform): 147 return torch.randn(transform.inv.out_shape) 148 domain = transform.domain 149 while ( 150 isinstance(domain, constraints.independent) 151 and domain is not constraints.real_vector 152 ): 153 domain = domain.base_constraint 154 codomain = transform.codomain 155 x = torch.empty(4, 5) 156 positive_definite_constraints = [ 157 constraints.lower_cholesky, 158 constraints.positive_definite, 159 ] 160 if domain in positive_definite_constraints: 161 x = torch.randn(6, 6) 162 x = x.tril(-1) + x.diag().exp().diag_embed() 163 if domain is constraints.positive_definite: 164 return x @ x.T 165 return x 166 elif codomain in positive_definite_constraints: 167 return torch.randn(6, 6) 168 elif domain is constraints.real: 169 return x.normal_() 170 elif domain is constraints.real_vector: 171 # For corr_cholesky the last dim in the vector 172 # must be of size (dim * dim) // 2 173 x = torch.empty(3, 6) 174 x = x.normal_() 175 return x 176 elif domain is constraints.positive: 177 return x.normal_().exp() 178 elif domain is constraints.unit_interval: 179 return x.uniform_() 180 elif isinstance(domain, constraints.interval): 181 x = x.uniform_() 182 x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) 183 return x 184 elif domain is constraints.simplex: 185 x = x.normal_().exp() 186 x /= x.sum(-1, True) 187 return x 188 elif domain is constraints.corr_cholesky: 189 x = torch.empty(4, 5, 5) 190 x = x.normal_().tril() 191 x /= x.norm(dim=-1, keepdim=True) 192 x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs()) 193 return x 194 raise ValueError(f"Unsupported domain: {domain}") 195 196 197TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1) 198TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0) 199ALL_TRANSFORMS = ( 200 TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform] 201) 202 203 204@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) 205def test_inv_inv(transform, ids=transform_id): 206 assert transform.inv.inv is transform 207 208 209@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) 210@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) 211def test_equality(x, y): 212 if x is y: 213 assert x == y 214 else: 215 assert x != y 216 assert identity_transform == identity_transform.inv 217 218 219@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) 220def test_with_cache(transform): 221 if transform._cache_size == 0: 222 transform = transform.with_cache(1) 223 assert transform._cache_size == 1 224 x = generate_data(transform).requires_grad_() 225 try: 226 y = transform(x) 227 except NotImplementedError: 228 pytest.skip("Not implemented.") 229 y2 = transform(x) 230 assert y2 is y 231 232 233@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) 234@pytest.mark.parametrize("test_cached", [True, False]) 235def test_forward_inverse(transform, test_cached): 236 x = generate_data(transform).requires_grad_() 237 assert transform.domain.check(x).all() # verify that the input data are valid 238 try: 239 y = transform(x) 240 except NotImplementedError: 241 pytest.skip("Not implemented.") 242 assert y.shape == transform.forward_shape(x.shape) 243 if test_cached: 244 x2 = transform.inv(y) # should be implemented at least by caching 245 else: 246 try: 247 x2 = transform.inv(y.clone()) # bypass cache 248 except NotImplementedError: 249 pytest.skip("Not implemented.") 250 assert x2.shape == transform.inverse_shape(y.shape) 251 y2 = transform(x2) 252 if transform.bijective: 253 # verify function inverse 254 assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join( 255 [ 256 f"{transform} t.inv(t(-)) error", 257 f"x = {x}", 258 f"y = t(x) = {y}", 259 f"x2 = t.inv(y) = {x2}", 260 ] 261 ) 262 else: 263 # verify weaker function pseudo-inverse 264 assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join( 265 [ 266 f"{transform} t(t.inv(t(-))) error", 267 f"x = {x}", 268 f"y = t(x) = {y}", 269 f"x2 = t.inv(y) = {x2}", 270 f"y2 = t(x2) = {y2}", 271 ] 272 ) 273 274 275def test_compose_transform_shapes(): 276 transform0 = ExpTransform() 277 transform1 = SoftmaxTransform() 278 transform2 = LowerCholeskyTransform() 279 280 assert transform0.event_dim == 0 281 assert transform1.event_dim == 1 282 assert transform2.event_dim == 2 283 assert ComposeTransform([transform0, transform1]).event_dim == 1 284 assert ComposeTransform([transform0, transform2]).event_dim == 2 285 assert ComposeTransform([transform1, transform2]).event_dim == 2 286 287 288transform0 = ExpTransform() 289transform1 = SoftmaxTransform() 290transform2 = LowerCholeskyTransform() 291base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) 292base_dist1 = Dirichlet(torch.ones(4, 4)) 293base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) 294 295 296@pytest.mark.parametrize( 297 ("batch_shape", "event_shape", "dist"), 298 [ 299 ((4, 4), (), base_dist0), 300 ((4,), (4,), base_dist1), 301 ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), 302 ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), 303 ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), 304 ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), 305 ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), 306 ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), 307 ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), 308 ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), 309 ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), 310 ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), 311 ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), 312 ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), 313 ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), 314 ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), 315 ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), 316 ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), 317 ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), 318 ((3, 4, 4), (), base_dist2), 319 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), 320 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), 321 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), 322 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), 323 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), 324 ], 325) 326def test_transformed_distribution_shapes(batch_shape, event_shape, dist): 327 assert dist.batch_shape == batch_shape 328 assert dist.event_shape == event_shape 329 x = dist.rsample() 330 try: 331 dist.log_prob(x) # this should not crash 332 except NotImplementedError: 333 pytest.skip("Not implemented.") 334 335 336@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) 337def test_jit_fwd(transform): 338 x = generate_data(transform).requires_grad_() 339 340 def f(x): 341 return transform(x) 342 343 try: 344 traced_f = torch.jit.trace(f, (x,)) 345 except NotImplementedError: 346 pytest.skip("Not implemented.") 347 348 # check on different inputs 349 x = generate_data(transform).requires_grad_() 350 assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) 351 352 353@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) 354def test_jit_inv(transform): 355 y = generate_data(transform.inv).requires_grad_() 356 357 def f(y): 358 return transform.inv(y) 359 360 try: 361 traced_f = torch.jit.trace(f, (y,)) 362 except NotImplementedError: 363 pytest.skip("Not implemented.") 364 365 # check on different inputs 366 y = generate_data(transform.inv).requires_grad_() 367 assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True) 368 369 370@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) 371def test_jit_jacobian(transform): 372 x = generate_data(transform).requires_grad_() 373 374 def f(x): 375 y = transform(x) 376 return transform.log_abs_det_jacobian(x, y) 377 378 try: 379 traced_f = torch.jit.trace(f, (x,)) 380 except NotImplementedError: 381 pytest.skip("Not implemented.") 382 383 # check on different inputs 384 x = generate_data(transform).requires_grad_() 385 assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) 386 387 388@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) 389def test_jacobian(transform): 390 x = generate_data(transform) 391 try: 392 y = transform(x) 393 actual = transform.log_abs_det_jacobian(x, y) 394 except NotImplementedError: 395 pytest.skip("Not implemented.") 396 # Test shape 397 target_shape = x.shape[: x.dim() - transform.domain.event_dim] 398 assert actual.shape == target_shape 399 400 # Expand if required 401 transform = reshape_transform(transform, x.shape) 402 ndims = len(x.shape) 403 event_dim = ndims - transform.domain.event_dim 404 x_ = x.view((-1,) + x.shape[event_dim:]) 405 n = x_.shape[0] 406 # Reshape to squash batch dims to a single batch dim 407 transform = reshape_transform(transform, x_.shape) 408 409 # 1. Transforms with unit jacobian 410 if isinstance(transform, ReshapeTransform) or isinstance( 411 transform.inv, ReshapeTransform 412 ): 413 expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) 414 expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) 415 # 2. Transforms with 0 off-diagonal elements 416 elif transform.domain.event_dim == 0: 417 jac = jacobian(transform, x_) 418 # assert off-diagonal elements are zero 419 assert torch.allclose(jac, jac.diagonal().diag_embed()) 420 expected = jac.diagonal().abs().log().reshape(x.shape) 421 # 3. Transforms with non-0 off-diagonal elements 422 else: 423 if isinstance(transform, CorrCholeskyTransform): 424 jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) 425 elif isinstance(transform.inv, CorrCholeskyTransform): 426 jac = jacobian( 427 lambda x: transform(vec_to_tril_matrix(x, diag=-1)), 428 tril_matrix_to_vec(x_, diag=-1), 429 ) 430 elif isinstance(transform, StickBreakingTransform): 431 jac = jacobian(lambda x: transform(x)[..., :-1], x_) 432 else: 433 jac = jacobian(transform, x_) 434 435 # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) 436 # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) 437 # after reshaping the event dims (see above) to give a batched square matrix whose determinant 438 # can be computed. 439 gather_idx_shape = list(jac.shape) 440 gather_idx_shape[-2] = 1 441 gather_idxs = ( 442 torch.arange(n) 443 .reshape((n,) + (1,) * (len(jac.shape) - 1)) 444 .expand(gather_idx_shape) 445 ) 446 jac = jac.gather(-2, gather_idxs).squeeze(-2) 447 out_ndims = jac.shape[-2] 448 jac = jac[ 449 ..., :out_ndims 450 ] # Remove extra zero-valued dims (for inverse stick-breaking). 451 expected = torch.slogdet(jac).logabsdet 452 453 assert torch.allclose(actual, expected, atol=1e-5) 454 455 456@pytest.mark.parametrize( 457 "event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str 458) 459def test_compose_affine(event_dims): 460 transforms = [ 461 AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims 462 ] 463 transform = ComposeTransform(transforms) 464 assert transform.codomain.event_dim == max(event_dims) 465 assert transform.domain.event_dim == max(event_dims) 466 467 base_dist = Normal(0, 1) 468 if transform.domain.event_dim: 469 base_dist = base_dist.expand((1,) * transform.domain.event_dim) 470 dist = TransformedDistribution(base_dist, transform.parts) 471 assert dist.support.event_dim == max(event_dims) 472 473 base_dist = Dirichlet(torch.ones(5)) 474 if transform.domain.event_dim > 1: 475 base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) 476 dist = TransformedDistribution(base_dist, transforms) 477 assert dist.support.event_dim == max(1, *event_dims) 478 479 480@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str) 481def test_compose_reshape(batch_shape): 482 transforms = [ 483 ReshapeTransform((), ()), 484 ReshapeTransform((2,), (1, 2)), 485 ReshapeTransform((3, 1, 2), (6,)), 486 ReshapeTransform((6,), (2, 3)), 487 ] 488 transform = ComposeTransform(transforms) 489 assert transform.codomain.event_dim == 2 490 assert transform.domain.event_dim == 2 491 data = torch.randn(batch_shape + (3, 2)) 492 assert transform(data).shape == batch_shape + (2, 3) 493 494 dist = TransformedDistribution(Normal(data, 1), transforms) 495 assert dist.batch_shape == batch_shape 496 assert dist.event_shape == (2, 3) 497 assert dist.support.event_dim == 2 498 499 500@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str) 501@pytest.mark.parametrize("transform_dim", [0, 1, 2]) 502@pytest.mark.parametrize("base_batch_dim", [0, 1, 2]) 503@pytest.mark.parametrize("base_event_dim", [0, 1, 2]) 504@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3]) 505def test_transformed_distribution( 506 base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape 507): 508 shape = torch.Size([2, 3, 4, 5]) 509 base_dist = Normal(0, 1) 510 base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :]) 511 if base_event_dim: 512 base_dist = Independent(base_dist, base_event_dim) 513 transforms = [ 514 AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1), 515 ReshapeTransform((4, 5), (20,)), 516 ReshapeTransform((3, 20), (6, 10)), 517 ] 518 transforms = transforms[:num_transforms] 519 transform = ComposeTransform(transforms) 520 521 # Check validation in .__init__(). 522 if base_batch_dim + base_event_dim < transform.domain.event_dim: 523 with pytest.raises(ValueError): 524 TransformedDistribution(base_dist, transforms) 525 return 526 d = TransformedDistribution(base_dist, transforms) 527 528 # Check sampling is sufficiently expanded. 529 x = d.sample(sample_shape) 530 assert x.shape == sample_shape + d.batch_shape + d.event_shape 531 num_unique = len(set(x.reshape(-1).tolist())) 532 assert num_unique >= 0.9 * x.numel() 533 534 # Check log_prob shape on full samples. 535 log_prob = d.log_prob(x) 536 assert log_prob.shape == sample_shape + d.batch_shape 537 538 # Check log_prob shape on partial samples. 539 y = x 540 while y.dim() > len(d.event_shape): 541 y = y[0] 542 log_prob = d.log_prob(y) 543 assert log_prob.shape == d.batch_shape 544 545 546def test_save_load_transform(): 547 # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check 548 # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after. 549 dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)]) 550 x = torch.linspace(0, 1, 10) 551 log_prob = dist.log_prob(x) 552 stream = io.BytesIO() 553 torch.save(dist, stream) 554 stream.seek(0) 555 other = torch.load(stream) 556 assert torch.allclose(log_prob, other.log_prob(x)) 557 558 559@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) 560def test_transform_sign(transform: Transform): 561 try: 562 sign = transform.sign 563 except NotImplementedError: 564 pytest.skip("Not implemented.") 565 566 x = generate_data(transform).requires_grad_() 567 y = transform(x).sum() 568 (derivatives,) = grad(y, [x]) 569 assert torch.less(torch.as_tensor(0.0), derivatives * sign).all() 570 571 572if __name__ == "__main__": 573 run_tests() 574