1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import ( 8*da0073e9SAndroid Build Coastguard Worker broadcast_all, 9*da0073e9SAndroid Build Coastguard Worker lazy_property, 10*da0073e9SAndroid Build Coastguard Worker logits_to_probs, 11*da0073e9SAndroid Build Coastguard Worker probs_to_logits, 12*da0073e9SAndroid Build Coastguard Worker) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import binary_cross_entropy_with_logits 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker__all__ = ["Geometric"] 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass Geometric(Distribution): 20*da0073e9SAndroid Build Coastguard Worker r""" 21*da0073e9SAndroid Build Coastguard Worker Creates a Geometric distribution parameterized by :attr:`probs`, 22*da0073e9SAndroid Build Coastguard Worker where :attr:`probs` is the probability of success of Bernoulli trials. 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker .. math:: 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker P(X=k) = (1-p)^{k} p, k = 0, 1, ... 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker .. note:: 29*da0073e9SAndroid Build Coastguard Worker :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success 30*da0073e9SAndroid Build Coastguard Worker hence draws samples in :math:`\{0, 1, \ldots\}`, whereas 31*da0073e9SAndroid Build Coastguard Worker :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`. 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker Example:: 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 36*da0073e9SAndroid Build Coastguard Worker >>> m = Geometric(torch.tensor([0.3])) 37*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 38*da0073e9SAndroid Build Coastguard Worker tensor([ 2.]) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker Args: 41*da0073e9SAndroid Build Coastguard Worker probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] 42*da0073e9SAndroid Build Coastguard Worker logits (Number, Tensor): the log-odds of sampling `1`. 43*da0073e9SAndroid Build Coastguard Worker """ 44*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 45*da0073e9SAndroid Build Coastguard Worker support = constraints.nonnegative_integer 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def __init__(self, probs=None, logits=None, validate_args=None): 48*da0073e9SAndroid Build Coastguard Worker if (probs is None) == (logits is None): 49*da0073e9SAndroid Build Coastguard Worker raise ValueError( 50*da0073e9SAndroid Build Coastguard Worker "Either `probs` or `logits` must be specified, but not both." 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker if probs is not None: 53*da0073e9SAndroid Build Coastguard Worker (self.probs,) = broadcast_all(probs) 54*da0073e9SAndroid Build Coastguard Worker else: 55*da0073e9SAndroid Build Coastguard Worker (self.logits,) = broadcast_all(logits) 56*da0073e9SAndroid Build Coastguard Worker probs_or_logits = probs if probs is not None else logits 57*da0073e9SAndroid Build Coastguard Worker if isinstance(probs_or_logits, Number): 58*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size() 59*da0073e9SAndroid Build Coastguard Worker else: 60*da0073e9SAndroid Build Coastguard Worker batch_shape = probs_or_logits.size() 61*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 62*da0073e9SAndroid Build Coastguard Worker if self._validate_args and probs is not None: 63*da0073e9SAndroid Build Coastguard Worker # Add an extra check beyond unit_interval 64*da0073e9SAndroid Build Coastguard Worker value = self.probs 65*da0073e9SAndroid Build Coastguard Worker valid = value > 0 66*da0073e9SAndroid Build Coastguard Worker if not valid.all(): 67*da0073e9SAndroid Build Coastguard Worker invalid_value = value.data[~valid] 68*da0073e9SAndroid Build Coastguard Worker raise ValueError( 69*da0073e9SAndroid Build Coastguard Worker "Expected parameter probs " 70*da0073e9SAndroid Build Coastguard Worker f"({type(value).__name__} of shape {tuple(value.shape)}) " 71*da0073e9SAndroid Build Coastguard Worker f"of distribution {repr(self)} " 72*da0073e9SAndroid Build Coastguard Worker f"to be positive but found invalid values:\n{invalid_value}" 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 76*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Geometric, _instance) 77*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 78*da0073e9SAndroid Build Coastguard Worker if "probs" in self.__dict__: 79*da0073e9SAndroid Build Coastguard Worker new.probs = self.probs.expand(batch_shape) 80*da0073e9SAndroid Build Coastguard Worker if "logits" in self.__dict__: 81*da0073e9SAndroid Build Coastguard Worker new.logits = self.logits.expand(batch_shape) 82*da0073e9SAndroid Build Coastguard Worker super(Geometric, new).__init__(batch_shape, validate_args=False) 83*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 84*da0073e9SAndroid Build Coastguard Worker return new 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @property 87*da0073e9SAndroid Build Coastguard Worker def mean(self): 88*da0073e9SAndroid Build Coastguard Worker return 1.0 / self.probs - 1.0 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker @property 91*da0073e9SAndroid Build Coastguard Worker def mode(self): 92*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(self.probs) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @property 95*da0073e9SAndroid Build Coastguard Worker def variance(self): 96*da0073e9SAndroid Build Coastguard Worker return (1.0 / self.probs - 1.0) / self.probs 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker @lazy_property 99*da0073e9SAndroid Build Coastguard Worker def logits(self): 100*da0073e9SAndroid Build Coastguard Worker return probs_to_logits(self.probs, is_binary=True) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @lazy_property 103*da0073e9SAndroid Build Coastguard Worker def probs(self): 104*da0073e9SAndroid Build Coastguard Worker return logits_to_probs(self.logits, is_binary=True) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape=torch.Size()): 107*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 108*da0073e9SAndroid Build Coastguard Worker tiny = torch.finfo(self.probs.dtype).tiny 109*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 110*da0073e9SAndroid Build Coastguard Worker if torch._C._get_tracing_state(): 111*da0073e9SAndroid Build Coastguard Worker # [JIT WORKAROUND] lack of support for .uniform_() 112*da0073e9SAndroid Build Coastguard Worker u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) 113*da0073e9SAndroid Build Coastguard Worker u = u.clamp(min=tiny) 114*da0073e9SAndroid Build Coastguard Worker else: 115*da0073e9SAndroid Build Coastguard Worker u = self.probs.new(shape).uniform_(tiny, 1) 116*da0073e9SAndroid Build Coastguard Worker return (u.log() / (-self.probs).log1p()).floor() 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 119*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 120*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 121*da0073e9SAndroid Build Coastguard Worker value, probs = broadcast_all(value, self.probs) 122*da0073e9SAndroid Build Coastguard Worker probs = probs.clone(memory_format=torch.contiguous_format) 123*da0073e9SAndroid Build Coastguard Worker probs[(probs == 1) & (value == 0)] = 0 124*da0073e9SAndroid Build Coastguard Worker return value * (-probs).log1p() + self.probs.log() 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def entropy(self): 127*da0073e9SAndroid Build Coastguard Worker return ( 128*da0073e9SAndroid Build Coastguard Worker binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none") 129*da0073e9SAndroid Build Coastguard Worker / self.probs 130*da0073e9SAndroid Build Coastguard Worker ) 131