1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import ( 6*da0073e9SAndroid Build Coastguard Worker broadcast_all, 7*da0073e9SAndroid Build Coastguard Worker lazy_property, 8*da0073e9SAndroid Build Coastguard Worker logits_to_probs, 9*da0073e9SAndroid Build Coastguard Worker probs_to_logits, 10*da0073e9SAndroid Build Coastguard Worker) 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker__all__ = ["Binomial"] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerdef _clamp_by_zero(x): 17*da0073e9SAndroid Build Coastguard Worker # works like clamp(x, min=0) but has grad at 0 is 0.5 18*da0073e9SAndroid Build Coastguard Worker return (x.clamp(min=0) + x - x.clamp(max=0)) / 2 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass Binomial(Distribution): 22*da0073e9SAndroid Build Coastguard Worker r""" 23*da0073e9SAndroid Build Coastguard Worker Creates a Binomial distribution parameterized by :attr:`total_count` and 24*da0073e9SAndroid Build Coastguard Worker either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be 25*da0073e9SAndroid Build Coastguard Worker broadcastable with :attr:`probs`/:attr:`logits`. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker Example:: 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 30*da0073e9SAndroid Build Coastguard Worker >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) 31*da0073e9SAndroid Build Coastguard Worker >>> x = m.sample() 32*da0073e9SAndroid Build Coastguard Worker tensor([ 0., 22., 71., 100.]) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) 35*da0073e9SAndroid Build Coastguard Worker >>> x = m.sample() 36*da0073e9SAndroid Build Coastguard Worker tensor([[ 4., 5.], 37*da0073e9SAndroid Build Coastguard Worker [ 7., 6.]]) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker Args: 40*da0073e9SAndroid Build Coastguard Worker total_count (int or Tensor): number of Bernoulli trials 41*da0073e9SAndroid Build Coastguard Worker probs (Tensor): Event probabilities 42*da0073e9SAndroid Build Coastguard Worker logits (Tensor): Event log-odds 43*da0073e9SAndroid Build Coastguard Worker """ 44*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 45*da0073e9SAndroid Build Coastguard Worker "total_count": constraints.nonnegative_integer, 46*da0073e9SAndroid Build Coastguard Worker "probs": constraints.unit_interval, 47*da0073e9SAndroid Build Coastguard Worker "logits": constraints.real, 48*da0073e9SAndroid Build Coastguard Worker } 49*da0073e9SAndroid Build Coastguard Worker has_enumerate_support = True 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): 52*da0073e9SAndroid Build Coastguard Worker if (probs is None) == (logits is None): 53*da0073e9SAndroid Build Coastguard Worker raise ValueError( 54*da0073e9SAndroid Build Coastguard Worker "Either `probs` or `logits` must be specified, but not both." 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker if probs is not None: 57*da0073e9SAndroid Build Coastguard Worker ( 58*da0073e9SAndroid Build Coastguard Worker self.total_count, 59*da0073e9SAndroid Build Coastguard Worker self.probs, 60*da0073e9SAndroid Build Coastguard Worker ) = broadcast_all(total_count, probs) 61*da0073e9SAndroid Build Coastguard Worker self.total_count = self.total_count.type_as(self.probs) 62*da0073e9SAndroid Build Coastguard Worker else: 63*da0073e9SAndroid Build Coastguard Worker ( 64*da0073e9SAndroid Build Coastguard Worker self.total_count, 65*da0073e9SAndroid Build Coastguard Worker self.logits, 66*da0073e9SAndroid Build Coastguard Worker ) = broadcast_all(total_count, logits) 67*da0073e9SAndroid Build Coastguard Worker self.total_count = self.total_count.type_as(self.logits) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker self._param = self.probs if probs is not None else self.logits 70*da0073e9SAndroid Build Coastguard Worker batch_shape = self._param.size() 71*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 74*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Binomial, _instance) 75*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 76*da0073e9SAndroid Build Coastguard Worker new.total_count = self.total_count.expand(batch_shape) 77*da0073e9SAndroid Build Coastguard Worker if "probs" in self.__dict__: 78*da0073e9SAndroid Build Coastguard Worker new.probs = self.probs.expand(batch_shape) 79*da0073e9SAndroid Build Coastguard Worker new._param = new.probs 80*da0073e9SAndroid Build Coastguard Worker if "logits" in self.__dict__: 81*da0073e9SAndroid Build Coastguard Worker new.logits = self.logits.expand(batch_shape) 82*da0073e9SAndroid Build Coastguard Worker new._param = new.logits 83*da0073e9SAndroid Build Coastguard Worker super(Binomial, new).__init__(batch_shape, validate_args=False) 84*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 85*da0073e9SAndroid Build Coastguard Worker return new 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def _new(self, *args, **kwargs): 88*da0073e9SAndroid Build Coastguard Worker return self._param.new(*args, **kwargs) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker @constraints.dependent_property(is_discrete=True, event_dim=0) 91*da0073e9SAndroid Build Coastguard Worker def support(self): 92*da0073e9SAndroid Build Coastguard Worker return constraints.integer_interval(0, self.total_count) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @property 95*da0073e9SAndroid Build Coastguard Worker def mean(self): 96*da0073e9SAndroid Build Coastguard Worker return self.total_count * self.probs 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker @property 99*da0073e9SAndroid Build Coastguard Worker def mode(self): 100*da0073e9SAndroid Build Coastguard Worker return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @property 103*da0073e9SAndroid Build Coastguard Worker def variance(self): 104*da0073e9SAndroid Build Coastguard Worker return self.total_count * self.probs * (1 - self.probs) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker @lazy_property 107*da0073e9SAndroid Build Coastguard Worker def logits(self): 108*da0073e9SAndroid Build Coastguard Worker return probs_to_logits(self.probs, is_binary=True) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker @lazy_property 111*da0073e9SAndroid Build Coastguard Worker def probs(self): 112*da0073e9SAndroid Build Coastguard Worker return logits_to_probs(self.logits, is_binary=True) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker @property 115*da0073e9SAndroid Build Coastguard Worker def param_shape(self): 116*da0073e9SAndroid Build Coastguard Worker return self._param.size() 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape=torch.Size()): 119*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 120*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 121*da0073e9SAndroid Build Coastguard Worker return torch.binomial( 122*da0073e9SAndroid Build Coastguard Worker self.total_count.expand(shape), self.probs.expand(shape) 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 126*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 127*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 128*da0073e9SAndroid Build Coastguard Worker log_factorial_n = torch.lgamma(self.total_count + 1) 129*da0073e9SAndroid Build Coastguard Worker log_factorial_k = torch.lgamma(value + 1) 130*da0073e9SAndroid Build Coastguard Worker log_factorial_nmk = torch.lgamma(self.total_count - value + 1) 131*da0073e9SAndroid Build Coastguard Worker # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) 132*da0073e9SAndroid Build Coastguard Worker # (case logit < 0) = k * logit - n * log1p(e^logit) 133*da0073e9SAndroid Build Coastguard Worker # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) 134*da0073e9SAndroid Build Coastguard Worker # = k * logit - n * logit - n * log1p(e^-logit) 135*da0073e9SAndroid Build Coastguard Worker # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) 136*da0073e9SAndroid Build Coastguard Worker normalize_term = ( 137*da0073e9SAndroid Build Coastguard Worker self.total_count * _clamp_by_zero(self.logits) 138*da0073e9SAndroid Build Coastguard Worker + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) 139*da0073e9SAndroid Build Coastguard Worker - log_factorial_n 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker return ( 142*da0073e9SAndroid Build Coastguard Worker value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def entropy(self): 146*da0073e9SAndroid Build Coastguard Worker total_count = int(self.total_count.max()) 147*da0073e9SAndroid Build Coastguard Worker if not self.total_count.min() == total_count: 148*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 149*da0073e9SAndroid Build Coastguard Worker "Inhomogeneous total count not supported by `entropy`." 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker log_prob = self.log_prob(self.enumerate_support(False)) 153*da0073e9SAndroid Build Coastguard Worker return -(torch.exp(log_prob) * log_prob).sum(0) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def enumerate_support(self, expand=True): 156*da0073e9SAndroid Build Coastguard Worker total_count = int(self.total_count.max()) 157*da0073e9SAndroid Build Coastguard Worker if not self.total_count.min() == total_count: 158*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 159*da0073e9SAndroid Build Coastguard Worker "Inhomogeneous total count not supported by `enumerate_support`." 160*da0073e9SAndroid Build Coastguard Worker ) 161*da0073e9SAndroid Build Coastguard Worker values = torch.arange( 162*da0073e9SAndroid Build Coastguard Worker 1 + total_count, dtype=self._param.dtype, device=self._param.device 163*da0073e9SAndroid Build Coastguard Worker ) 164*da0073e9SAndroid Build Coastguard Worker values = values.view((-1,) + (1,) * len(self._batch_shape)) 165*da0073e9SAndroid Build Coastguard Worker if expand: 166*da0073e9SAndroid Build Coastguard Worker values = values.expand((-1,) + self._batch_shape) 167*da0073e9SAndroid Build Coastguard Worker return values 168