1# mypy: allow-untyped-defs 2from numbers import Number, Real 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.dirichlet import Dirichlet 7from torch.distributions.exp_family import ExponentialFamily 8from torch.distributions.utils import broadcast_all 9from torch.types import _size 10 11 12__all__ = ["Beta"] 13 14 15class Beta(ExponentialFamily): 16 r""" 17 Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. 18 19 Example:: 20 21 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 22 >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) 23 >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 24 tensor([ 0.1046]) 25 26 Args: 27 concentration1 (float or Tensor): 1st concentration parameter of the distribution 28 (often referred to as alpha) 29 concentration0 (float or Tensor): 2nd concentration parameter of the distribution 30 (often referred to as beta) 31 """ 32 arg_constraints = { 33 "concentration1": constraints.positive, 34 "concentration0": constraints.positive, 35 } 36 support = constraints.unit_interval 37 has_rsample = True 38 39 def __init__(self, concentration1, concentration0, validate_args=None): 40 if isinstance(concentration1, Real) and isinstance(concentration0, Real): 41 concentration1_concentration0 = torch.tensor( 42 [float(concentration1), float(concentration0)] 43 ) 44 else: 45 concentration1, concentration0 = broadcast_all( 46 concentration1, concentration0 47 ) 48 concentration1_concentration0 = torch.stack( 49 [concentration1, concentration0], -1 50 ) 51 self._dirichlet = Dirichlet( 52 concentration1_concentration0, validate_args=validate_args 53 ) 54 super().__init__(self._dirichlet._batch_shape, validate_args=validate_args) 55 56 def expand(self, batch_shape, _instance=None): 57 new = self._get_checked_instance(Beta, _instance) 58 batch_shape = torch.Size(batch_shape) 59 new._dirichlet = self._dirichlet.expand(batch_shape) 60 super(Beta, new).__init__(batch_shape, validate_args=False) 61 new._validate_args = self._validate_args 62 return new 63 64 @property 65 def mean(self): 66 return self.concentration1 / (self.concentration1 + self.concentration0) 67 68 @property 69 def mode(self): 70 return self._dirichlet.mode[..., 0] 71 72 @property 73 def variance(self): 74 total = self.concentration1 + self.concentration0 75 return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) 76 77 def rsample(self, sample_shape: _size = ()) -> torch.Tensor: 78 return self._dirichlet.rsample(sample_shape).select(-1, 0) 79 80 def log_prob(self, value): 81 if self._validate_args: 82 self._validate_sample(value) 83 heads_tails = torch.stack([value, 1.0 - value], -1) 84 return self._dirichlet.log_prob(heads_tails) 85 86 def entropy(self): 87 return self._dirichlet.entropy() 88 89 @property 90 def concentration1(self): 91 result = self._dirichlet.concentration[..., 0] 92 if isinstance(result, Number): 93 return torch.tensor([result]) 94 else: 95 return result 96 97 @property 98 def concentration0(self): 99 result = self._dirichlet.concentration[..., 1] 100 if isinstance(result, Number): 101 return torch.tensor([result]) 102 else: 103 return result 104 105 @property 106 def _natural_params(self): 107 return (self.concentration1, self.concentration0) 108 109 def _log_normalizer(self, x, y): 110 return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) 111