xref: /aosp_15_r20/external/pytorch/torch/distributions/transformed_distribution.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.independent import Independent
8from torch.distributions.transforms import ComposeTransform, Transform
9from torch.distributions.utils import _sum_rightmost
10from torch.types import _size
11
12
13__all__ = ["TransformedDistribution"]
14
15
16class TransformedDistribution(Distribution):
17    r"""
18    Extension of the Distribution class, which applies a sequence of Transforms
19    to a base distribution.  Let f be the composition of transforms applied::
20
21        X ~ BaseDistribution
22        Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
23        log p(Y) = log p(X) + log |det (dX/dY)|
24
25    Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
26    maximum shape of its base distribution and its transforms, since transforms
27    can introduce correlations among events.
28
29    An example for the usage of :class:`TransformedDistribution` would be::
30
31        # Building a Logistic Distribution
32        # X ~ Uniform(0, 1)
33        # f = a + b * logit(X)
34        # Y ~ f(X) ~ Logistic(a, b)
35        base_distribution = Uniform(0, 1)
36        transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
37        logistic = TransformedDistribution(base_distribution, transforms)
38
39    For more examples, please look at the implementations of
40    :class:`~torch.distributions.gumbel.Gumbel`,
41    :class:`~torch.distributions.half_cauchy.HalfCauchy`,
42    :class:`~torch.distributions.half_normal.HalfNormal`,
43    :class:`~torch.distributions.log_normal.LogNormal`,
44    :class:`~torch.distributions.pareto.Pareto`,
45    :class:`~torch.distributions.weibull.Weibull`,
46    :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
47    :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
48    """
49    arg_constraints: Dict[str, constraints.Constraint] = {}
50
51    def __init__(self, base_distribution, transforms, validate_args=None):
52        if isinstance(transforms, Transform):
53            self.transforms = [
54                transforms,
55            ]
56        elif isinstance(transforms, list):
57            if not all(isinstance(t, Transform) for t in transforms):
58                raise ValueError(
59                    "transforms must be a Transform or a list of Transforms"
60                )
61            self.transforms = transforms
62        else:
63            raise ValueError(
64                f"transforms must be a Transform or list, but was {transforms}"
65            )
66
67        # Reshape base_distribution according to transforms.
68        base_shape = base_distribution.batch_shape + base_distribution.event_shape
69        base_event_dim = len(base_distribution.event_shape)
70        transform = ComposeTransform(self.transforms)
71        if len(base_shape) < transform.domain.event_dim:
72            raise ValueError(
73                f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}."
74            )
75        forward_shape = transform.forward_shape(base_shape)
76        expanded_base_shape = transform.inverse_shape(forward_shape)
77        if base_shape != expanded_base_shape:
78            base_batch_shape = expanded_base_shape[
79                : len(expanded_base_shape) - base_event_dim
80            ]
81            base_distribution = base_distribution.expand(base_batch_shape)
82        reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
83        if reinterpreted_batch_ndims > 0:
84            base_distribution = Independent(
85                base_distribution, reinterpreted_batch_ndims
86            )
87        self.base_dist = base_distribution
88
89        # Compute shapes.
90        transform_change_in_event_dim = (
91            transform.codomain.event_dim - transform.domain.event_dim
92        )
93        event_dim = max(
94            transform.codomain.event_dim,  # the transform is coupled
95            base_event_dim + transform_change_in_event_dim,  # the base dist is coupled
96        )
97        assert len(forward_shape) >= event_dim
98        cut = len(forward_shape) - event_dim
99        batch_shape = forward_shape[:cut]
100        event_shape = forward_shape[cut:]
101        super().__init__(batch_shape, event_shape, validate_args=validate_args)
102
103    def expand(self, batch_shape, _instance=None):
104        new = self._get_checked_instance(TransformedDistribution, _instance)
105        batch_shape = torch.Size(batch_shape)
106        shape = batch_shape + self.event_shape
107        for t in reversed(self.transforms):
108            shape = t.inverse_shape(shape)
109        base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
110        new.base_dist = self.base_dist.expand(base_batch_shape)
111        new.transforms = self.transforms
112        super(TransformedDistribution, new).__init__(
113            batch_shape, self.event_shape, validate_args=False
114        )
115        new._validate_args = self._validate_args
116        return new
117
118    @constraints.dependent_property(is_discrete=False)
119    def support(self):
120        if not self.transforms:
121            return self.base_dist.support
122        support = self.transforms[-1].codomain
123        if len(self.event_shape) > support.event_dim:
124            support = constraints.independent(
125                support, len(self.event_shape) - support.event_dim
126            )
127        return support
128
129    @property
130    def has_rsample(self):
131        return self.base_dist.has_rsample
132
133    def sample(self, sample_shape=torch.Size()):
134        """
135        Generates a sample_shape shaped sample or sample_shape shaped batch of
136        samples if the distribution parameters are batched. Samples first from
137        base distribution and applies `transform()` for every transform in the
138        list.
139        """
140        with torch.no_grad():
141            x = self.base_dist.sample(sample_shape)
142            for transform in self.transforms:
143                x = transform(x)
144            return x
145
146    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
147        """
148        Generates a sample_shape shaped reparameterized sample or sample_shape
149        shaped batch of reparameterized samples if the distribution parameters
150        are batched. Samples first from base distribution and applies
151        `transform()` for every transform in the list.
152        """
153        x = self.base_dist.rsample(sample_shape)
154        for transform in self.transforms:
155            x = transform(x)
156        return x
157
158    def log_prob(self, value):
159        """
160        Scores the sample by inverting the transform(s) and computing the score
161        using the score of the base distribution and the log abs det jacobian.
162        """
163        if self._validate_args:
164            self._validate_sample(value)
165        event_dim = len(self.event_shape)
166        log_prob = 0.0
167        y = value
168        for transform in reversed(self.transforms):
169            x = transform.inv(y)
170            event_dim += transform.domain.event_dim - transform.codomain.event_dim
171            log_prob = log_prob - _sum_rightmost(
172                transform.log_abs_det_jacobian(x, y),
173                event_dim - transform.domain.event_dim,
174            )
175            y = x
176
177        log_prob = log_prob + _sum_rightmost(
178            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
179        )
180        return log_prob
181
182    def _monotonize_cdf(self, value):
183        """
184        This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
185        monotone increasing.
186        """
187        sign = 1
188        for transform in self.transforms:
189            sign = sign * transform.sign
190        if isinstance(sign, int) and sign == 1:
191            return value
192        return sign * (value - 0.5) + 0.5
193
194    def cdf(self, value):
195        """
196        Computes the cumulative distribution function by inverting the
197        transform(s) and computing the score of the base distribution.
198        """
199        for transform in self.transforms[::-1]:
200            value = transform.inv(value)
201        if self._validate_args:
202            self.base_dist._validate_sample(value)
203        value = self.base_dist.cdf(value)
204        value = self._monotonize_cdf(value)
205        return value
206
207    def icdf(self, value):
208        """
209        Computes the inverse cumulative distribution function using
210        transform(s) and computing the score of the base distribution.
211        """
212        value = self._monotonize_cdf(value)
213        value = self.base_dist.icdf(value)
214        for transform in self.transforms:
215            value = transform(value)
216        return value
217