xref: /aosp_15_r20/external/pytorch/torch/distributions/beta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number, Real
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.dirichlet import Dirichlet
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily
8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all
9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker__all__ = ["Beta"]
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass Beta(ExponentialFamily):
16*da0073e9SAndroid Build Coastguard Worker    r"""
17*da0073e9SAndroid Build Coastguard Worker    Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    Example::
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
22*da0073e9SAndroid Build Coastguard Worker        >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
23*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
24*da0073e9SAndroid Build Coastguard Worker        tensor([ 0.1046])
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    Args:
27*da0073e9SAndroid Build Coastguard Worker        concentration1 (float or Tensor): 1st concentration parameter of the distribution
28*da0073e9SAndroid Build Coastguard Worker            (often referred to as alpha)
29*da0073e9SAndroid Build Coastguard Worker        concentration0 (float or Tensor): 2nd concentration parameter of the distribution
30*da0073e9SAndroid Build Coastguard Worker            (often referred to as beta)
31*da0073e9SAndroid Build Coastguard Worker    """
32*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {
33*da0073e9SAndroid Build Coastguard Worker        "concentration1": constraints.positive,
34*da0073e9SAndroid Build Coastguard Worker        "concentration0": constraints.positive,
35*da0073e9SAndroid Build Coastguard Worker    }
36*da0073e9SAndroid Build Coastguard Worker    support = constraints.unit_interval
37*da0073e9SAndroid Build Coastguard Worker    has_rsample = True
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    def __init__(self, concentration1, concentration0, validate_args=None):
40*da0073e9SAndroid Build Coastguard Worker        if isinstance(concentration1, Real) and isinstance(concentration0, Real):
41*da0073e9SAndroid Build Coastguard Worker            concentration1_concentration0 = torch.tensor(
42*da0073e9SAndroid Build Coastguard Worker                [float(concentration1), float(concentration0)]
43*da0073e9SAndroid Build Coastguard Worker            )
44*da0073e9SAndroid Build Coastguard Worker        else:
45*da0073e9SAndroid Build Coastguard Worker            concentration1, concentration0 = broadcast_all(
46*da0073e9SAndroid Build Coastguard Worker                concentration1, concentration0
47*da0073e9SAndroid Build Coastguard Worker            )
48*da0073e9SAndroid Build Coastguard Worker            concentration1_concentration0 = torch.stack(
49*da0073e9SAndroid Build Coastguard Worker                [concentration1, concentration0], -1
50*da0073e9SAndroid Build Coastguard Worker            )
51*da0073e9SAndroid Build Coastguard Worker        self._dirichlet = Dirichlet(
52*da0073e9SAndroid Build Coastguard Worker            concentration1_concentration0, validate_args=validate_args
53*da0073e9SAndroid Build Coastguard Worker        )
54*da0073e9SAndroid Build Coastguard Worker        super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
57*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Beta, _instance)
58*da0073e9SAndroid Build Coastguard Worker        batch_shape = torch.Size(batch_shape)
59*da0073e9SAndroid Build Coastguard Worker        new._dirichlet = self._dirichlet.expand(batch_shape)
60*da0073e9SAndroid Build Coastguard Worker        super(Beta, new).__init__(batch_shape, validate_args=False)
61*da0073e9SAndroid Build Coastguard Worker        new._validate_args = self._validate_args
62*da0073e9SAndroid Build Coastguard Worker        return new
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @property
65*da0073e9SAndroid Build Coastguard Worker    def mean(self):
66*da0073e9SAndroid Build Coastguard Worker        return self.concentration1 / (self.concentration1 + self.concentration0)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    @property
69*da0073e9SAndroid Build Coastguard Worker    def mode(self):
70*da0073e9SAndroid Build Coastguard Worker        return self._dirichlet.mode[..., 0]
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    @property
73*da0073e9SAndroid Build Coastguard Worker    def variance(self):
74*da0073e9SAndroid Build Coastguard Worker        total = self.concentration1 + self.concentration0
75*da0073e9SAndroid Build Coastguard Worker        return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
78*da0073e9SAndroid Build Coastguard Worker        return self._dirichlet.rsample(sample_shape).select(-1, 0)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value):
81*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
82*da0073e9SAndroid Build Coastguard Worker            self._validate_sample(value)
83*da0073e9SAndroid Build Coastguard Worker        heads_tails = torch.stack([value, 1.0 - value], -1)
84*da0073e9SAndroid Build Coastguard Worker        return self._dirichlet.log_prob(heads_tails)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
87*da0073e9SAndroid Build Coastguard Worker        return self._dirichlet.entropy()
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    @property
90*da0073e9SAndroid Build Coastguard Worker    def concentration1(self):
91*da0073e9SAndroid Build Coastguard Worker        result = self._dirichlet.concentration[..., 0]
92*da0073e9SAndroid Build Coastguard Worker        if isinstance(result, Number):
93*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([result])
94*da0073e9SAndroid Build Coastguard Worker        else:
95*da0073e9SAndroid Build Coastguard Worker            return result
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    @property
98*da0073e9SAndroid Build Coastguard Worker    def concentration0(self):
99*da0073e9SAndroid Build Coastguard Worker        result = self._dirichlet.concentration[..., 1]
100*da0073e9SAndroid Build Coastguard Worker        if isinstance(result, Number):
101*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([result])
102*da0073e9SAndroid Build Coastguard Worker        else:
103*da0073e9SAndroid Build Coastguard Worker            return result
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    @property
106*da0073e9SAndroid Build Coastguard Worker    def _natural_params(self):
107*da0073e9SAndroid Build Coastguard Worker        return (self.concentration1, self.concentration0)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def _log_normalizer(self, x, y):
110*da0073e9SAndroid Build Coastguard Worker        return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
111