xref: /aosp_15_r20/external/pytorch/torch/distributions/beta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from numbers import Number, Real
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.dirichlet import Dirichlet
7from torch.distributions.exp_family import ExponentialFamily
8from torch.distributions.utils import broadcast_all
9from torch.types import _size
10
11
12__all__ = ["Beta"]
13
14
15class Beta(ExponentialFamily):
16    r"""
17    Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
18
19    Example::
20
21        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
22        >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
23        >>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
24        tensor([ 0.1046])
25
26    Args:
27        concentration1 (float or Tensor): 1st concentration parameter of the distribution
28            (often referred to as alpha)
29        concentration0 (float or Tensor): 2nd concentration parameter of the distribution
30            (often referred to as beta)
31    """
32    arg_constraints = {
33        "concentration1": constraints.positive,
34        "concentration0": constraints.positive,
35    }
36    support = constraints.unit_interval
37    has_rsample = True
38
39    def __init__(self, concentration1, concentration0, validate_args=None):
40        if isinstance(concentration1, Real) and isinstance(concentration0, Real):
41            concentration1_concentration0 = torch.tensor(
42                [float(concentration1), float(concentration0)]
43            )
44        else:
45            concentration1, concentration0 = broadcast_all(
46                concentration1, concentration0
47            )
48            concentration1_concentration0 = torch.stack(
49                [concentration1, concentration0], -1
50            )
51        self._dirichlet = Dirichlet(
52            concentration1_concentration0, validate_args=validate_args
53        )
54        super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
55
56    def expand(self, batch_shape, _instance=None):
57        new = self._get_checked_instance(Beta, _instance)
58        batch_shape = torch.Size(batch_shape)
59        new._dirichlet = self._dirichlet.expand(batch_shape)
60        super(Beta, new).__init__(batch_shape, validate_args=False)
61        new._validate_args = self._validate_args
62        return new
63
64    @property
65    def mean(self):
66        return self.concentration1 / (self.concentration1 + self.concentration0)
67
68    @property
69    def mode(self):
70        return self._dirichlet.mode[..., 0]
71
72    @property
73    def variance(self):
74        total = self.concentration1 + self.concentration0
75        return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
76
77    def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
78        return self._dirichlet.rsample(sample_shape).select(-1, 0)
79
80    def log_prob(self, value):
81        if self._validate_args:
82            self._validate_sample(value)
83        heads_tails = torch.stack([value, 1.0 - value], -1)
84        return self._dirichlet.log_prob(heads_tails)
85
86    def entropy(self):
87        return self._dirichlet.entropy()
88
89    @property
90    def concentration1(self):
91        result = self._dirichlet.concentration[..., 0]
92        if isinstance(result, Number):
93            return torch.tensor([result])
94        else:
95            return result
96
97    @property
98    def concentration0(self):
99        result = self._dirichlet.concentration[..., 1]
100        if isinstance(result, Number):
101            return torch.tensor([result])
102        else:
103            return result
104
105    @property
106    def _natural_params(self):
107        return (self.concentration1, self.concentration0)
108
109    def _log_normalizer(self, x, y):
110        return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
111