xref: /aosp_15_r20/external/pytorch/torch/distributions/exponential.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from numbers import Number
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.exp_family import ExponentialFamily
7from torch.distributions.utils import broadcast_all
8from torch.types import _size
9
10
11__all__ = ["Exponential"]
12
13
14class Exponential(ExponentialFamily):
15    r"""
16    Creates a Exponential distribution parameterized by :attr:`rate`.
17
18    Example::
19
20        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
21        >>> m = Exponential(torch.tensor([1.0]))
22        >>> m.sample()  # Exponential distributed with rate=1
23        tensor([ 0.1046])
24
25    Args:
26        rate (float or Tensor): rate = 1 / scale of the distribution
27    """
28    arg_constraints = {"rate": constraints.positive}
29    support = constraints.nonnegative
30    has_rsample = True
31    _mean_carrier_measure = 0
32
33    @property
34    def mean(self):
35        return self.rate.reciprocal()
36
37    @property
38    def mode(self):
39        return torch.zeros_like(self.rate)
40
41    @property
42    def stddev(self):
43        return self.rate.reciprocal()
44
45    @property
46    def variance(self):
47        return self.rate.pow(-2)
48
49    def __init__(self, rate, validate_args=None):
50        (self.rate,) = broadcast_all(rate)
51        batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
52        super().__init__(batch_shape, validate_args=validate_args)
53
54    def expand(self, batch_shape, _instance=None):
55        new = self._get_checked_instance(Exponential, _instance)
56        batch_shape = torch.Size(batch_shape)
57        new.rate = self.rate.expand(batch_shape)
58        super(Exponential, new).__init__(batch_shape, validate_args=False)
59        new._validate_args = self._validate_args
60        return new
61
62    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
63        shape = self._extended_shape(sample_shape)
64        return self.rate.new(shape).exponential_() / self.rate
65
66    def log_prob(self, value):
67        if self._validate_args:
68            self._validate_sample(value)
69        return self.rate.log() - self.rate * value
70
71    def cdf(self, value):
72        if self._validate_args:
73            self._validate_sample(value)
74        return 1 - torch.exp(-self.rate * value)
75
76    def icdf(self, value):
77        return -torch.log1p(-value) / self.rate
78
79    def entropy(self):
80        return 1.0 - torch.log(self.rate)
81
82    @property
83    def _natural_params(self):
84        return (-self.rate,)
85
86    def _log_normalizer(self, x):
87        return -torch.log(-x)
88