1# mypy: allow-untyped-defs 2import functools 3import math 4import numbers 5import operator 6import weakref 7from typing import List 8 9import torch 10import torch.nn.functional as F 11from torch.distributions import constraints 12from torch.distributions.utils import ( 13 _sum_rightmost, 14 broadcast_all, 15 lazy_property, 16 tril_matrix_to_vec, 17 vec_to_tril_matrix, 18) 19from torch.nn.functional import pad, softplus 20 21 22__all__ = [ 23 "AbsTransform", 24 "AffineTransform", 25 "CatTransform", 26 "ComposeTransform", 27 "CorrCholeskyTransform", 28 "CumulativeDistributionTransform", 29 "ExpTransform", 30 "IndependentTransform", 31 "LowerCholeskyTransform", 32 "PositiveDefiniteTransform", 33 "PowerTransform", 34 "ReshapeTransform", 35 "SigmoidTransform", 36 "SoftplusTransform", 37 "TanhTransform", 38 "SoftmaxTransform", 39 "StackTransform", 40 "StickBreakingTransform", 41 "Transform", 42 "identity_transform", 43] 44 45 46class Transform: 47 """ 48 Abstract class for invertable transformations with computable log 49 det jacobians. They are primarily used in 50 :class:`torch.distributions.TransformedDistribution`. 51 52 Caching is useful for transforms whose inverses are either expensive or 53 numerically unstable. Note that care must be taken with memoized values 54 since the autograd graph may be reversed. For example while the following 55 works with or without caching:: 56 57 y = t(x) 58 t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. 59 60 However the following will error when caching due to dependency reversal:: 61 62 y = t(x) 63 z = t.inv(y) 64 grad(z.sum(), [y]) # error because z is x 65 66 Derived classes should implement one or both of :meth:`_call` or 67 :meth:`_inverse`. Derived classes that set `bijective=True` should also 68 implement :meth:`log_abs_det_jacobian`. 69 70 Args: 71 cache_size (int): Size of cache. If zero, no caching is done. If one, 72 the latest single value is cached. Only 0 and 1 are supported. 73 74 Attributes: 75 domain (:class:`~torch.distributions.constraints.Constraint`): 76 The constraint representing valid inputs to this transform. 77 codomain (:class:`~torch.distributions.constraints.Constraint`): 78 The constraint representing valid outputs to this transform 79 which are inputs to the inverse transform. 80 bijective (bool): Whether this transform is bijective. A transform 81 ``t`` is bijective iff ``t.inv(t(x)) == x`` and 82 ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in 83 the codomain. Transforms that are not bijective should at least 84 maintain the weaker pseudoinverse properties 85 ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. 86 sign (int or Tensor): For bijective univariate transforms, this 87 should be +1 or -1 depending on whether transform is monotone 88 increasing or decreasing. 89 """ 90 91 bijective = False 92 domain: constraints.Constraint 93 codomain: constraints.Constraint 94 95 def __init__(self, cache_size=0): 96 self._cache_size = cache_size 97 self._inv = None 98 if cache_size == 0: 99 pass # default behavior 100 elif cache_size == 1: 101 self._cached_x_y = None, None 102 else: 103 raise ValueError("cache_size must be 0 or 1") 104 super().__init__() 105 106 def __getstate__(self): 107 state = self.__dict__.copy() 108 state["_inv"] = None 109 return state 110 111 @property 112 def event_dim(self): 113 if self.domain.event_dim == self.codomain.event_dim: 114 return self.domain.event_dim 115 raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") 116 117 @property 118 def inv(self): 119 """ 120 Returns the inverse :class:`Transform` of this transform. 121 This should satisfy ``t.inv.inv is t``. 122 """ 123 inv = None 124 if self._inv is not None: 125 inv = self._inv() 126 if inv is None: 127 inv = _InverseTransform(self) 128 self._inv = weakref.ref(inv) 129 return inv 130 131 @property 132 def sign(self): 133 """ 134 Returns the sign of the determinant of the Jacobian, if applicable. 135 In general this only makes sense for bijective transforms. 136 """ 137 raise NotImplementedError 138 139 def with_cache(self, cache_size=1): 140 if self._cache_size == cache_size: 141 return self 142 if type(self).__init__ is Transform.__init__: 143 return type(self)(cache_size=cache_size) 144 raise NotImplementedError(f"{type(self)}.with_cache is not implemented") 145 146 def __eq__(self, other): 147 return self is other 148 149 def __ne__(self, other): 150 # Necessary for Python2 151 return not self.__eq__(other) 152 153 def __call__(self, x): 154 """ 155 Computes the transform `x => y`. 156 """ 157 if self._cache_size == 0: 158 return self._call(x) 159 x_old, y_old = self._cached_x_y 160 if x is x_old: 161 return y_old 162 y = self._call(x) 163 self._cached_x_y = x, y 164 return y 165 166 def _inv_call(self, y): 167 """ 168 Inverts the transform `y => x`. 169 """ 170 if self._cache_size == 0: 171 return self._inverse(y) 172 x_old, y_old = self._cached_x_y 173 if y is y_old: 174 return x_old 175 x = self._inverse(y) 176 self._cached_x_y = x, y 177 return x 178 179 def _call(self, x): 180 """ 181 Abstract method to compute forward transformation. 182 """ 183 raise NotImplementedError 184 185 def _inverse(self, y): 186 """ 187 Abstract method to compute inverse transformation. 188 """ 189 raise NotImplementedError 190 191 def log_abs_det_jacobian(self, x, y): 192 """ 193 Computes the log det jacobian `log |dy/dx|` given input and output. 194 """ 195 raise NotImplementedError 196 197 def __repr__(self): 198 return self.__class__.__name__ + "()" 199 200 def forward_shape(self, shape): 201 """ 202 Infers the shape of the forward computation, given the input shape. 203 Defaults to preserving shape. 204 """ 205 return shape 206 207 def inverse_shape(self, shape): 208 """ 209 Infers the shapes of the inverse computation, given the output shape. 210 Defaults to preserving shape. 211 """ 212 return shape 213 214 215class _InverseTransform(Transform): 216 """ 217 Inverts a single :class:`Transform`. 218 This class is private; please instead use the ``Transform.inv`` property. 219 """ 220 221 def __init__(self, transform: Transform): 222 super().__init__(cache_size=transform._cache_size) 223 self._inv: Transform = transform 224 225 @constraints.dependent_property(is_discrete=False) 226 def domain(self): 227 assert self._inv is not None 228 return self._inv.codomain 229 230 @constraints.dependent_property(is_discrete=False) 231 def codomain(self): 232 assert self._inv is not None 233 return self._inv.domain 234 235 @property 236 def bijective(self): 237 assert self._inv is not None 238 return self._inv.bijective 239 240 @property 241 def sign(self): 242 assert self._inv is not None 243 return self._inv.sign 244 245 @property 246 def inv(self): 247 return self._inv 248 249 def with_cache(self, cache_size=1): 250 assert self._inv is not None 251 return self.inv.with_cache(cache_size).inv 252 253 def __eq__(self, other): 254 if not isinstance(other, _InverseTransform): 255 return False 256 assert self._inv is not None 257 return self._inv == other._inv 258 259 def __repr__(self): 260 return f"{self.__class__.__name__}({repr(self._inv)})" 261 262 def __call__(self, x): 263 assert self._inv is not None 264 return self._inv._inv_call(x) 265 266 def log_abs_det_jacobian(self, x, y): 267 assert self._inv is not None 268 return -self._inv.log_abs_det_jacobian(y, x) 269 270 def forward_shape(self, shape): 271 return self._inv.inverse_shape(shape) 272 273 def inverse_shape(self, shape): 274 return self._inv.forward_shape(shape) 275 276 277class ComposeTransform(Transform): 278 """ 279 Composes multiple transforms in a chain. 280 The transforms being composed are responsible for caching. 281 282 Args: 283 parts (list of :class:`Transform`): A list of transforms to compose. 284 cache_size (int): Size of cache. If zero, no caching is done. If one, 285 the latest single value is cached. Only 0 and 1 are supported. 286 """ 287 288 def __init__(self, parts: List[Transform], cache_size=0): 289 if cache_size: 290 parts = [part.with_cache(cache_size) for part in parts] 291 super().__init__(cache_size=cache_size) 292 self.parts = parts 293 294 def __eq__(self, other): 295 if not isinstance(other, ComposeTransform): 296 return False 297 return self.parts == other.parts 298 299 @constraints.dependent_property(is_discrete=False) 300 def domain(self): 301 if not self.parts: 302 return constraints.real 303 domain = self.parts[0].domain 304 # Adjust event_dim to be maximum among all parts. 305 event_dim = self.parts[-1].codomain.event_dim 306 for part in reversed(self.parts): 307 event_dim += part.domain.event_dim - part.codomain.event_dim 308 event_dim = max(event_dim, part.domain.event_dim) 309 assert event_dim >= domain.event_dim 310 if event_dim > domain.event_dim: 311 domain = constraints.independent(domain, event_dim - domain.event_dim) 312 return domain 313 314 @constraints.dependent_property(is_discrete=False) 315 def codomain(self): 316 if not self.parts: 317 return constraints.real 318 codomain = self.parts[-1].codomain 319 # Adjust event_dim to be maximum among all parts. 320 event_dim = self.parts[0].domain.event_dim 321 for part in self.parts: 322 event_dim += part.codomain.event_dim - part.domain.event_dim 323 event_dim = max(event_dim, part.codomain.event_dim) 324 assert event_dim >= codomain.event_dim 325 if event_dim > codomain.event_dim: 326 codomain = constraints.independent(codomain, event_dim - codomain.event_dim) 327 return codomain 328 329 @lazy_property 330 def bijective(self): 331 return all(p.bijective for p in self.parts) 332 333 @lazy_property 334 def sign(self): 335 sign = 1 336 for p in self.parts: 337 sign = sign * p.sign 338 return sign 339 340 @property 341 def inv(self): 342 inv = None 343 if self._inv is not None: 344 inv = self._inv() 345 if inv is None: 346 inv = ComposeTransform([p.inv for p in reversed(self.parts)]) 347 self._inv = weakref.ref(inv) 348 inv._inv = weakref.ref(self) 349 return inv 350 351 def with_cache(self, cache_size=1): 352 if self._cache_size == cache_size: 353 return self 354 return ComposeTransform(self.parts, cache_size=cache_size) 355 356 def __call__(self, x): 357 for part in self.parts: 358 x = part(x) 359 return x 360 361 def log_abs_det_jacobian(self, x, y): 362 if not self.parts: 363 return torch.zeros_like(x) 364 365 # Compute intermediates. This will be free if parts[:-1] are all cached. 366 xs = [x] 367 for part in self.parts[:-1]: 368 xs.append(part(xs[-1])) 369 xs.append(y) 370 371 terms = [] 372 event_dim = self.domain.event_dim 373 for part, x, y in zip(self.parts, xs[:-1], xs[1:]): 374 terms.append( 375 _sum_rightmost( 376 part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim 377 ) 378 ) 379 event_dim += part.codomain.event_dim - part.domain.event_dim 380 return functools.reduce(operator.add, terms) 381 382 def forward_shape(self, shape): 383 for part in self.parts: 384 shape = part.forward_shape(shape) 385 return shape 386 387 def inverse_shape(self, shape): 388 for part in reversed(self.parts): 389 shape = part.inverse_shape(shape) 390 return shape 391 392 def __repr__(self): 393 fmt_string = self.__class__.__name__ + "(\n " 394 fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) 395 fmt_string += "\n)" 396 return fmt_string 397 398 399identity_transform = ComposeTransform([]) 400 401 402class IndependentTransform(Transform): 403 """ 404 Wrapper around another transform to treat 405 ``reinterpreted_batch_ndims``-many extra of the right most dimensions as 406 dependent. This has no effect on the forward or backward transforms, but 407 does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions 408 in :meth:`log_abs_det_jacobian`. 409 410 Args: 411 base_transform (:class:`Transform`): A base transform. 412 reinterpreted_batch_ndims (int): The number of extra rightmost 413 dimensions to treat as dependent. 414 """ 415 416 def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): 417 super().__init__(cache_size=cache_size) 418 self.base_transform = base_transform.with_cache(cache_size) 419 self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 420 421 def with_cache(self, cache_size=1): 422 if self._cache_size == cache_size: 423 return self 424 return IndependentTransform( 425 self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size 426 ) 427 428 @constraints.dependent_property(is_discrete=False) 429 def domain(self): 430 return constraints.independent( 431 self.base_transform.domain, self.reinterpreted_batch_ndims 432 ) 433 434 @constraints.dependent_property(is_discrete=False) 435 def codomain(self): 436 return constraints.independent( 437 self.base_transform.codomain, self.reinterpreted_batch_ndims 438 ) 439 440 @property 441 def bijective(self): 442 return self.base_transform.bijective 443 444 @property 445 def sign(self): 446 return self.base_transform.sign 447 448 def _call(self, x): 449 if x.dim() < self.domain.event_dim: 450 raise ValueError("Too few dimensions on input") 451 return self.base_transform(x) 452 453 def _inverse(self, y): 454 if y.dim() < self.codomain.event_dim: 455 raise ValueError("Too few dimensions on input") 456 return self.base_transform.inv(y) 457 458 def log_abs_det_jacobian(self, x, y): 459 result = self.base_transform.log_abs_det_jacobian(x, y) 460 result = _sum_rightmost(result, self.reinterpreted_batch_ndims) 461 return result 462 463 def __repr__(self): 464 return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" 465 466 def forward_shape(self, shape): 467 return self.base_transform.forward_shape(shape) 468 469 def inverse_shape(self, shape): 470 return self.base_transform.inverse_shape(shape) 471 472 473class ReshapeTransform(Transform): 474 """ 475 Unit Jacobian transform to reshape the rightmost part of a tensor. 476 477 Note that ``in_shape`` and ``out_shape`` must have the same number of 478 elements, just as for :meth:`torch.Tensor.reshape`. 479 480 Arguments: 481 in_shape (torch.Size): The input event shape. 482 out_shape (torch.Size): The output event shape. 483 """ 484 485 bijective = True 486 487 def __init__(self, in_shape, out_shape, cache_size=0): 488 self.in_shape = torch.Size(in_shape) 489 self.out_shape = torch.Size(out_shape) 490 if self.in_shape.numel() != self.out_shape.numel(): 491 raise ValueError("in_shape, out_shape have different numbers of elements") 492 super().__init__(cache_size=cache_size) 493 494 @constraints.dependent_property 495 def domain(self): 496 return constraints.independent(constraints.real, len(self.in_shape)) 497 498 @constraints.dependent_property 499 def codomain(self): 500 return constraints.independent(constraints.real, len(self.out_shape)) 501 502 def with_cache(self, cache_size=1): 503 if self._cache_size == cache_size: 504 return self 505 return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) 506 507 def _call(self, x): 508 batch_shape = x.shape[: x.dim() - len(self.in_shape)] 509 return x.reshape(batch_shape + self.out_shape) 510 511 def _inverse(self, y): 512 batch_shape = y.shape[: y.dim() - len(self.out_shape)] 513 return y.reshape(batch_shape + self.in_shape) 514 515 def log_abs_det_jacobian(self, x, y): 516 batch_shape = x.shape[: x.dim() - len(self.in_shape)] 517 return x.new_zeros(batch_shape) 518 519 def forward_shape(self, shape): 520 if len(shape) < len(self.in_shape): 521 raise ValueError("Too few dimensions on input") 522 cut = len(shape) - len(self.in_shape) 523 if shape[cut:] != self.in_shape: 524 raise ValueError( 525 f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" 526 ) 527 return shape[:cut] + self.out_shape 528 529 def inverse_shape(self, shape): 530 if len(shape) < len(self.out_shape): 531 raise ValueError("Too few dimensions on input") 532 cut = len(shape) - len(self.out_shape) 533 if shape[cut:] != self.out_shape: 534 raise ValueError( 535 f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" 536 ) 537 return shape[:cut] + self.in_shape 538 539 540class ExpTransform(Transform): 541 r""" 542 Transform via the mapping :math:`y = \exp(x)`. 543 """ 544 domain = constraints.real 545 codomain = constraints.positive 546 bijective = True 547 sign = +1 548 549 def __eq__(self, other): 550 return isinstance(other, ExpTransform) 551 552 def _call(self, x): 553 return x.exp() 554 555 def _inverse(self, y): 556 return y.log() 557 558 def log_abs_det_jacobian(self, x, y): 559 return x 560 561 562class PowerTransform(Transform): 563 r""" 564 Transform via the mapping :math:`y = x^{\text{exponent}}`. 565 """ 566 domain = constraints.positive 567 codomain = constraints.positive 568 bijective = True 569 570 def __init__(self, exponent, cache_size=0): 571 super().__init__(cache_size=cache_size) 572 (self.exponent,) = broadcast_all(exponent) 573 574 def with_cache(self, cache_size=1): 575 if self._cache_size == cache_size: 576 return self 577 return PowerTransform(self.exponent, cache_size=cache_size) 578 579 @lazy_property 580 def sign(self): 581 return self.exponent.sign() 582 583 def __eq__(self, other): 584 if not isinstance(other, PowerTransform): 585 return False 586 return self.exponent.eq(other.exponent).all().item() 587 588 def _call(self, x): 589 return x.pow(self.exponent) 590 591 def _inverse(self, y): 592 return y.pow(1 / self.exponent) 593 594 def log_abs_det_jacobian(self, x, y): 595 return (self.exponent * y / x).abs().log() 596 597 def forward_shape(self, shape): 598 return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) 599 600 def inverse_shape(self, shape): 601 return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) 602 603 604def _clipped_sigmoid(x): 605 finfo = torch.finfo(x.dtype) 606 return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) 607 608 609class SigmoidTransform(Transform): 610 r""" 611 Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. 612 """ 613 domain = constraints.real 614 codomain = constraints.unit_interval 615 bijective = True 616 sign = +1 617 618 def __eq__(self, other): 619 return isinstance(other, SigmoidTransform) 620 621 def _call(self, x): 622 return _clipped_sigmoid(x) 623 624 def _inverse(self, y): 625 finfo = torch.finfo(y.dtype) 626 y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) 627 return y.log() - (-y).log1p() 628 629 def log_abs_det_jacobian(self, x, y): 630 return -F.softplus(-x) - F.softplus(x) 631 632 633class SoftplusTransform(Transform): 634 r""" 635 Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. 636 The implementation reverts to the linear function when :math:`x > 20`. 637 """ 638 domain = constraints.real 639 codomain = constraints.positive 640 bijective = True 641 sign = +1 642 643 def __eq__(self, other): 644 return isinstance(other, SoftplusTransform) 645 646 def _call(self, x): 647 return softplus(x) 648 649 def _inverse(self, y): 650 return (-y).expm1().neg().log() + y 651 652 def log_abs_det_jacobian(self, x, y): 653 return -softplus(-x) 654 655 656class TanhTransform(Transform): 657 r""" 658 Transform via the mapping :math:`y = \tanh(x)`. 659 660 It is equivalent to 661 ``` 662 ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) 663 ``` 664 However this might not be numerically stable, thus it is recommended to use `TanhTransform` 665 instead. 666 667 Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. 668 669 """ 670 domain = constraints.real 671 codomain = constraints.interval(-1.0, 1.0) 672 bijective = True 673 sign = +1 674 675 def __eq__(self, other): 676 return isinstance(other, TanhTransform) 677 678 def _call(self, x): 679 return x.tanh() 680 681 def _inverse(self, y): 682 # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 683 # one should use `cache_size=1` instead 684 return torch.atanh(y) 685 686 def log_abs_det_jacobian(self, x, y): 687 # We use a formula that is more numerically stable, see details in the following link 688 # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 689 return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) 690 691 692class AbsTransform(Transform): 693 r""" 694 Transform via the mapping :math:`y = |x|`. 695 """ 696 domain = constraints.real 697 codomain = constraints.positive 698 699 def __eq__(self, other): 700 return isinstance(other, AbsTransform) 701 702 def _call(self, x): 703 return x.abs() 704 705 def _inverse(self, y): 706 return y 707 708 709class AffineTransform(Transform): 710 r""" 711 Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. 712 713 Args: 714 loc (Tensor or float): Location parameter. 715 scale (Tensor or float): Scale parameter. 716 event_dim (int): Optional size of `event_shape`. This should be zero 717 for univariate random variables, 1 for distributions over vectors, 718 2 for distributions over matrices, etc. 719 """ 720 bijective = True 721 722 def __init__(self, loc, scale, event_dim=0, cache_size=0): 723 super().__init__(cache_size=cache_size) 724 self.loc = loc 725 self.scale = scale 726 self._event_dim = event_dim 727 728 @property 729 def event_dim(self): 730 return self._event_dim 731 732 @constraints.dependent_property(is_discrete=False) 733 def domain(self): 734 if self.event_dim == 0: 735 return constraints.real 736 return constraints.independent(constraints.real, self.event_dim) 737 738 @constraints.dependent_property(is_discrete=False) 739 def codomain(self): 740 if self.event_dim == 0: 741 return constraints.real 742 return constraints.independent(constraints.real, self.event_dim) 743 744 def with_cache(self, cache_size=1): 745 if self._cache_size == cache_size: 746 return self 747 return AffineTransform( 748 self.loc, self.scale, self.event_dim, cache_size=cache_size 749 ) 750 751 def __eq__(self, other): 752 if not isinstance(other, AffineTransform): 753 return False 754 755 if isinstance(self.loc, numbers.Number) and isinstance( 756 other.loc, numbers.Number 757 ): 758 if self.loc != other.loc: 759 return False 760 else: 761 if not (self.loc == other.loc).all().item(): 762 return False 763 764 if isinstance(self.scale, numbers.Number) and isinstance( 765 other.scale, numbers.Number 766 ): 767 if self.scale != other.scale: 768 return False 769 else: 770 if not (self.scale == other.scale).all().item(): 771 return False 772 773 return True 774 775 @property 776 def sign(self): 777 if isinstance(self.scale, numbers.Real): 778 return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 779 return self.scale.sign() 780 781 def _call(self, x): 782 return self.loc + self.scale * x 783 784 def _inverse(self, y): 785 return (y - self.loc) / self.scale 786 787 def log_abs_det_jacobian(self, x, y): 788 shape = x.shape 789 scale = self.scale 790 if isinstance(scale, numbers.Real): 791 result = torch.full_like(x, math.log(abs(scale))) 792 else: 793 result = torch.abs(scale).log() 794 if self.event_dim: 795 result_size = result.size()[: -self.event_dim] + (-1,) 796 result = result.view(result_size).sum(-1) 797 shape = shape[: -self.event_dim] 798 return result.expand(shape) 799 800 def forward_shape(self, shape): 801 return torch.broadcast_shapes( 802 shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) 803 ) 804 805 def inverse_shape(self, shape): 806 return torch.broadcast_shapes( 807 shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) 808 ) 809 810 811class CorrCholeskyTransform(Transform): 812 r""" 813 Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the 814 Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower 815 triangular matrix with positive diagonals and unit Euclidean norm for each row. 816 The transform is processed as follows: 817 818 1. First we convert x into a lower triangular matrix in row order. 819 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of 820 class :class:`StickBreakingTransform` to transform :math:`X_i` into a 821 unit Euclidean length vector using the following steps: 822 - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. 823 - Transforms into an unsigned domain: :math:`z_i = r_i^2`. 824 - Applies :math:`s_i = StickBreakingTransform(z_i)`. 825 - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. 826 """ 827 domain = constraints.real_vector 828 codomain = constraints.corr_cholesky 829 bijective = True 830 831 def _call(self, x): 832 x = torch.tanh(x) 833 eps = torch.finfo(x.dtype).eps 834 x = x.clamp(min=-1 + eps, max=1 - eps) 835 r = vec_to_tril_matrix(x, diag=-1) 836 # apply stick-breaking on the squared values 837 # Note that y = sign(r) * sqrt(z * z1m_cumprod) 838 # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) 839 z = r**2 840 z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) 841 # Diagonal elements must be 1. 842 r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) 843 y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) 844 return y 845 846 def _inverse(self, y): 847 # inverse stick-breaking 848 # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html 849 y_cumsum = 1 - torch.cumsum(y * y, dim=-1) 850 y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) 851 y_vec = tril_matrix_to_vec(y, diag=-1) 852 y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) 853 t = y_vec / (y_cumsum_vec).sqrt() 854 # inverse of tanh 855 x = (t.log1p() - t.neg().log1p()) / 2 856 return x 857 858 def log_abs_det_jacobian(self, x, y, intermediates=None): 859 # Because domain and codomain are two spaces with different dimensions, determinant of 860 # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the 861 # flattened lower triangular part of `y`. 862 863 # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html 864 y1m_cumsum = 1 - (y * y).cumsum(dim=-1) 865 # by taking diagonal=-2, we don't need to shift z_cumprod to the right 866 # also works for 2 x 2 matrix 867 y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) 868 stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) 869 tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) 870 return stick_breaking_logdet + tanh_logdet 871 872 def forward_shape(self, shape): 873 # Reshape from (..., N) to (..., D, D). 874 if len(shape) < 1: 875 raise ValueError("Too few dimensions on input") 876 N = shape[-1] 877 D = round((0.25 + 2 * N) ** 0.5 + 0.5) 878 if D * (D - 1) // 2 != N: 879 raise ValueError("Input is not a flattend lower-diagonal number") 880 return shape[:-1] + (D, D) 881 882 def inverse_shape(self, shape): 883 # Reshape from (..., D, D) to (..., N). 884 if len(shape) < 2: 885 raise ValueError("Too few dimensions on input") 886 if shape[-2] != shape[-1]: 887 raise ValueError("Input is not square") 888 D = shape[-1] 889 N = D * (D - 1) // 2 890 return shape[:-2] + (N,) 891 892 893class SoftmaxTransform(Transform): 894 r""" 895 Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then 896 normalizing. 897 898 This is not bijective and cannot be used for HMC. However this acts mostly 899 coordinate-wise (except for the final normalization), and thus is 900 appropriate for coordinate-wise optimization algorithms. 901 """ 902 domain = constraints.real_vector 903 codomain = constraints.simplex 904 905 def __eq__(self, other): 906 return isinstance(other, SoftmaxTransform) 907 908 def _call(self, x): 909 logprobs = x 910 probs = (logprobs - logprobs.max(-1, True)[0]).exp() 911 return probs / probs.sum(-1, True) 912 913 def _inverse(self, y): 914 probs = y 915 return probs.log() 916 917 def forward_shape(self, shape): 918 if len(shape) < 1: 919 raise ValueError("Too few dimensions on input") 920 return shape 921 922 def inverse_shape(self, shape): 923 if len(shape) < 1: 924 raise ValueError("Too few dimensions on input") 925 return shape 926 927 928class StickBreakingTransform(Transform): 929 """ 930 Transform from unconstrained space to the simplex of one additional 931 dimension via a stick-breaking process. 932 933 This transform arises as an iterated sigmoid transform in a stick-breaking 934 construction of the `Dirichlet` distribution: the first logit is 935 transformed via sigmoid to the first probability and the probability of 936 everything else, and then the process recurses. 937 938 This is bijective and appropriate for use in HMC; however it mixes 939 coordinates together and is less appropriate for optimization. 940 """ 941 942 domain = constraints.real_vector 943 codomain = constraints.simplex 944 bijective = True 945 946 def __eq__(self, other): 947 return isinstance(other, StickBreakingTransform) 948 949 def _call(self, x): 950 offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) 951 z = _clipped_sigmoid(x - offset.log()) 952 z_cumprod = (1 - z).cumprod(-1) 953 y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) 954 return y 955 956 def _inverse(self, y): 957 y_crop = y[..., :-1] 958 offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) 959 sf = 1 - y_crop.cumsum(-1) 960 # we clamp to make sure that sf is positive which sometimes does not 961 # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 962 sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) 963 x = y_crop.log() - sf.log() + offset.log() 964 return x 965 966 def log_abs_det_jacobian(self, x, y): 967 offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) 968 x = x - offset.log() 969 # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) 970 detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) 971 return detJ 972 973 def forward_shape(self, shape): 974 if len(shape) < 1: 975 raise ValueError("Too few dimensions on input") 976 return shape[:-1] + (shape[-1] + 1,) 977 978 def inverse_shape(self, shape): 979 if len(shape) < 1: 980 raise ValueError("Too few dimensions on input") 981 return shape[:-1] + (shape[-1] - 1,) 982 983 984class LowerCholeskyTransform(Transform): 985 """ 986 Transform from unconstrained matrices to lower-triangular matrices with 987 nonnegative diagonal entries. 988 989 This is useful for parameterizing positive definite matrices in terms of 990 their Cholesky factorization. 991 """ 992 993 domain = constraints.independent(constraints.real, 2) 994 codomain = constraints.lower_cholesky 995 996 def __eq__(self, other): 997 return isinstance(other, LowerCholeskyTransform) 998 999 def _call(self, x): 1000 return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed() 1001 1002 def _inverse(self, y): 1003 return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed() 1004 1005 1006class PositiveDefiniteTransform(Transform): 1007 """ 1008 Transform from unconstrained matrices to positive-definite matrices. 1009 """ 1010 1011 domain = constraints.independent(constraints.real, 2) 1012 codomain = constraints.positive_definite # type: ignore[assignment] 1013 1014 def __eq__(self, other): 1015 return isinstance(other, PositiveDefiniteTransform) 1016 1017 def _call(self, x): 1018 x = LowerCholeskyTransform()(x) 1019 return x @ x.mT 1020 1021 def _inverse(self, y): 1022 y = torch.linalg.cholesky(y) 1023 return LowerCholeskyTransform().inv(y) 1024 1025 1026class CatTransform(Transform): 1027 """ 1028 Transform functor that applies a sequence of transforms `tseq` 1029 component-wise to each submatrix at `dim`, of length `lengths[dim]`, 1030 in a way compatible with :func:`torch.cat`. 1031 1032 Example:: 1033 1034 x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) 1035 x = torch.cat([x0, x0], dim=0) 1036 t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) 1037 t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) 1038 y = t(x) 1039 """ 1040 1041 transforms: List[Transform] 1042 1043 def __init__(self, tseq, dim=0, lengths=None, cache_size=0): 1044 assert all(isinstance(t, Transform) for t in tseq) 1045 if cache_size: 1046 tseq = [t.with_cache(cache_size) for t in tseq] 1047 super().__init__(cache_size=cache_size) 1048 self.transforms = list(tseq) 1049 if lengths is None: 1050 lengths = [1] * len(self.transforms) 1051 self.lengths = list(lengths) 1052 assert len(self.lengths) == len(self.transforms) 1053 self.dim = dim 1054 1055 @lazy_property 1056 def event_dim(self): 1057 return max(t.event_dim for t in self.transforms) 1058 1059 @lazy_property 1060 def length(self): 1061 return sum(self.lengths) 1062 1063 def with_cache(self, cache_size=1): 1064 if self._cache_size == cache_size: 1065 return self 1066 return CatTransform(self.transforms, self.dim, self.lengths, cache_size) 1067 1068 def _call(self, x): 1069 assert -x.dim() <= self.dim < x.dim() 1070 assert x.size(self.dim) == self.length 1071 yslices = [] 1072 start = 0 1073 for trans, length in zip(self.transforms, self.lengths): 1074 xslice = x.narrow(self.dim, start, length) 1075 yslices.append(trans(xslice)) 1076 start = start + length # avoid += for jit compat 1077 return torch.cat(yslices, dim=self.dim) 1078 1079 def _inverse(self, y): 1080 assert -y.dim() <= self.dim < y.dim() 1081 assert y.size(self.dim) == self.length 1082 xslices = [] 1083 start = 0 1084 for trans, length in zip(self.transforms, self.lengths): 1085 yslice = y.narrow(self.dim, start, length) 1086 xslices.append(trans.inv(yslice)) 1087 start = start + length # avoid += for jit compat 1088 return torch.cat(xslices, dim=self.dim) 1089 1090 def log_abs_det_jacobian(self, x, y): 1091 assert -x.dim() <= self.dim < x.dim() 1092 assert x.size(self.dim) == self.length 1093 assert -y.dim() <= self.dim < y.dim() 1094 assert y.size(self.dim) == self.length 1095 logdetjacs = [] 1096 start = 0 1097 for trans, length in zip(self.transforms, self.lengths): 1098 xslice = x.narrow(self.dim, start, length) 1099 yslice = y.narrow(self.dim, start, length) 1100 logdetjac = trans.log_abs_det_jacobian(xslice, yslice) 1101 if trans.event_dim < self.event_dim: 1102 logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) 1103 logdetjacs.append(logdetjac) 1104 start = start + length # avoid += for jit compat 1105 # Decide whether to concatenate or sum. 1106 dim = self.dim 1107 if dim >= 0: 1108 dim = dim - x.dim() 1109 dim = dim + self.event_dim 1110 if dim < 0: 1111 return torch.cat(logdetjacs, dim=dim) 1112 else: 1113 return sum(logdetjacs) 1114 1115 @property 1116 def bijective(self): 1117 return all(t.bijective for t in self.transforms) 1118 1119 @constraints.dependent_property 1120 def domain(self): 1121 return constraints.cat( 1122 [t.domain for t in self.transforms], self.dim, self.lengths 1123 ) 1124 1125 @constraints.dependent_property 1126 def codomain(self): 1127 return constraints.cat( 1128 [t.codomain for t in self.transforms], self.dim, self.lengths 1129 ) 1130 1131 1132class StackTransform(Transform): 1133 """ 1134 Transform functor that applies a sequence of transforms `tseq` 1135 component-wise to each submatrix at `dim` 1136 in a way compatible with :func:`torch.stack`. 1137 1138 Example:: 1139 1140 x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) 1141 t = StackTransform([ExpTransform(), identity_transform], dim=1) 1142 y = t(x) 1143 """ 1144 1145 transforms: List[Transform] 1146 1147 def __init__(self, tseq, dim=0, cache_size=0): 1148 assert all(isinstance(t, Transform) for t in tseq) 1149 if cache_size: 1150 tseq = [t.with_cache(cache_size) for t in tseq] 1151 super().__init__(cache_size=cache_size) 1152 self.transforms = list(tseq) 1153 self.dim = dim 1154 1155 def with_cache(self, cache_size=1): 1156 if self._cache_size == cache_size: 1157 return self 1158 return StackTransform(self.transforms, self.dim, cache_size) 1159 1160 def _slice(self, z): 1161 return [z.select(self.dim, i) for i in range(z.size(self.dim))] 1162 1163 def _call(self, x): 1164 assert -x.dim() <= self.dim < x.dim() 1165 assert x.size(self.dim) == len(self.transforms) 1166 yslices = [] 1167 for xslice, trans in zip(self._slice(x), self.transforms): 1168 yslices.append(trans(xslice)) 1169 return torch.stack(yslices, dim=self.dim) 1170 1171 def _inverse(self, y): 1172 assert -y.dim() <= self.dim < y.dim() 1173 assert y.size(self.dim) == len(self.transforms) 1174 xslices = [] 1175 for yslice, trans in zip(self._slice(y), self.transforms): 1176 xslices.append(trans.inv(yslice)) 1177 return torch.stack(xslices, dim=self.dim) 1178 1179 def log_abs_det_jacobian(self, x, y): 1180 assert -x.dim() <= self.dim < x.dim() 1181 assert x.size(self.dim) == len(self.transforms) 1182 assert -y.dim() <= self.dim < y.dim() 1183 assert y.size(self.dim) == len(self.transforms) 1184 logdetjacs = [] 1185 yslices = self._slice(y) 1186 xslices = self._slice(x) 1187 for xslice, yslice, trans in zip(xslices, yslices, self.transforms): 1188 logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) 1189 return torch.stack(logdetjacs, dim=self.dim) 1190 1191 @property 1192 def bijective(self): 1193 return all(t.bijective for t in self.transforms) 1194 1195 @constraints.dependent_property 1196 def domain(self): 1197 return constraints.stack([t.domain for t in self.transforms], self.dim) 1198 1199 @constraints.dependent_property 1200 def codomain(self): 1201 return constraints.stack([t.codomain for t in self.transforms], self.dim) 1202 1203 1204class CumulativeDistributionTransform(Transform): 1205 """ 1206 Transform via the cumulative distribution function of a probability distribution. 1207 1208 Args: 1209 distribution (Distribution): Distribution whose cumulative distribution function to use for 1210 the transformation. 1211 1212 Example:: 1213 1214 # Construct a Gaussian copula from a multivariate normal. 1215 base_dist = MultivariateNormal( 1216 loc=torch.zeros(2), 1217 scale_tril=LKJCholesky(2).sample(), 1218 ) 1219 transform = CumulativeDistributionTransform(Normal(0, 1)) 1220 copula = TransformedDistribution(base_dist, [transform]) 1221 """ 1222 1223 bijective = True 1224 codomain = constraints.unit_interval 1225 sign = +1 1226 1227 def __init__(self, distribution, cache_size=0): 1228 super().__init__(cache_size=cache_size) 1229 self.distribution = distribution 1230 1231 @property 1232 def domain(self): 1233 return self.distribution.support 1234 1235 def _call(self, x): 1236 return self.distribution.cdf(x) 1237 1238 def _inverse(self, y): 1239 return self.distribution.icdf(y) 1240 1241 def log_abs_det_jacobian(self, x, y): 1242 return self.distribution.log_prob(x) 1243 1244 def with_cache(self, cache_size=1): 1245 if self._cache_size == cache_size: 1246 return self 1247 return CumulativeDistributionTransform(self.distribution, cache_size=cache_size) 1248