xref: /aosp_15_r20/external/pytorch/torch/distributions/relaxed_bernoulli.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from numbers import Number
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.transformed_distribution import TransformedDistribution
8from torch.distributions.transforms import SigmoidTransform
9from torch.distributions.utils import (
10    broadcast_all,
11    clamp_probs,
12    lazy_property,
13    logits_to_probs,
14    probs_to_logits,
15)
16from torch.types import _size
17
18
19__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"]
20
21
22class LogitRelaxedBernoulli(Distribution):
23    r"""
24    Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs`
25    or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli
26    distribution.
27
28    Samples are logits of values in (0, 1). See [1] for more details.
29
30    Args:
31        temperature (Tensor): relaxation temperature
32        probs (Number, Tensor): the probability of sampling `1`
33        logits (Number, Tensor): the log-odds of sampling `1`
34
35    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random
36    Variables (Maddison et al., 2017)
37
38    [2] Categorical Reparametrization with Gumbel-Softmax
39    (Jang et al., 2017)
40    """
41    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
42    support = constraints.real
43
44    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
45        self.temperature = temperature
46        if (probs is None) == (logits is None):
47            raise ValueError(
48                "Either `probs` or `logits` must be specified, but not both."
49            )
50        if probs is not None:
51            is_scalar = isinstance(probs, Number)
52            (self.probs,) = broadcast_all(probs)
53        else:
54            is_scalar = isinstance(logits, Number)
55            (self.logits,) = broadcast_all(logits)
56        self._param = self.probs if probs is not None else self.logits
57        if is_scalar:
58            batch_shape = torch.Size()
59        else:
60            batch_shape = self._param.size()
61        super().__init__(batch_shape, validate_args=validate_args)
62
63    def expand(self, batch_shape, _instance=None):
64        new = self._get_checked_instance(LogitRelaxedBernoulli, _instance)
65        batch_shape = torch.Size(batch_shape)
66        new.temperature = self.temperature
67        if "probs" in self.__dict__:
68            new.probs = self.probs.expand(batch_shape)
69            new._param = new.probs
70        if "logits" in self.__dict__:
71            new.logits = self.logits.expand(batch_shape)
72            new._param = new.logits
73        super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
74        new._validate_args = self._validate_args
75        return new
76
77    def _new(self, *args, **kwargs):
78        return self._param.new(*args, **kwargs)
79
80    @lazy_property
81    def logits(self):
82        return probs_to_logits(self.probs, is_binary=True)
83
84    @lazy_property
85    def probs(self):
86        return logits_to_probs(self.logits, is_binary=True)
87
88    @property
89    def param_shape(self):
90        return self._param.size()
91
92    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
93        shape = self._extended_shape(sample_shape)
94        probs = clamp_probs(self.probs.expand(shape))
95        uniforms = clamp_probs(
96            torch.rand(shape, dtype=probs.dtype, device=probs.device)
97        )
98        return (
99            uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()
100        ) / self.temperature
101
102    def log_prob(self, value):
103        if self._validate_args:
104            self._validate_sample(value)
105        logits, value = broadcast_all(self.logits, value)
106        diff = logits - value.mul(self.temperature)
107        return self.temperature.log() + diff - 2 * diff.exp().log1p()
108
109
110class RelaxedBernoulli(TransformedDistribution):
111    r"""
112    Creates a RelaxedBernoulli distribution, parametrized by
113    :attr:`temperature`, and either :attr:`probs` or :attr:`logits`
114    (but not both). This is a relaxed version of the `Bernoulli` distribution,
115    so the values are in (0, 1), and has reparametrizable samples.
116
117    Example::
118
119        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
120        >>> m = RelaxedBernoulli(torch.tensor([2.2]),
121        ...                      torch.tensor([0.1, 0.2, 0.3, 0.99]))
122        >>> m.sample()
123        tensor([ 0.2951,  0.3442,  0.8918,  0.9021])
124
125    Args:
126        temperature (Tensor): relaxation temperature
127        probs (Number, Tensor): the probability of sampling `1`
128        logits (Number, Tensor): the log-odds of sampling `1`
129    """
130    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
131    support = constraints.unit_interval
132    has_rsample = True
133
134    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
135        base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
136        super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args)
137
138    def expand(self, batch_shape, _instance=None):
139        new = self._get_checked_instance(RelaxedBernoulli, _instance)
140        return super().expand(batch_shape, _instance=new)
141
142    @property
143    def temperature(self):
144        return self.base_dist.temperature
145
146    @property
147    def logits(self):
148        return self.base_dist.logits
149
150    @property
151    def probs(self):
152        return self.base_dist.probs
153