xref: /aosp_15_r20/external/pytorch/torch/distributions/mixture_same_family.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict
3
4import torch
5from torch.distributions import Categorical, constraints
6from torch.distributions.distribution import Distribution
7
8
9__all__ = ["MixtureSameFamily"]
10
11
12class MixtureSameFamily(Distribution):
13    r"""
14    The `MixtureSameFamily` distribution implements a (batch of) mixture
15    distribution where all component are from different parameterizations of
16    the same distribution type. It is parameterized by a `Categorical`
17    "selecting distribution" (over `k` component) and a component
18    distribution, i.e., a `Distribution` with a rightmost batch shape
19    (equal to `[k]`) which indexes each (batch of) component.
20
21    Examples::
22
23        >>> # xdoctest: +SKIP("undefined vars")
24        >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
25        >>> # weighted normal distributions
26        >>> mix = D.Categorical(torch.ones(5,))
27        >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
28        >>> gmm = MixtureSameFamily(mix, comp)
29
30        >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
31        >>> # weighted bivariate normal distributions
32        >>> mix = D.Categorical(torch.ones(5,))
33        >>> comp = D.Independent(D.Normal(
34        ...          torch.randn(5,2), torch.rand(5,2)), 1)
35        >>> gmm = MixtureSameFamily(mix, comp)
36
37        >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
38        >>> # consisting of 5 random weighted bivariate normal distributions
39        >>> mix = D.Categorical(torch.rand(3,5))
40        >>> comp = D.Independent(D.Normal(
41        ...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
42        >>> gmm = MixtureSameFamily(mix, comp)
43
44    Args:
45        mixture_distribution: `torch.distributions.Categorical`-like
46            instance. Manages the probability of selecting component.
47            The number of categories must match the rightmost batch
48            dimension of the `component_distribution`. Must have either
49            scalar `batch_shape` or `batch_shape` matching
50            `component_distribution.batch_shape[:-1]`
51        component_distribution: `torch.distributions.Distribution`-like
52            instance. Right-most batch dimension indexes component.
53    """
54    arg_constraints: Dict[str, constraints.Constraint] = {}
55    has_rsample = False
56
57    def __init__(
58        self, mixture_distribution, component_distribution, validate_args=None
59    ):
60        self._mixture_distribution = mixture_distribution
61        self._component_distribution = component_distribution
62
63        if not isinstance(self._mixture_distribution, Categorical):
64            raise ValueError(
65                " The Mixture distribution needs to be an "
66                " instance of torch.distributions.Categorical"
67            )
68
69        if not isinstance(self._component_distribution, Distribution):
70            raise ValueError(
71                "The Component distribution need to be an "
72                "instance of torch.distributions.Distribution"
73            )
74
75        # Check that batch size matches
76        mdbs = self._mixture_distribution.batch_shape
77        cdbs = self._component_distribution.batch_shape[:-1]
78        for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
79            if size1 != 1 and size2 != 1 and size1 != size2:
80                raise ValueError(
81                    f"`mixture_distribution.batch_shape` ({mdbs}) is not "
82                    "compatible with `component_distribution."
83                    f"batch_shape`({cdbs})"
84                )
85
86        # Check that the number of mixture component matches
87        km = self._mixture_distribution.logits.shape[-1]
88        kc = self._component_distribution.batch_shape[-1]
89        if km is not None and kc is not None and km != kc:
90            raise ValueError(
91                f"`mixture_distribution component` ({km}) does not"
92                " equal `component_distribution.batch_shape[-1]`"
93                f" ({kc})"
94            )
95        self._num_component = km
96
97        event_shape = self._component_distribution.event_shape
98        self._event_ndims = len(event_shape)
99        super().__init__(
100            batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
101        )
102
103    def expand(self, batch_shape, _instance=None):
104        batch_shape = torch.Size(batch_shape)
105        batch_shape_comp = batch_shape + (self._num_component,)
106        new = self._get_checked_instance(MixtureSameFamily, _instance)
107        new._component_distribution = self._component_distribution.expand(
108            batch_shape_comp
109        )
110        new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
111        new._num_component = self._num_component
112        new._event_ndims = self._event_ndims
113        event_shape = new._component_distribution.event_shape
114        super(MixtureSameFamily, new).__init__(
115            batch_shape=batch_shape, event_shape=event_shape, validate_args=False
116        )
117        new._validate_args = self._validate_args
118        return new
119
120    @constraints.dependent_property
121    def support(self):
122        # FIXME this may have the wrong shape when support contains batched
123        # parameters
124        return self._component_distribution.support
125
126    @property
127    def mixture_distribution(self):
128        return self._mixture_distribution
129
130    @property
131    def component_distribution(self):
132        return self._component_distribution
133
134    @property
135    def mean(self):
136        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
137        return torch.sum(
138            probs * self.component_distribution.mean, dim=-1 - self._event_ndims
139        )  # [B, E]
140
141    @property
142    def variance(self):
143        # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
144        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
145        mean_cond_var = torch.sum(
146            probs * self.component_distribution.variance, dim=-1 - self._event_ndims
147        )
148        var_cond_mean = torch.sum(
149            probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
150            dim=-1 - self._event_ndims,
151        )
152        return mean_cond_var + var_cond_mean
153
154    def cdf(self, x):
155        x = self._pad(x)
156        cdf_x = self.component_distribution.cdf(x)
157        mix_prob = self.mixture_distribution.probs
158
159        return torch.sum(cdf_x * mix_prob, dim=-1)
160
161    def log_prob(self, x):
162        if self._validate_args:
163            self._validate_sample(x)
164        x = self._pad(x)
165        log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k]
166        log_mix_prob = torch.log_softmax(
167            self.mixture_distribution.logits, dim=-1
168        )  # [B, k]
169        return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B]
170
171    def sample(self, sample_shape=torch.Size()):
172        with torch.no_grad():
173            sample_len = len(sample_shape)
174            batch_len = len(self.batch_shape)
175            gather_dim = sample_len + batch_len
176            es = self.event_shape
177
178            # mixture samples [n, B]
179            mix_sample = self.mixture_distribution.sample(sample_shape)
180            mix_shape = mix_sample.shape
181
182            # component samples [n, B, k, E]
183            comp_samples = self.component_distribution.sample(sample_shape)
184
185            # Gather along the k dimension
186            mix_sample_r = mix_sample.reshape(
187                mix_shape + torch.Size([1] * (len(es) + 1))
188            )
189            mix_sample_r = mix_sample_r.repeat(
190                torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
191            )
192
193            samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
194            return samples.squeeze(gather_dim)
195
196    def _pad(self, x):
197        return x.unsqueeze(-1 - self._event_ndims)
198
199    def _pad_mixture_dimensions(self, x):
200        dist_batch_ndims = len(self.batch_shape)
201        cat_batch_ndims = len(self.mixture_distribution.batch_shape)
202        pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
203        xs = x.shape
204        x = x.reshape(
205            xs[:-1]
206            + torch.Size(pad_ndims * [1])
207            + xs[-1:]
208            + torch.Size(self._event_ndims * [1])
209        )
210        return x
211
212    def __repr__(self):
213        args_string = (
214            f"\n  {self.mixture_distribution},\n  {self.component_distribution}"
215        )
216        return "MixtureSameFamily" + "(" + args_string + ")"
217