1# mypy: allow-untyped-defs 2import torch 3from torch.autograd import Function 4from torch.autograd.function import once_differentiable 5from torch.distributions import constraints 6from torch.distributions.exp_family import ExponentialFamily 7from torch.types import _size 8 9 10__all__ = ["Dirichlet"] 11 12 13# This helper is exposed for testing. 14def _Dirichlet_backward(x, concentration, grad_output): 15 total = concentration.sum(-1, True).expand_as(concentration) 16 grad = torch._dirichlet_grad(x, concentration, total) 17 return grad * (grad_output - (x * grad_output).sum(-1, True)) 18 19 20class _Dirichlet(Function): 21 @staticmethod 22 def forward(ctx, concentration): 23 x = torch._sample_dirichlet(concentration) 24 ctx.save_for_backward(x, concentration) 25 return x 26 27 @staticmethod 28 @once_differentiable 29 def backward(ctx, grad_output): 30 x, concentration = ctx.saved_tensors 31 return _Dirichlet_backward(x, concentration, grad_output) 32 33 34class Dirichlet(ExponentialFamily): 35 r""" 36 Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`. 37 38 Example:: 39 40 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 41 >>> m = Dirichlet(torch.tensor([0.5, 0.5])) 42 >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] 43 tensor([ 0.1046, 0.8954]) 44 45 Args: 46 concentration (Tensor): concentration parameter of the distribution 47 (often referred to as alpha) 48 """ 49 arg_constraints = { 50 "concentration": constraints.independent(constraints.positive, 1) 51 } 52 support = constraints.simplex 53 has_rsample = True 54 55 def __init__(self, concentration, validate_args=None): 56 if concentration.dim() < 1: 57 raise ValueError( 58 "`concentration` parameter must be at least one-dimensional." 59 ) 60 self.concentration = concentration 61 batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] 62 super().__init__(batch_shape, event_shape, validate_args=validate_args) 63 64 def expand(self, batch_shape, _instance=None): 65 new = self._get_checked_instance(Dirichlet, _instance) 66 batch_shape = torch.Size(batch_shape) 67 new.concentration = self.concentration.expand(batch_shape + self.event_shape) 68 super(Dirichlet, new).__init__( 69 batch_shape, self.event_shape, validate_args=False 70 ) 71 new._validate_args = self._validate_args 72 return new 73 74 def rsample(self, sample_shape: _size = ()) -> torch.Tensor: 75 shape = self._extended_shape(sample_shape) 76 concentration = self.concentration.expand(shape) 77 return _Dirichlet.apply(concentration) 78 79 def log_prob(self, value): 80 if self._validate_args: 81 self._validate_sample(value) 82 return ( 83 torch.xlogy(self.concentration - 1.0, value).sum(-1) 84 + torch.lgamma(self.concentration.sum(-1)) 85 - torch.lgamma(self.concentration).sum(-1) 86 ) 87 88 @property 89 def mean(self): 90 return self.concentration / self.concentration.sum(-1, True) 91 92 @property 93 def mode(self): 94 concentrationm1 = (self.concentration - 1).clamp(min=0.0) 95 mode = concentrationm1 / concentrationm1.sum(-1, True) 96 mask = (self.concentration < 1).all(axis=-1) 97 mode[mask] = torch.nn.functional.one_hot( 98 mode[mask].argmax(axis=-1), concentrationm1.shape[-1] 99 ).to(mode) 100 return mode 101 102 @property 103 def variance(self): 104 con0 = self.concentration.sum(-1, True) 105 return ( 106 self.concentration 107 * (con0 - self.concentration) 108 / (con0.pow(2) * (con0 + 1)) 109 ) 110 111 def entropy(self): 112 k = self.concentration.size(-1) 113 a0 = self.concentration.sum(-1) 114 return ( 115 torch.lgamma(self.concentration).sum(-1) 116 - torch.lgamma(a0) 117 - (k - a0) * torch.digamma(a0) 118 - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) 119 ) 120 121 @property 122 def _natural_params(self): 123 return (self.concentration,) 124 125 def _log_normalizer(self, x): 126 return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) 127