xref: /aosp_15_r20/external/pytorch/torch/distributions/transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import math
4import numbers
5import operator
6import weakref
7from typing import List
8
9import torch
10import torch.nn.functional as F
11from torch.distributions import constraints
12from torch.distributions.utils import (
13    _sum_rightmost,
14    broadcast_all,
15    lazy_property,
16    tril_matrix_to_vec,
17    vec_to_tril_matrix,
18)
19from torch.nn.functional import pad, softplus
20
21
22__all__ = [
23    "AbsTransform",
24    "AffineTransform",
25    "CatTransform",
26    "ComposeTransform",
27    "CorrCholeskyTransform",
28    "CumulativeDistributionTransform",
29    "ExpTransform",
30    "IndependentTransform",
31    "LowerCholeskyTransform",
32    "PositiveDefiniteTransform",
33    "PowerTransform",
34    "ReshapeTransform",
35    "SigmoidTransform",
36    "SoftplusTransform",
37    "TanhTransform",
38    "SoftmaxTransform",
39    "StackTransform",
40    "StickBreakingTransform",
41    "Transform",
42    "identity_transform",
43]
44
45
46class Transform:
47    """
48    Abstract class for invertable transformations with computable log
49    det jacobians. They are primarily used in
50    :class:`torch.distributions.TransformedDistribution`.
51
52    Caching is useful for transforms whose inverses are either expensive or
53    numerically unstable. Note that care must be taken with memoized values
54    since the autograd graph may be reversed. For example while the following
55    works with or without caching::
56
57        y = t(x)
58        t.log_abs_det_jacobian(x, y).backward()  # x will receive gradients.
59
60    However the following will error when caching due to dependency reversal::
61
62        y = t(x)
63        z = t.inv(y)
64        grad(z.sum(), [y])  # error because z is x
65
66    Derived classes should implement one or both of :meth:`_call` or
67    :meth:`_inverse`. Derived classes that set `bijective=True` should also
68    implement :meth:`log_abs_det_jacobian`.
69
70    Args:
71        cache_size (int): Size of cache. If zero, no caching is done. If one,
72            the latest single value is cached. Only 0 and 1 are supported.
73
74    Attributes:
75        domain (:class:`~torch.distributions.constraints.Constraint`):
76            The constraint representing valid inputs to this transform.
77        codomain (:class:`~torch.distributions.constraints.Constraint`):
78            The constraint representing valid outputs to this transform
79            which are inputs to the inverse transform.
80        bijective (bool): Whether this transform is bijective. A transform
81            ``t`` is bijective iff ``t.inv(t(x)) == x`` and
82            ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
83            the codomain. Transforms that are not bijective should at least
84            maintain the weaker pseudoinverse properties
85            ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
86        sign (int or Tensor): For bijective univariate transforms, this
87            should be +1 or -1 depending on whether transform is monotone
88            increasing or decreasing.
89    """
90
91    bijective = False
92    domain: constraints.Constraint
93    codomain: constraints.Constraint
94
95    def __init__(self, cache_size=0):
96        self._cache_size = cache_size
97        self._inv = None
98        if cache_size == 0:
99            pass  # default behavior
100        elif cache_size == 1:
101            self._cached_x_y = None, None
102        else:
103            raise ValueError("cache_size must be 0 or 1")
104        super().__init__()
105
106    def __getstate__(self):
107        state = self.__dict__.copy()
108        state["_inv"] = None
109        return state
110
111    @property
112    def event_dim(self):
113        if self.domain.event_dim == self.codomain.event_dim:
114            return self.domain.event_dim
115        raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
116
117    @property
118    def inv(self):
119        """
120        Returns the inverse :class:`Transform` of this transform.
121        This should satisfy ``t.inv.inv is t``.
122        """
123        inv = None
124        if self._inv is not None:
125            inv = self._inv()
126        if inv is None:
127            inv = _InverseTransform(self)
128            self._inv = weakref.ref(inv)
129        return inv
130
131    @property
132    def sign(self):
133        """
134        Returns the sign of the determinant of the Jacobian, if applicable.
135        In general this only makes sense for bijective transforms.
136        """
137        raise NotImplementedError
138
139    def with_cache(self, cache_size=1):
140        if self._cache_size == cache_size:
141            return self
142        if type(self).__init__ is Transform.__init__:
143            return type(self)(cache_size=cache_size)
144        raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
145
146    def __eq__(self, other):
147        return self is other
148
149    def __ne__(self, other):
150        # Necessary for Python2
151        return not self.__eq__(other)
152
153    def __call__(self, x):
154        """
155        Computes the transform `x => y`.
156        """
157        if self._cache_size == 0:
158            return self._call(x)
159        x_old, y_old = self._cached_x_y
160        if x is x_old:
161            return y_old
162        y = self._call(x)
163        self._cached_x_y = x, y
164        return y
165
166    def _inv_call(self, y):
167        """
168        Inverts the transform `y => x`.
169        """
170        if self._cache_size == 0:
171            return self._inverse(y)
172        x_old, y_old = self._cached_x_y
173        if y is y_old:
174            return x_old
175        x = self._inverse(y)
176        self._cached_x_y = x, y
177        return x
178
179    def _call(self, x):
180        """
181        Abstract method to compute forward transformation.
182        """
183        raise NotImplementedError
184
185    def _inverse(self, y):
186        """
187        Abstract method to compute inverse transformation.
188        """
189        raise NotImplementedError
190
191    def log_abs_det_jacobian(self, x, y):
192        """
193        Computes the log det jacobian `log |dy/dx|` given input and output.
194        """
195        raise NotImplementedError
196
197    def __repr__(self):
198        return self.__class__.__name__ + "()"
199
200    def forward_shape(self, shape):
201        """
202        Infers the shape of the forward computation, given the input shape.
203        Defaults to preserving shape.
204        """
205        return shape
206
207    def inverse_shape(self, shape):
208        """
209        Infers the shapes of the inverse computation, given the output shape.
210        Defaults to preserving shape.
211        """
212        return shape
213
214
215class _InverseTransform(Transform):
216    """
217    Inverts a single :class:`Transform`.
218    This class is private; please instead use the ``Transform.inv`` property.
219    """
220
221    def __init__(self, transform: Transform):
222        super().__init__(cache_size=transform._cache_size)
223        self._inv: Transform = transform
224
225    @constraints.dependent_property(is_discrete=False)
226    def domain(self):
227        assert self._inv is not None
228        return self._inv.codomain
229
230    @constraints.dependent_property(is_discrete=False)
231    def codomain(self):
232        assert self._inv is not None
233        return self._inv.domain
234
235    @property
236    def bijective(self):
237        assert self._inv is not None
238        return self._inv.bijective
239
240    @property
241    def sign(self):
242        assert self._inv is not None
243        return self._inv.sign
244
245    @property
246    def inv(self):
247        return self._inv
248
249    def with_cache(self, cache_size=1):
250        assert self._inv is not None
251        return self.inv.with_cache(cache_size).inv
252
253    def __eq__(self, other):
254        if not isinstance(other, _InverseTransform):
255            return False
256        assert self._inv is not None
257        return self._inv == other._inv
258
259    def __repr__(self):
260        return f"{self.__class__.__name__}({repr(self._inv)})"
261
262    def __call__(self, x):
263        assert self._inv is not None
264        return self._inv._inv_call(x)
265
266    def log_abs_det_jacobian(self, x, y):
267        assert self._inv is not None
268        return -self._inv.log_abs_det_jacobian(y, x)
269
270    def forward_shape(self, shape):
271        return self._inv.inverse_shape(shape)
272
273    def inverse_shape(self, shape):
274        return self._inv.forward_shape(shape)
275
276
277class ComposeTransform(Transform):
278    """
279    Composes multiple transforms in a chain.
280    The transforms being composed are responsible for caching.
281
282    Args:
283        parts (list of :class:`Transform`): A list of transforms to compose.
284        cache_size (int): Size of cache. If zero, no caching is done. If one,
285            the latest single value is cached. Only 0 and 1 are supported.
286    """
287
288    def __init__(self, parts: List[Transform], cache_size=0):
289        if cache_size:
290            parts = [part.with_cache(cache_size) for part in parts]
291        super().__init__(cache_size=cache_size)
292        self.parts = parts
293
294    def __eq__(self, other):
295        if not isinstance(other, ComposeTransform):
296            return False
297        return self.parts == other.parts
298
299    @constraints.dependent_property(is_discrete=False)
300    def domain(self):
301        if not self.parts:
302            return constraints.real
303        domain = self.parts[0].domain
304        # Adjust event_dim to be maximum among all parts.
305        event_dim = self.parts[-1].codomain.event_dim
306        for part in reversed(self.parts):
307            event_dim += part.domain.event_dim - part.codomain.event_dim
308            event_dim = max(event_dim, part.domain.event_dim)
309        assert event_dim >= domain.event_dim
310        if event_dim > domain.event_dim:
311            domain = constraints.independent(domain, event_dim - domain.event_dim)
312        return domain
313
314    @constraints.dependent_property(is_discrete=False)
315    def codomain(self):
316        if not self.parts:
317            return constraints.real
318        codomain = self.parts[-1].codomain
319        # Adjust event_dim to be maximum among all parts.
320        event_dim = self.parts[0].domain.event_dim
321        for part in self.parts:
322            event_dim += part.codomain.event_dim - part.domain.event_dim
323            event_dim = max(event_dim, part.codomain.event_dim)
324        assert event_dim >= codomain.event_dim
325        if event_dim > codomain.event_dim:
326            codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
327        return codomain
328
329    @lazy_property
330    def bijective(self):
331        return all(p.bijective for p in self.parts)
332
333    @lazy_property
334    def sign(self):
335        sign = 1
336        for p in self.parts:
337            sign = sign * p.sign
338        return sign
339
340    @property
341    def inv(self):
342        inv = None
343        if self._inv is not None:
344            inv = self._inv()
345        if inv is None:
346            inv = ComposeTransform([p.inv for p in reversed(self.parts)])
347            self._inv = weakref.ref(inv)
348            inv._inv = weakref.ref(self)
349        return inv
350
351    def with_cache(self, cache_size=1):
352        if self._cache_size == cache_size:
353            return self
354        return ComposeTransform(self.parts, cache_size=cache_size)
355
356    def __call__(self, x):
357        for part in self.parts:
358            x = part(x)
359        return x
360
361    def log_abs_det_jacobian(self, x, y):
362        if not self.parts:
363            return torch.zeros_like(x)
364
365        # Compute intermediates. This will be free if parts[:-1] are all cached.
366        xs = [x]
367        for part in self.parts[:-1]:
368            xs.append(part(xs[-1]))
369        xs.append(y)
370
371        terms = []
372        event_dim = self.domain.event_dim
373        for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
374            terms.append(
375                _sum_rightmost(
376                    part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
377                )
378            )
379            event_dim += part.codomain.event_dim - part.domain.event_dim
380        return functools.reduce(operator.add, terms)
381
382    def forward_shape(self, shape):
383        for part in self.parts:
384            shape = part.forward_shape(shape)
385        return shape
386
387    def inverse_shape(self, shape):
388        for part in reversed(self.parts):
389            shape = part.inverse_shape(shape)
390        return shape
391
392    def __repr__(self):
393        fmt_string = self.__class__.__name__ + "(\n    "
394        fmt_string += ",\n    ".join([p.__repr__() for p in self.parts])
395        fmt_string += "\n)"
396        return fmt_string
397
398
399identity_transform = ComposeTransform([])
400
401
402class IndependentTransform(Transform):
403    """
404    Wrapper around another transform to treat
405    ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
406    dependent. This has no effect on the forward or backward transforms, but
407    does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
408    in :meth:`log_abs_det_jacobian`.
409
410    Args:
411        base_transform (:class:`Transform`): A base transform.
412        reinterpreted_batch_ndims (int): The number of extra rightmost
413            dimensions to treat as dependent.
414    """
415
416    def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
417        super().__init__(cache_size=cache_size)
418        self.base_transform = base_transform.with_cache(cache_size)
419        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
420
421    def with_cache(self, cache_size=1):
422        if self._cache_size == cache_size:
423            return self
424        return IndependentTransform(
425            self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
426        )
427
428    @constraints.dependent_property(is_discrete=False)
429    def domain(self):
430        return constraints.independent(
431            self.base_transform.domain, self.reinterpreted_batch_ndims
432        )
433
434    @constraints.dependent_property(is_discrete=False)
435    def codomain(self):
436        return constraints.independent(
437            self.base_transform.codomain, self.reinterpreted_batch_ndims
438        )
439
440    @property
441    def bijective(self):
442        return self.base_transform.bijective
443
444    @property
445    def sign(self):
446        return self.base_transform.sign
447
448    def _call(self, x):
449        if x.dim() < self.domain.event_dim:
450            raise ValueError("Too few dimensions on input")
451        return self.base_transform(x)
452
453    def _inverse(self, y):
454        if y.dim() < self.codomain.event_dim:
455            raise ValueError("Too few dimensions on input")
456        return self.base_transform.inv(y)
457
458    def log_abs_det_jacobian(self, x, y):
459        result = self.base_transform.log_abs_det_jacobian(x, y)
460        result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
461        return result
462
463    def __repr__(self):
464        return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
465
466    def forward_shape(self, shape):
467        return self.base_transform.forward_shape(shape)
468
469    def inverse_shape(self, shape):
470        return self.base_transform.inverse_shape(shape)
471
472
473class ReshapeTransform(Transform):
474    """
475    Unit Jacobian transform to reshape the rightmost part of a tensor.
476
477    Note that ``in_shape`` and ``out_shape`` must have the same number of
478    elements, just as for :meth:`torch.Tensor.reshape`.
479
480    Arguments:
481        in_shape (torch.Size): The input event shape.
482        out_shape (torch.Size): The output event shape.
483    """
484
485    bijective = True
486
487    def __init__(self, in_shape, out_shape, cache_size=0):
488        self.in_shape = torch.Size(in_shape)
489        self.out_shape = torch.Size(out_shape)
490        if self.in_shape.numel() != self.out_shape.numel():
491            raise ValueError("in_shape, out_shape have different numbers of elements")
492        super().__init__(cache_size=cache_size)
493
494    @constraints.dependent_property
495    def domain(self):
496        return constraints.independent(constraints.real, len(self.in_shape))
497
498    @constraints.dependent_property
499    def codomain(self):
500        return constraints.independent(constraints.real, len(self.out_shape))
501
502    def with_cache(self, cache_size=1):
503        if self._cache_size == cache_size:
504            return self
505        return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
506
507    def _call(self, x):
508        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
509        return x.reshape(batch_shape + self.out_shape)
510
511    def _inverse(self, y):
512        batch_shape = y.shape[: y.dim() - len(self.out_shape)]
513        return y.reshape(batch_shape + self.in_shape)
514
515    def log_abs_det_jacobian(self, x, y):
516        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
517        return x.new_zeros(batch_shape)
518
519    def forward_shape(self, shape):
520        if len(shape) < len(self.in_shape):
521            raise ValueError("Too few dimensions on input")
522        cut = len(shape) - len(self.in_shape)
523        if shape[cut:] != self.in_shape:
524            raise ValueError(
525                f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
526            )
527        return shape[:cut] + self.out_shape
528
529    def inverse_shape(self, shape):
530        if len(shape) < len(self.out_shape):
531            raise ValueError("Too few dimensions on input")
532        cut = len(shape) - len(self.out_shape)
533        if shape[cut:] != self.out_shape:
534            raise ValueError(
535                f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
536            )
537        return shape[:cut] + self.in_shape
538
539
540class ExpTransform(Transform):
541    r"""
542    Transform via the mapping :math:`y = \exp(x)`.
543    """
544    domain = constraints.real
545    codomain = constraints.positive
546    bijective = True
547    sign = +1
548
549    def __eq__(self, other):
550        return isinstance(other, ExpTransform)
551
552    def _call(self, x):
553        return x.exp()
554
555    def _inverse(self, y):
556        return y.log()
557
558    def log_abs_det_jacobian(self, x, y):
559        return x
560
561
562class PowerTransform(Transform):
563    r"""
564    Transform via the mapping :math:`y = x^{\text{exponent}}`.
565    """
566    domain = constraints.positive
567    codomain = constraints.positive
568    bijective = True
569
570    def __init__(self, exponent, cache_size=0):
571        super().__init__(cache_size=cache_size)
572        (self.exponent,) = broadcast_all(exponent)
573
574    def with_cache(self, cache_size=1):
575        if self._cache_size == cache_size:
576            return self
577        return PowerTransform(self.exponent, cache_size=cache_size)
578
579    @lazy_property
580    def sign(self):
581        return self.exponent.sign()
582
583    def __eq__(self, other):
584        if not isinstance(other, PowerTransform):
585            return False
586        return self.exponent.eq(other.exponent).all().item()
587
588    def _call(self, x):
589        return x.pow(self.exponent)
590
591    def _inverse(self, y):
592        return y.pow(1 / self.exponent)
593
594    def log_abs_det_jacobian(self, x, y):
595        return (self.exponent * y / x).abs().log()
596
597    def forward_shape(self, shape):
598        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
599
600    def inverse_shape(self, shape):
601        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
602
603
604def _clipped_sigmoid(x):
605    finfo = torch.finfo(x.dtype)
606    return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
607
608
609class SigmoidTransform(Transform):
610    r"""
611    Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
612    """
613    domain = constraints.real
614    codomain = constraints.unit_interval
615    bijective = True
616    sign = +1
617
618    def __eq__(self, other):
619        return isinstance(other, SigmoidTransform)
620
621    def _call(self, x):
622        return _clipped_sigmoid(x)
623
624    def _inverse(self, y):
625        finfo = torch.finfo(y.dtype)
626        y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
627        return y.log() - (-y).log1p()
628
629    def log_abs_det_jacobian(self, x, y):
630        return -F.softplus(-x) - F.softplus(x)
631
632
633class SoftplusTransform(Transform):
634    r"""
635    Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
636    The implementation reverts to the linear function when :math:`x > 20`.
637    """
638    domain = constraints.real
639    codomain = constraints.positive
640    bijective = True
641    sign = +1
642
643    def __eq__(self, other):
644        return isinstance(other, SoftplusTransform)
645
646    def _call(self, x):
647        return softplus(x)
648
649    def _inverse(self, y):
650        return (-y).expm1().neg().log() + y
651
652    def log_abs_det_jacobian(self, x, y):
653        return -softplus(-x)
654
655
656class TanhTransform(Transform):
657    r"""
658    Transform via the mapping :math:`y = \tanh(x)`.
659
660    It is equivalent to
661    ```
662    ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
663    ```
664    However this might not be numerically stable, thus it is recommended to use `TanhTransform`
665    instead.
666
667    Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
668
669    """
670    domain = constraints.real
671    codomain = constraints.interval(-1.0, 1.0)
672    bijective = True
673    sign = +1
674
675    def __eq__(self, other):
676        return isinstance(other, TanhTransform)
677
678    def _call(self, x):
679        return x.tanh()
680
681    def _inverse(self, y):
682        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
683        # one should use `cache_size=1` instead
684        return torch.atanh(y)
685
686    def log_abs_det_jacobian(self, x, y):
687        # We use a formula that is more numerically stable, see details in the following link
688        # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
689        return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
690
691
692class AbsTransform(Transform):
693    r"""
694    Transform via the mapping :math:`y = |x|`.
695    """
696    domain = constraints.real
697    codomain = constraints.positive
698
699    def __eq__(self, other):
700        return isinstance(other, AbsTransform)
701
702    def _call(self, x):
703        return x.abs()
704
705    def _inverse(self, y):
706        return y
707
708
709class AffineTransform(Transform):
710    r"""
711    Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
712
713    Args:
714        loc (Tensor or float): Location parameter.
715        scale (Tensor or float): Scale parameter.
716        event_dim (int): Optional size of `event_shape`. This should be zero
717            for univariate random variables, 1 for distributions over vectors,
718            2 for distributions over matrices, etc.
719    """
720    bijective = True
721
722    def __init__(self, loc, scale, event_dim=0, cache_size=0):
723        super().__init__(cache_size=cache_size)
724        self.loc = loc
725        self.scale = scale
726        self._event_dim = event_dim
727
728    @property
729    def event_dim(self):
730        return self._event_dim
731
732    @constraints.dependent_property(is_discrete=False)
733    def domain(self):
734        if self.event_dim == 0:
735            return constraints.real
736        return constraints.independent(constraints.real, self.event_dim)
737
738    @constraints.dependent_property(is_discrete=False)
739    def codomain(self):
740        if self.event_dim == 0:
741            return constraints.real
742        return constraints.independent(constraints.real, self.event_dim)
743
744    def with_cache(self, cache_size=1):
745        if self._cache_size == cache_size:
746            return self
747        return AffineTransform(
748            self.loc, self.scale, self.event_dim, cache_size=cache_size
749        )
750
751    def __eq__(self, other):
752        if not isinstance(other, AffineTransform):
753            return False
754
755        if isinstance(self.loc, numbers.Number) and isinstance(
756            other.loc, numbers.Number
757        ):
758            if self.loc != other.loc:
759                return False
760        else:
761            if not (self.loc == other.loc).all().item():
762                return False
763
764        if isinstance(self.scale, numbers.Number) and isinstance(
765            other.scale, numbers.Number
766        ):
767            if self.scale != other.scale:
768                return False
769        else:
770            if not (self.scale == other.scale).all().item():
771                return False
772
773        return True
774
775    @property
776    def sign(self):
777        if isinstance(self.scale, numbers.Real):
778            return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
779        return self.scale.sign()
780
781    def _call(self, x):
782        return self.loc + self.scale * x
783
784    def _inverse(self, y):
785        return (y - self.loc) / self.scale
786
787    def log_abs_det_jacobian(self, x, y):
788        shape = x.shape
789        scale = self.scale
790        if isinstance(scale, numbers.Real):
791            result = torch.full_like(x, math.log(abs(scale)))
792        else:
793            result = torch.abs(scale).log()
794        if self.event_dim:
795            result_size = result.size()[: -self.event_dim] + (-1,)
796            result = result.view(result_size).sum(-1)
797            shape = shape[: -self.event_dim]
798        return result.expand(shape)
799
800    def forward_shape(self, shape):
801        return torch.broadcast_shapes(
802            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
803        )
804
805    def inverse_shape(self, shape):
806        return torch.broadcast_shapes(
807            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
808        )
809
810
811class CorrCholeskyTransform(Transform):
812    r"""
813    Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
814    Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
815    triangular matrix with positive diagonals and unit Euclidean norm for each row.
816    The transform is processed as follows:
817
818        1. First we convert x into a lower triangular matrix in row order.
819        2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
820           class :class:`StickBreakingTransform` to transform :math:`X_i` into a
821           unit Euclidean length vector using the following steps:
822           - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
823           - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
824           - Applies :math:`s_i = StickBreakingTransform(z_i)`.
825           - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
826    """
827    domain = constraints.real_vector
828    codomain = constraints.corr_cholesky
829    bijective = True
830
831    def _call(self, x):
832        x = torch.tanh(x)
833        eps = torch.finfo(x.dtype).eps
834        x = x.clamp(min=-1 + eps, max=1 - eps)
835        r = vec_to_tril_matrix(x, diag=-1)
836        # apply stick-breaking on the squared values
837        # Note that y = sign(r) * sqrt(z * z1m_cumprod)
838        #             = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
839        z = r**2
840        z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
841        # Diagonal elements must be 1.
842        r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
843        y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
844        return y
845
846    def _inverse(self, y):
847        # inverse stick-breaking
848        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
849        y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
850        y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
851        y_vec = tril_matrix_to_vec(y, diag=-1)
852        y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
853        t = y_vec / (y_cumsum_vec).sqrt()
854        # inverse of tanh
855        x = (t.log1p() - t.neg().log1p()) / 2
856        return x
857
858    def log_abs_det_jacobian(self, x, y, intermediates=None):
859        # Because domain and codomain are two spaces with different dimensions, determinant of
860        # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
861        # flattened lower triangular part of `y`.
862
863        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
864        y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
865        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
866        # also works for 2 x 2 matrix
867        y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
868        stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
869        tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
870        return stick_breaking_logdet + tanh_logdet
871
872    def forward_shape(self, shape):
873        # Reshape from (..., N) to (..., D, D).
874        if len(shape) < 1:
875            raise ValueError("Too few dimensions on input")
876        N = shape[-1]
877        D = round((0.25 + 2 * N) ** 0.5 + 0.5)
878        if D * (D - 1) // 2 != N:
879            raise ValueError("Input is not a flattend lower-diagonal number")
880        return shape[:-1] + (D, D)
881
882    def inverse_shape(self, shape):
883        # Reshape from (..., D, D) to (..., N).
884        if len(shape) < 2:
885            raise ValueError("Too few dimensions on input")
886        if shape[-2] != shape[-1]:
887            raise ValueError("Input is not square")
888        D = shape[-1]
889        N = D * (D - 1) // 2
890        return shape[:-2] + (N,)
891
892
893class SoftmaxTransform(Transform):
894    r"""
895    Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
896    normalizing.
897
898    This is not bijective and cannot be used for HMC. However this acts mostly
899    coordinate-wise (except for the final normalization), and thus is
900    appropriate for coordinate-wise optimization algorithms.
901    """
902    domain = constraints.real_vector
903    codomain = constraints.simplex
904
905    def __eq__(self, other):
906        return isinstance(other, SoftmaxTransform)
907
908    def _call(self, x):
909        logprobs = x
910        probs = (logprobs - logprobs.max(-1, True)[0]).exp()
911        return probs / probs.sum(-1, True)
912
913    def _inverse(self, y):
914        probs = y
915        return probs.log()
916
917    def forward_shape(self, shape):
918        if len(shape) < 1:
919            raise ValueError("Too few dimensions on input")
920        return shape
921
922    def inverse_shape(self, shape):
923        if len(shape) < 1:
924            raise ValueError("Too few dimensions on input")
925        return shape
926
927
928class StickBreakingTransform(Transform):
929    """
930    Transform from unconstrained space to the simplex of one additional
931    dimension via a stick-breaking process.
932
933    This transform arises as an iterated sigmoid transform in a stick-breaking
934    construction of the `Dirichlet` distribution: the first logit is
935    transformed via sigmoid to the first probability and the probability of
936    everything else, and then the process recurses.
937
938    This is bijective and appropriate for use in HMC; however it mixes
939    coordinates together and is less appropriate for optimization.
940    """
941
942    domain = constraints.real_vector
943    codomain = constraints.simplex
944    bijective = True
945
946    def __eq__(self, other):
947        return isinstance(other, StickBreakingTransform)
948
949    def _call(self, x):
950        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
951        z = _clipped_sigmoid(x - offset.log())
952        z_cumprod = (1 - z).cumprod(-1)
953        y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
954        return y
955
956    def _inverse(self, y):
957        y_crop = y[..., :-1]
958        offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
959        sf = 1 - y_crop.cumsum(-1)
960        # we clamp to make sure that sf is positive which sometimes does not
961        # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
962        sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
963        x = y_crop.log() - sf.log() + offset.log()
964        return x
965
966    def log_abs_det_jacobian(self, x, y):
967        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
968        x = x - offset.log()
969        # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
970        detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
971        return detJ
972
973    def forward_shape(self, shape):
974        if len(shape) < 1:
975            raise ValueError("Too few dimensions on input")
976        return shape[:-1] + (shape[-1] + 1,)
977
978    def inverse_shape(self, shape):
979        if len(shape) < 1:
980            raise ValueError("Too few dimensions on input")
981        return shape[:-1] + (shape[-1] - 1,)
982
983
984class LowerCholeskyTransform(Transform):
985    """
986    Transform from unconstrained matrices to lower-triangular matrices with
987    nonnegative diagonal entries.
988
989    This is useful for parameterizing positive definite matrices in terms of
990    their Cholesky factorization.
991    """
992
993    domain = constraints.independent(constraints.real, 2)
994    codomain = constraints.lower_cholesky
995
996    def __eq__(self, other):
997        return isinstance(other, LowerCholeskyTransform)
998
999    def _call(self, x):
1000        return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
1001
1002    def _inverse(self, y):
1003        return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
1004
1005
1006class PositiveDefiniteTransform(Transform):
1007    """
1008    Transform from unconstrained matrices to positive-definite matrices.
1009    """
1010
1011    domain = constraints.independent(constraints.real, 2)
1012    codomain = constraints.positive_definite  # type: ignore[assignment]
1013
1014    def __eq__(self, other):
1015        return isinstance(other, PositiveDefiniteTransform)
1016
1017    def _call(self, x):
1018        x = LowerCholeskyTransform()(x)
1019        return x @ x.mT
1020
1021    def _inverse(self, y):
1022        y = torch.linalg.cholesky(y)
1023        return LowerCholeskyTransform().inv(y)
1024
1025
1026class CatTransform(Transform):
1027    """
1028    Transform functor that applies a sequence of transforms `tseq`
1029    component-wise to each submatrix at `dim`, of length `lengths[dim]`,
1030    in a way compatible with :func:`torch.cat`.
1031
1032    Example::
1033
1034       x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
1035       x = torch.cat([x0, x0], dim=0)
1036       t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
1037       t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
1038       y = t(x)
1039    """
1040
1041    transforms: List[Transform]
1042
1043    def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
1044        assert all(isinstance(t, Transform) for t in tseq)
1045        if cache_size:
1046            tseq = [t.with_cache(cache_size) for t in tseq]
1047        super().__init__(cache_size=cache_size)
1048        self.transforms = list(tseq)
1049        if lengths is None:
1050            lengths = [1] * len(self.transforms)
1051        self.lengths = list(lengths)
1052        assert len(self.lengths) == len(self.transforms)
1053        self.dim = dim
1054
1055    @lazy_property
1056    def event_dim(self):
1057        return max(t.event_dim for t in self.transforms)
1058
1059    @lazy_property
1060    def length(self):
1061        return sum(self.lengths)
1062
1063    def with_cache(self, cache_size=1):
1064        if self._cache_size == cache_size:
1065            return self
1066        return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
1067
1068    def _call(self, x):
1069        assert -x.dim() <= self.dim < x.dim()
1070        assert x.size(self.dim) == self.length
1071        yslices = []
1072        start = 0
1073        for trans, length in zip(self.transforms, self.lengths):
1074            xslice = x.narrow(self.dim, start, length)
1075            yslices.append(trans(xslice))
1076            start = start + length  # avoid += for jit compat
1077        return torch.cat(yslices, dim=self.dim)
1078
1079    def _inverse(self, y):
1080        assert -y.dim() <= self.dim < y.dim()
1081        assert y.size(self.dim) == self.length
1082        xslices = []
1083        start = 0
1084        for trans, length in zip(self.transforms, self.lengths):
1085            yslice = y.narrow(self.dim, start, length)
1086            xslices.append(trans.inv(yslice))
1087            start = start + length  # avoid += for jit compat
1088        return torch.cat(xslices, dim=self.dim)
1089
1090    def log_abs_det_jacobian(self, x, y):
1091        assert -x.dim() <= self.dim < x.dim()
1092        assert x.size(self.dim) == self.length
1093        assert -y.dim() <= self.dim < y.dim()
1094        assert y.size(self.dim) == self.length
1095        logdetjacs = []
1096        start = 0
1097        for trans, length in zip(self.transforms, self.lengths):
1098            xslice = x.narrow(self.dim, start, length)
1099            yslice = y.narrow(self.dim, start, length)
1100            logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
1101            if trans.event_dim < self.event_dim:
1102                logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
1103            logdetjacs.append(logdetjac)
1104            start = start + length  # avoid += for jit compat
1105        # Decide whether to concatenate or sum.
1106        dim = self.dim
1107        if dim >= 0:
1108            dim = dim - x.dim()
1109        dim = dim + self.event_dim
1110        if dim < 0:
1111            return torch.cat(logdetjacs, dim=dim)
1112        else:
1113            return sum(logdetjacs)
1114
1115    @property
1116    def bijective(self):
1117        return all(t.bijective for t in self.transforms)
1118
1119    @constraints.dependent_property
1120    def domain(self):
1121        return constraints.cat(
1122            [t.domain for t in self.transforms], self.dim, self.lengths
1123        )
1124
1125    @constraints.dependent_property
1126    def codomain(self):
1127        return constraints.cat(
1128            [t.codomain for t in self.transforms], self.dim, self.lengths
1129        )
1130
1131
1132class StackTransform(Transform):
1133    """
1134    Transform functor that applies a sequence of transforms `tseq`
1135    component-wise to each submatrix at `dim`
1136    in a way compatible with :func:`torch.stack`.
1137
1138    Example::
1139
1140       x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
1141       t = StackTransform([ExpTransform(), identity_transform], dim=1)
1142       y = t(x)
1143    """
1144
1145    transforms: List[Transform]
1146
1147    def __init__(self, tseq, dim=0, cache_size=0):
1148        assert all(isinstance(t, Transform) for t in tseq)
1149        if cache_size:
1150            tseq = [t.with_cache(cache_size) for t in tseq]
1151        super().__init__(cache_size=cache_size)
1152        self.transforms = list(tseq)
1153        self.dim = dim
1154
1155    def with_cache(self, cache_size=1):
1156        if self._cache_size == cache_size:
1157            return self
1158        return StackTransform(self.transforms, self.dim, cache_size)
1159
1160    def _slice(self, z):
1161        return [z.select(self.dim, i) for i in range(z.size(self.dim))]
1162
1163    def _call(self, x):
1164        assert -x.dim() <= self.dim < x.dim()
1165        assert x.size(self.dim) == len(self.transforms)
1166        yslices = []
1167        for xslice, trans in zip(self._slice(x), self.transforms):
1168            yslices.append(trans(xslice))
1169        return torch.stack(yslices, dim=self.dim)
1170
1171    def _inverse(self, y):
1172        assert -y.dim() <= self.dim < y.dim()
1173        assert y.size(self.dim) == len(self.transforms)
1174        xslices = []
1175        for yslice, trans in zip(self._slice(y), self.transforms):
1176            xslices.append(trans.inv(yslice))
1177        return torch.stack(xslices, dim=self.dim)
1178
1179    def log_abs_det_jacobian(self, x, y):
1180        assert -x.dim() <= self.dim < x.dim()
1181        assert x.size(self.dim) == len(self.transforms)
1182        assert -y.dim() <= self.dim < y.dim()
1183        assert y.size(self.dim) == len(self.transforms)
1184        logdetjacs = []
1185        yslices = self._slice(y)
1186        xslices = self._slice(x)
1187        for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
1188            logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
1189        return torch.stack(logdetjacs, dim=self.dim)
1190
1191    @property
1192    def bijective(self):
1193        return all(t.bijective for t in self.transforms)
1194
1195    @constraints.dependent_property
1196    def domain(self):
1197        return constraints.stack([t.domain for t in self.transforms], self.dim)
1198
1199    @constraints.dependent_property
1200    def codomain(self):
1201        return constraints.stack([t.codomain for t in self.transforms], self.dim)
1202
1203
1204class CumulativeDistributionTransform(Transform):
1205    """
1206    Transform via the cumulative distribution function of a probability distribution.
1207
1208    Args:
1209        distribution (Distribution): Distribution whose cumulative distribution function to use for
1210            the transformation.
1211
1212    Example::
1213
1214        # Construct a Gaussian copula from a multivariate normal.
1215        base_dist = MultivariateNormal(
1216            loc=torch.zeros(2),
1217            scale_tril=LKJCholesky(2).sample(),
1218        )
1219        transform = CumulativeDistributionTransform(Normal(0, 1))
1220        copula = TransformedDistribution(base_dist, [transform])
1221    """
1222
1223    bijective = True
1224    codomain = constraints.unit_interval
1225    sign = +1
1226
1227    def __init__(self, distribution, cache_size=0):
1228        super().__init__(cache_size=cache_size)
1229        self.distribution = distribution
1230
1231    @property
1232    def domain(self):
1233        return self.distribution.support
1234
1235    def _call(self, x):
1236        return self.distribution.cdf(x)
1237
1238    def _inverse(self, y):
1239        return self.distribution.icdf(y)
1240
1241    def log_abs_det_jacobian(self, x, y):
1242        return self.distribution.log_prob(x)
1243
1244    def with_cache(self, cache_size=1):
1245        if self._cache_size == cache_size:
1246            return self
1247        return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)
1248