xref: /aosp_15_r20/external/pytorch/torch/distributions/distribution.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from typing import Any, Dict, Optional
4from typing_extensions import deprecated
5
6import torch
7from torch.distributions import constraints
8from torch.distributions.utils import lazy_property
9from torch.types import _size
10
11
12__all__ = ["Distribution"]
13
14
15class Distribution:
16    r"""
17    Distribution is the abstract base class for probability distributions.
18    """
19
20    has_rsample = False
21    has_enumerate_support = False
22    _validate_args = __debug__
23
24    @staticmethod
25    def set_default_validate_args(value: bool) -> None:
26        """
27        Sets whether validation is enabled or disabled.
28
29        The default behavior mimics Python's ``assert`` statement: validation
30        is on by default, but is disabled if Python is run in optimized mode
31        (via ``python -O``). Validation may be expensive, so you may want to
32        disable it once a model is working.
33
34        Args:
35            value (bool): Whether to enable validation.
36        """
37        if value not in [True, False]:
38            raise ValueError
39        Distribution._validate_args = value
40
41    def __init__(
42        self,
43        batch_shape: torch.Size = torch.Size(),
44        event_shape: torch.Size = torch.Size(),
45        validate_args: Optional[bool] = None,
46    ):
47        self._batch_shape = batch_shape
48        self._event_shape = event_shape
49        if validate_args is not None:
50            self._validate_args = validate_args
51        if self._validate_args:
52            try:
53                arg_constraints = self.arg_constraints
54            except NotImplementedError:
55                arg_constraints = {}
56                warnings.warn(
57                    f"{self.__class__} does not define `arg_constraints`. "
58                    + "Please set `arg_constraints = {}` or initialize the distribution "
59                    + "with `validate_args=False` to turn off validation."
60                )
61            for param, constraint in arg_constraints.items():
62                if constraints.is_dependent(constraint):
63                    continue  # skip constraints that cannot be checked
64                if param not in self.__dict__ and isinstance(
65                    getattr(type(self), param), lazy_property
66                ):
67                    continue  # skip checking lazily-constructed args
68                value = getattr(self, param)
69                valid = constraint.check(value)
70                if not valid.all():
71                    raise ValueError(
72                        f"Expected parameter {param} "
73                        f"({type(value).__name__} of shape {tuple(value.shape)}) "
74                        f"of distribution {repr(self)} "
75                        f"to satisfy the constraint {repr(constraint)}, "
76                        f"but found invalid values:\n{value}"
77                    )
78        super().__init__()
79
80    def expand(self, batch_shape: _size, _instance=None):
81        """
82        Returns a new distribution instance (or populates an existing instance
83        provided by a derived class) with batch dimensions expanded to
84        `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
85        the distribution's parameters. As such, this does not allocate new
86        memory for the expanded distribution instance. Additionally,
87        this does not repeat any args checking or parameter broadcasting in
88        `__init__.py`, when an instance is first created.
89
90        Args:
91            batch_shape (torch.Size): the desired expanded size.
92            _instance: new instance provided by subclasses that
93                need to override `.expand`.
94
95        Returns:
96            New distribution instance with batch dimensions expanded to
97            `batch_size`.
98        """
99        raise NotImplementedError
100
101    @property
102    def batch_shape(self) -> torch.Size:
103        """
104        Returns the shape over which parameters are batched.
105        """
106        return self._batch_shape
107
108    @property
109    def event_shape(self) -> torch.Size:
110        """
111        Returns the shape of a single sample (without batching).
112        """
113        return self._event_shape
114
115    @property
116    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
117        """
118        Returns a dictionary from argument names to
119        :class:`~torch.distributions.constraints.Constraint` objects that
120        should be satisfied by each argument of this distribution. Args that
121        are not tensors need not appear in this dict.
122        """
123        raise NotImplementedError
124
125    @property
126    def support(self) -> Optional[Any]:
127        """
128        Returns a :class:`~torch.distributions.constraints.Constraint` object
129        representing this distribution's support.
130        """
131        raise NotImplementedError
132
133    @property
134    def mean(self) -> torch.Tensor:
135        """
136        Returns the mean of the distribution.
137        """
138        raise NotImplementedError
139
140    @property
141    def mode(self) -> torch.Tensor:
142        """
143        Returns the mode of the distribution.
144        """
145        raise NotImplementedError(f"{self.__class__} does not implement mode")
146
147    @property
148    def variance(self) -> torch.Tensor:
149        """
150        Returns the variance of the distribution.
151        """
152        raise NotImplementedError
153
154    @property
155    def stddev(self) -> torch.Tensor:
156        """
157        Returns the standard deviation of the distribution.
158        """
159        return self.variance.sqrt()
160
161    def sample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
162        """
163        Generates a sample_shape shaped sample or sample_shape shaped batch of
164        samples if the distribution parameters are batched.
165        """
166        with torch.no_grad():
167            return self.rsample(sample_shape)
168
169    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
170        """
171        Generates a sample_shape shaped reparameterized sample or sample_shape
172        shaped batch of reparameterized samples if the distribution parameters
173        are batched.
174        """
175        raise NotImplementedError
176
177    @deprecated(
178        "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.",
179        category=FutureWarning,
180    )
181    def sample_n(self, n: int) -> torch.Tensor:
182        """
183        Generates n samples or n batches of samples if the distribution
184        parameters are batched.
185        """
186        return self.sample(torch.Size((n,)))
187
188    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
189        """
190        Returns the log of the probability density/mass function evaluated at
191        `value`.
192
193        Args:
194            value (Tensor):
195        """
196        raise NotImplementedError
197
198    def cdf(self, value: torch.Tensor) -> torch.Tensor:
199        """
200        Returns the cumulative density/mass function evaluated at
201        `value`.
202
203        Args:
204            value (Tensor):
205        """
206        raise NotImplementedError
207
208    def icdf(self, value: torch.Tensor) -> torch.Tensor:
209        """
210        Returns the inverse cumulative density/mass function evaluated at
211        `value`.
212
213        Args:
214            value (Tensor):
215        """
216        raise NotImplementedError
217
218    def enumerate_support(self, expand: bool = True) -> torch.Tensor:
219        """
220        Returns tensor containing all values supported by a discrete
221        distribution. The result will enumerate over dimension 0, so the shape
222        of the result will be `(cardinality,) + batch_shape + event_shape`
223        (where `event_shape = ()` for univariate distributions).
224
225        Note that this enumerates over all batched tensors in lock-step
226        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
227        along dim 0, but with the remaining batch dimensions being
228        singleton dimensions, `[[0], [1], ..`.
229
230        To iterate over the full Cartesian product use
231        `itertools.product(m.enumerate_support())`.
232
233        Args:
234            expand (bool): whether to expand the support over the
235                batch dims to match the distribution's `batch_shape`.
236
237        Returns:
238            Tensor iterating over dimension 0.
239        """
240        raise NotImplementedError
241
242    def entropy(self) -> torch.Tensor:
243        """
244        Returns entropy of distribution, batched over batch_shape.
245
246        Returns:
247            Tensor of shape batch_shape.
248        """
249        raise NotImplementedError
250
251    def perplexity(self) -> torch.Tensor:
252        """
253        Returns perplexity of distribution, batched over batch_shape.
254
255        Returns:
256            Tensor of shape batch_shape.
257        """
258        return torch.exp(self.entropy())
259
260    def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size:
261        """
262        Returns the size of the sample returned by the distribution, given
263        a `sample_shape`. Note, that the batch and event shapes of a distribution
264        instance are fixed at the time of construction. If this is empty, the
265        returned shape is upcast to (1,).
266
267        Args:
268            sample_shape (torch.Size): the size of the sample to be drawn.
269        """
270        if not isinstance(sample_shape, torch.Size):
271            sample_shape = torch.Size(sample_shape)
272        return torch.Size(sample_shape + self._batch_shape + self._event_shape)
273
274    def _validate_sample(self, value: torch.Tensor) -> None:
275        """
276        Argument validation for distribution methods such as `log_prob`,
277        `cdf` and `icdf`. The rightmost dimensions of a value to be
278        scored via these methods must agree with the distribution's batch
279        and event shapes.
280
281        Args:
282            value (Tensor): the tensor whose log probability is to be
283                computed by the `log_prob` method.
284        Raises
285            ValueError: when the rightmost dimensions of `value` do not match the
286                distribution's batch and event shapes.
287        """
288        if not isinstance(value, torch.Tensor):
289            raise ValueError("The value argument to log_prob must be a Tensor")
290
291        event_dim_start = len(value.size()) - len(self._event_shape)
292        if value.size()[event_dim_start:] != self._event_shape:
293            raise ValueError(
294                f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
295            )
296
297        actual_shape = value.size()
298        expected_shape = self._batch_shape + self._event_shape
299        for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
300            if i != 1 and j != 1 and i != j:
301                raise ValueError(
302                    f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
303                )
304        try:
305            support = self.support
306        except NotImplementedError:
307            warnings.warn(
308                f"{self.__class__} does not define `support` to enable "
309                + "sample validation. Please initialize the distribution with "
310                + "`validate_args=False` to turn off validation."
311            )
312            return
313        assert support is not None
314        valid = support.check(value)
315        if not valid.all():
316            raise ValueError(
317                "Expected value argument "
318                f"({type(value).__name__} of shape {tuple(value.shape)}) "
319                f"to be within the support ({repr(support)}) "
320                f"of the distribution {repr(self)}, "
321                f"but found invalid values:\n{value}"
322            )
323
324    def _get_checked_instance(self, cls, _instance=None):
325        if _instance is None and type(self).__init__ != cls.__init__:
326            raise NotImplementedError(
327                f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
328                "must also define a custom .expand() method."
329            )
330        return self.__new__(type(self)) if _instance is None else _instance
331
332    def __repr__(self) -> str:
333        param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
334        args_string = ", ".join(
335            [
336                f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
337                for p in param_names
338            ]
339        )
340        return self.__class__.__name__ + "(" + args_string + ")"
341