xref: /aosp_15_r20/external/pytorch/torch/distributions/bernoulli.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import nan
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily
8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import (
9*da0073e9SAndroid Build Coastguard Worker    broadcast_all,
10*da0073e9SAndroid Build Coastguard Worker    lazy_property,
11*da0073e9SAndroid Build Coastguard Worker    logits_to_probs,
12*da0073e9SAndroid Build Coastguard Worker    probs_to_logits,
13*da0073e9SAndroid Build Coastguard Worker)
14*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import binary_cross_entropy_with_logits
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker__all__ = ["Bernoulli"]
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass Bernoulli(ExponentialFamily):
21*da0073e9SAndroid Build Coastguard Worker    r"""
22*da0073e9SAndroid Build Coastguard Worker    Creates a Bernoulli distribution parameterized by :attr:`probs`
23*da0073e9SAndroid Build Coastguard Worker    or :attr:`logits` (but not both).
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    Samples are binary (0 or 1). They take the value `1` with probability `p`
26*da0073e9SAndroid Build Coastguard Worker    and `0` with probability `1 - p`.
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    Example::
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
31*da0073e9SAndroid Build Coastguard Worker        >>> m = Bernoulli(torch.tensor([0.3]))
32*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # 30% chance 1; 70% chance 0
33*da0073e9SAndroid Build Coastguard Worker        tensor([ 0.])
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    Args:
36*da0073e9SAndroid Build Coastguard Worker        probs (Number, Tensor): the probability of sampling `1`
37*da0073e9SAndroid Build Coastguard Worker        logits (Number, Tensor): the log-odds of sampling `1`
38*da0073e9SAndroid Build Coastguard Worker    """
39*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
40*da0073e9SAndroid Build Coastguard Worker    support = constraints.boolean
41*da0073e9SAndroid Build Coastguard Worker    has_enumerate_support = True
42*da0073e9SAndroid Build Coastguard Worker    _mean_carrier_measure = 0
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    def __init__(self, probs=None, logits=None, validate_args=None):
45*da0073e9SAndroid Build Coastguard Worker        if (probs is None) == (logits is None):
46*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
47*da0073e9SAndroid Build Coastguard Worker                "Either `probs` or `logits` must be specified, but not both."
48*da0073e9SAndroid Build Coastguard Worker            )
49*da0073e9SAndroid Build Coastguard Worker        if probs is not None:
50*da0073e9SAndroid Build Coastguard Worker            is_scalar = isinstance(probs, Number)
51*da0073e9SAndroid Build Coastguard Worker            (self.probs,) = broadcast_all(probs)
52*da0073e9SAndroid Build Coastguard Worker        else:
53*da0073e9SAndroid Build Coastguard Worker            is_scalar = isinstance(logits, Number)
54*da0073e9SAndroid Build Coastguard Worker            (self.logits,) = broadcast_all(logits)
55*da0073e9SAndroid Build Coastguard Worker        self._param = self.probs if probs is not None else self.logits
56*da0073e9SAndroid Build Coastguard Worker        if is_scalar:
57*da0073e9SAndroid Build Coastguard Worker            batch_shape = torch.Size()
58*da0073e9SAndroid Build Coastguard Worker        else:
59*da0073e9SAndroid Build Coastguard Worker            batch_shape = self._param.size()
60*da0073e9SAndroid Build Coastguard Worker        super().__init__(batch_shape, validate_args=validate_args)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
63*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Bernoulli, _instance)
64*da0073e9SAndroid Build Coastguard Worker        batch_shape = torch.Size(batch_shape)
65*da0073e9SAndroid Build Coastguard Worker        if "probs" in self.__dict__:
66*da0073e9SAndroid Build Coastguard Worker            new.probs = self.probs.expand(batch_shape)
67*da0073e9SAndroid Build Coastguard Worker            new._param = new.probs
68*da0073e9SAndroid Build Coastguard Worker        if "logits" in self.__dict__:
69*da0073e9SAndroid Build Coastguard Worker            new.logits = self.logits.expand(batch_shape)
70*da0073e9SAndroid Build Coastguard Worker            new._param = new.logits
71*da0073e9SAndroid Build Coastguard Worker        super(Bernoulli, new).__init__(batch_shape, validate_args=False)
72*da0073e9SAndroid Build Coastguard Worker        new._validate_args = self._validate_args
73*da0073e9SAndroid Build Coastguard Worker        return new
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def _new(self, *args, **kwargs):
76*da0073e9SAndroid Build Coastguard Worker        return self._param.new(*args, **kwargs)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    @property
79*da0073e9SAndroid Build Coastguard Worker    def mean(self):
80*da0073e9SAndroid Build Coastguard Worker        return self.probs
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @property
83*da0073e9SAndroid Build Coastguard Worker    def mode(self):
84*da0073e9SAndroid Build Coastguard Worker        mode = (self.probs >= 0.5).to(self.probs)
85*da0073e9SAndroid Build Coastguard Worker        mode[self.probs == 0.5] = nan
86*da0073e9SAndroid Build Coastguard Worker        return mode
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    @property
89*da0073e9SAndroid Build Coastguard Worker    def variance(self):
90*da0073e9SAndroid Build Coastguard Worker        return self.probs * (1 - self.probs)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    @lazy_property
93*da0073e9SAndroid Build Coastguard Worker    def logits(self):
94*da0073e9SAndroid Build Coastguard Worker        return probs_to_logits(self.probs, is_binary=True)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    @lazy_property
97*da0073e9SAndroid Build Coastguard Worker    def probs(self):
98*da0073e9SAndroid Build Coastguard Worker        return logits_to_probs(self.logits, is_binary=True)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    @property
101*da0073e9SAndroid Build Coastguard Worker    def param_shape(self):
102*da0073e9SAndroid Build Coastguard Worker        return self._param.size()
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def sample(self, sample_shape=torch.Size()):
105*da0073e9SAndroid Build Coastguard Worker        shape = self._extended_shape(sample_shape)
106*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
107*da0073e9SAndroid Build Coastguard Worker            return torch.bernoulli(self.probs.expand(shape))
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value):
110*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
111*da0073e9SAndroid Build Coastguard Worker            self._validate_sample(value)
112*da0073e9SAndroid Build Coastguard Worker        logits, value = broadcast_all(self.logits, value)
113*da0073e9SAndroid Build Coastguard Worker        return -binary_cross_entropy_with_logits(logits, value, reduction="none")
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
116*da0073e9SAndroid Build Coastguard Worker        return binary_cross_entropy_with_logits(
117*da0073e9SAndroid Build Coastguard Worker            self.logits, self.probs, reduction="none"
118*da0073e9SAndroid Build Coastguard Worker        )
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def enumerate_support(self, expand=True):
121*da0073e9SAndroid Build Coastguard Worker        values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
122*da0073e9SAndroid Build Coastguard Worker        values = values.view((-1,) + (1,) * len(self._batch_shape))
123*da0073e9SAndroid Build Coastguard Worker        if expand:
124*da0073e9SAndroid Build Coastguard Worker            values = values.expand((-1,) + self._batch_shape)
125*da0073e9SAndroid Build Coastguard Worker        return values
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @property
128*da0073e9SAndroid Build Coastguard Worker    def _natural_params(self):
129*da0073e9SAndroid Build Coastguard Worker        return (torch.logit(self.probs),)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def _log_normalizer(self, x):
132*da0073e9SAndroid Build Coastguard Worker        return torch.log1p(torch.exp(x))
133