1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import nan 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import ( 9*da0073e9SAndroid Build Coastguard Worker broadcast_all, 10*da0073e9SAndroid Build Coastguard Worker lazy_property, 11*da0073e9SAndroid Build Coastguard Worker logits_to_probs, 12*da0073e9SAndroid Build Coastguard Worker probs_to_logits, 13*da0073e9SAndroid Build Coastguard Worker) 14*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import binary_cross_entropy_with_logits 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker__all__ = ["Bernoulli"] 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass Bernoulli(ExponentialFamily): 21*da0073e9SAndroid Build Coastguard Worker r""" 22*da0073e9SAndroid Build Coastguard Worker Creates a Bernoulli distribution parameterized by :attr:`probs` 23*da0073e9SAndroid Build Coastguard Worker or :attr:`logits` (but not both). 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker Samples are binary (0 or 1). They take the value `1` with probability `p` 26*da0073e9SAndroid Build Coastguard Worker and `0` with probability `1 - p`. 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker Example:: 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 31*da0073e9SAndroid Build Coastguard Worker >>> m = Bernoulli(torch.tensor([0.3])) 32*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # 30% chance 1; 70% chance 0 33*da0073e9SAndroid Build Coastguard Worker tensor([ 0.]) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker Args: 36*da0073e9SAndroid Build Coastguard Worker probs (Number, Tensor): the probability of sampling `1` 37*da0073e9SAndroid Build Coastguard Worker logits (Number, Tensor): the log-odds of sampling `1` 38*da0073e9SAndroid Build Coastguard Worker """ 39*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 40*da0073e9SAndroid Build Coastguard Worker support = constraints.boolean 41*da0073e9SAndroid Build Coastguard Worker has_enumerate_support = True 42*da0073e9SAndroid Build Coastguard Worker _mean_carrier_measure = 0 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def __init__(self, probs=None, logits=None, validate_args=None): 45*da0073e9SAndroid Build Coastguard Worker if (probs is None) == (logits is None): 46*da0073e9SAndroid Build Coastguard Worker raise ValueError( 47*da0073e9SAndroid Build Coastguard Worker "Either `probs` or `logits` must be specified, but not both." 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker if probs is not None: 50*da0073e9SAndroid Build Coastguard Worker is_scalar = isinstance(probs, Number) 51*da0073e9SAndroid Build Coastguard Worker (self.probs,) = broadcast_all(probs) 52*da0073e9SAndroid Build Coastguard Worker else: 53*da0073e9SAndroid Build Coastguard Worker is_scalar = isinstance(logits, Number) 54*da0073e9SAndroid Build Coastguard Worker (self.logits,) = broadcast_all(logits) 55*da0073e9SAndroid Build Coastguard Worker self._param = self.probs if probs is not None else self.logits 56*da0073e9SAndroid Build Coastguard Worker if is_scalar: 57*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size() 58*da0073e9SAndroid Build Coastguard Worker else: 59*da0073e9SAndroid Build Coastguard Worker batch_shape = self._param.size() 60*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 63*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Bernoulli, _instance) 64*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 65*da0073e9SAndroid Build Coastguard Worker if "probs" in self.__dict__: 66*da0073e9SAndroid Build Coastguard Worker new.probs = self.probs.expand(batch_shape) 67*da0073e9SAndroid Build Coastguard Worker new._param = new.probs 68*da0073e9SAndroid Build Coastguard Worker if "logits" in self.__dict__: 69*da0073e9SAndroid Build Coastguard Worker new.logits = self.logits.expand(batch_shape) 70*da0073e9SAndroid Build Coastguard Worker new._param = new.logits 71*da0073e9SAndroid Build Coastguard Worker super(Bernoulli, new).__init__(batch_shape, validate_args=False) 72*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 73*da0073e9SAndroid Build Coastguard Worker return new 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def _new(self, *args, **kwargs): 76*da0073e9SAndroid Build Coastguard Worker return self._param.new(*args, **kwargs) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker @property 79*da0073e9SAndroid Build Coastguard Worker def mean(self): 80*da0073e9SAndroid Build Coastguard Worker return self.probs 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @property 83*da0073e9SAndroid Build Coastguard Worker def mode(self): 84*da0073e9SAndroid Build Coastguard Worker mode = (self.probs >= 0.5).to(self.probs) 85*da0073e9SAndroid Build Coastguard Worker mode[self.probs == 0.5] = nan 86*da0073e9SAndroid Build Coastguard Worker return mode 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker @property 89*da0073e9SAndroid Build Coastguard Worker def variance(self): 90*da0073e9SAndroid Build Coastguard Worker return self.probs * (1 - self.probs) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker @lazy_property 93*da0073e9SAndroid Build Coastguard Worker def logits(self): 94*da0073e9SAndroid Build Coastguard Worker return probs_to_logits(self.probs, is_binary=True) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker @lazy_property 97*da0073e9SAndroid Build Coastguard Worker def probs(self): 98*da0073e9SAndroid Build Coastguard Worker return logits_to_probs(self.logits, is_binary=True) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @property 101*da0073e9SAndroid Build Coastguard Worker def param_shape(self): 102*da0073e9SAndroid Build Coastguard Worker return self._param.size() 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape=torch.Size()): 105*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 106*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 107*da0073e9SAndroid Build Coastguard Worker return torch.bernoulli(self.probs.expand(shape)) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 110*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 111*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 112*da0073e9SAndroid Build Coastguard Worker logits, value = broadcast_all(self.logits, value) 113*da0073e9SAndroid Build Coastguard Worker return -binary_cross_entropy_with_logits(logits, value, reduction="none") 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker def entropy(self): 116*da0073e9SAndroid Build Coastguard Worker return binary_cross_entropy_with_logits( 117*da0073e9SAndroid Build Coastguard Worker self.logits, self.probs, reduction="none" 118*da0073e9SAndroid Build Coastguard Worker ) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def enumerate_support(self, expand=True): 121*da0073e9SAndroid Build Coastguard Worker values = torch.arange(2, dtype=self._param.dtype, device=self._param.device) 122*da0073e9SAndroid Build Coastguard Worker values = values.view((-1,) + (1,) * len(self._batch_shape)) 123*da0073e9SAndroid Build Coastguard Worker if expand: 124*da0073e9SAndroid Build Coastguard Worker values = values.expand((-1,) + self._batch_shape) 125*da0073e9SAndroid Build Coastguard Worker return values 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @property 128*da0073e9SAndroid Build Coastguard Worker def _natural_params(self): 129*da0073e9SAndroid Build Coastguard Worker return (torch.logit(self.probs),) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def _log_normalizer(self, x): 132*da0073e9SAndroid Build Coastguard Worker return torch.log1p(torch.exp(x)) 133