1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number, Real 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.dirichlet import Dirichlet 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker__all__ = ["Beta"] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass Beta(ExponentialFamily): 16*da0073e9SAndroid Build Coastguard Worker r""" 17*da0073e9SAndroid Build Coastguard Worker Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker Example:: 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 22*da0073e9SAndroid Build Coastguard Worker >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) 23*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 24*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1046]) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker Args: 27*da0073e9SAndroid Build Coastguard Worker concentration1 (float or Tensor): 1st concentration parameter of the distribution 28*da0073e9SAndroid Build Coastguard Worker (often referred to as alpha) 29*da0073e9SAndroid Build Coastguard Worker concentration0 (float or Tensor): 2nd concentration parameter of the distribution 30*da0073e9SAndroid Build Coastguard Worker (often referred to as beta) 31*da0073e9SAndroid Build Coastguard Worker """ 32*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 33*da0073e9SAndroid Build Coastguard Worker "concentration1": constraints.positive, 34*da0073e9SAndroid Build Coastguard Worker "concentration0": constraints.positive, 35*da0073e9SAndroid Build Coastguard Worker } 36*da0073e9SAndroid Build Coastguard Worker support = constraints.unit_interval 37*da0073e9SAndroid Build Coastguard Worker has_rsample = True 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker def __init__(self, concentration1, concentration0, validate_args=None): 40*da0073e9SAndroid Build Coastguard Worker if isinstance(concentration1, Real) and isinstance(concentration0, Real): 41*da0073e9SAndroid Build Coastguard Worker concentration1_concentration0 = torch.tensor( 42*da0073e9SAndroid Build Coastguard Worker [float(concentration1), float(concentration0)] 43*da0073e9SAndroid Build Coastguard Worker ) 44*da0073e9SAndroid Build Coastguard Worker else: 45*da0073e9SAndroid Build Coastguard Worker concentration1, concentration0 = broadcast_all( 46*da0073e9SAndroid Build Coastguard Worker concentration1, concentration0 47*da0073e9SAndroid Build Coastguard Worker ) 48*da0073e9SAndroid Build Coastguard Worker concentration1_concentration0 = torch.stack( 49*da0073e9SAndroid Build Coastguard Worker [concentration1, concentration0], -1 50*da0073e9SAndroid Build Coastguard Worker ) 51*da0073e9SAndroid Build Coastguard Worker self._dirichlet = Dirichlet( 52*da0073e9SAndroid Build Coastguard Worker concentration1_concentration0, validate_args=validate_args 53*da0073e9SAndroid Build Coastguard Worker ) 54*da0073e9SAndroid Build Coastguard Worker super().__init__(self._dirichlet._batch_shape, validate_args=validate_args) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 57*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Beta, _instance) 58*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 59*da0073e9SAndroid Build Coastguard Worker new._dirichlet = self._dirichlet.expand(batch_shape) 60*da0073e9SAndroid Build Coastguard Worker super(Beta, new).__init__(batch_shape, validate_args=False) 61*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 62*da0073e9SAndroid Build Coastguard Worker return new 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @property 65*da0073e9SAndroid Build Coastguard Worker def mean(self): 66*da0073e9SAndroid Build Coastguard Worker return self.concentration1 / (self.concentration1 + self.concentration0) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker @property 69*da0073e9SAndroid Build Coastguard Worker def mode(self): 70*da0073e9SAndroid Build Coastguard Worker return self._dirichlet.mode[..., 0] 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker @property 73*da0073e9SAndroid Build Coastguard Worker def variance(self): 74*da0073e9SAndroid Build Coastguard Worker total = self.concentration1 + self.concentration0 75*da0073e9SAndroid Build Coastguard Worker return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = ()) -> torch.Tensor: 78*da0073e9SAndroid Build Coastguard Worker return self._dirichlet.rsample(sample_shape).select(-1, 0) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 81*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 82*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 83*da0073e9SAndroid Build Coastguard Worker heads_tails = torch.stack([value, 1.0 - value], -1) 84*da0073e9SAndroid Build Coastguard Worker return self._dirichlet.log_prob(heads_tails) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def entropy(self): 87*da0073e9SAndroid Build Coastguard Worker return self._dirichlet.entropy() 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker @property 90*da0073e9SAndroid Build Coastguard Worker def concentration1(self): 91*da0073e9SAndroid Build Coastguard Worker result = self._dirichlet.concentration[..., 0] 92*da0073e9SAndroid Build Coastguard Worker if isinstance(result, Number): 93*da0073e9SAndroid Build Coastguard Worker return torch.tensor([result]) 94*da0073e9SAndroid Build Coastguard Worker else: 95*da0073e9SAndroid Build Coastguard Worker return result 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker @property 98*da0073e9SAndroid Build Coastguard Worker def concentration0(self): 99*da0073e9SAndroid Build Coastguard Worker result = self._dirichlet.concentration[..., 1] 100*da0073e9SAndroid Build Coastguard Worker if isinstance(result, Number): 101*da0073e9SAndroid Build Coastguard Worker return torch.tensor([result]) 102*da0073e9SAndroid Build Coastguard Worker else: 103*da0073e9SAndroid Build Coastguard Worker return result 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker @property 106*da0073e9SAndroid Build Coastguard Worker def _natural_params(self): 107*da0073e9SAndroid Build Coastguard Worker return (self.concentration1, self.concentration0) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def _log_normalizer(self, x, y): 110*da0073e9SAndroid Build Coastguard Worker return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) 111