xref: /aosp_15_r20/external/pytorch/torch/distributions/independent.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.utils import _sum_rightmost
8from torch.types import _size
9
10
11__all__ = ["Independent"]
12
13
14class Independent(Distribution):
15    r"""
16    Reinterprets some of the batch dims of a distribution as event dims.
17
18    This is mainly useful for changing the shape of the result of
19    :meth:`log_prob`. For example to create a diagonal Normal distribution with
20    the same shape as a Multivariate Normal distribution (so they are
21    interchangeable), you can::
22
23        >>> from torch.distributions.multivariate_normal import MultivariateNormal
24        >>> from torch.distributions.normal import Normal
25        >>> loc = torch.zeros(3)
26        >>> scale = torch.ones(3)
27        >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
28        >>> [mvn.batch_shape, mvn.event_shape]
29        [torch.Size([]), torch.Size([3])]
30        >>> normal = Normal(loc, scale)
31        >>> [normal.batch_shape, normal.event_shape]
32        [torch.Size([3]), torch.Size([])]
33        >>> diagn = Independent(normal, 1)
34        >>> [diagn.batch_shape, diagn.event_shape]
35        [torch.Size([]), torch.Size([3])]
36
37    Args:
38        base_distribution (torch.distributions.distribution.Distribution): a
39            base distribution
40        reinterpreted_batch_ndims (int): the number of batch dims to
41            reinterpret as event dims
42    """
43    arg_constraints: Dict[str, constraints.Constraint] = {}
44
45    def __init__(
46        self, base_distribution, reinterpreted_batch_ndims, validate_args=None
47    ):
48        if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
49            raise ValueError(
50                "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
51                f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
52            )
53        shape = base_distribution.batch_shape + base_distribution.event_shape
54        event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
55        batch_shape = shape[: len(shape) - event_dim]
56        event_shape = shape[len(shape) - event_dim :]
57        self.base_dist = base_distribution
58        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
59        super().__init__(batch_shape, event_shape, validate_args=validate_args)
60
61    def expand(self, batch_shape, _instance=None):
62        new = self._get_checked_instance(Independent, _instance)
63        batch_shape = torch.Size(batch_shape)
64        new.base_dist = self.base_dist.expand(
65            batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
66        )
67        new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
68        super(Independent, new).__init__(
69            batch_shape, self.event_shape, validate_args=False
70        )
71        new._validate_args = self._validate_args
72        return new
73
74    @property
75    def has_rsample(self):
76        return self.base_dist.has_rsample
77
78    @property
79    def has_enumerate_support(self):
80        if self.reinterpreted_batch_ndims > 0:
81            return False
82        return self.base_dist.has_enumerate_support
83
84    @constraints.dependent_property
85    def support(self):
86        result = self.base_dist.support
87        if self.reinterpreted_batch_ndims:
88            result = constraints.independent(result, self.reinterpreted_batch_ndims)
89        return result
90
91    @property
92    def mean(self):
93        return self.base_dist.mean
94
95    @property
96    def mode(self):
97        return self.base_dist.mode
98
99    @property
100    def variance(self):
101        return self.base_dist.variance
102
103    def sample(self, sample_shape=torch.Size()):
104        return self.base_dist.sample(sample_shape)
105
106    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
107        return self.base_dist.rsample(sample_shape)
108
109    def log_prob(self, value):
110        log_prob = self.base_dist.log_prob(value)
111        return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
112
113    def entropy(self):
114        entropy = self.base_dist.entropy()
115        return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
116
117    def enumerate_support(self, expand=True):
118        if self.reinterpreted_batch_ndims > 0:
119            raise NotImplementedError(
120                "Enumeration over cartesian product is not implemented"
121            )
122        return self.base_dist.enumerate_support(expand=expand)
123
124    def __repr__(self):
125        return (
126            self.__class__.__name__
127            + f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
128        )
129