1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Function 4*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.function import once_differentiable 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily 7*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = ["Dirichlet"] 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# This helper is exposed for testing. 14*da0073e9SAndroid Build Coastguard Workerdef _Dirichlet_backward(x, concentration, grad_output): 15*da0073e9SAndroid Build Coastguard Worker total = concentration.sum(-1, True).expand_as(concentration) 16*da0073e9SAndroid Build Coastguard Worker grad = torch._dirichlet_grad(x, concentration, total) 17*da0073e9SAndroid Build Coastguard Worker return grad * (grad_output - (x * grad_output).sum(-1, True)) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass _Dirichlet(Function): 21*da0073e9SAndroid Build Coastguard Worker @staticmethod 22*da0073e9SAndroid Build Coastguard Worker def forward(ctx, concentration): 23*da0073e9SAndroid Build Coastguard Worker x = torch._sample_dirichlet(concentration) 24*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, concentration) 25*da0073e9SAndroid Build Coastguard Worker return x 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker @staticmethod 28*da0073e9SAndroid Build Coastguard Worker @once_differentiable 29*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 30*da0073e9SAndroid Build Coastguard Worker x, concentration = ctx.saved_tensors 31*da0073e9SAndroid Build Coastguard Worker return _Dirichlet_backward(x, concentration, grad_output) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerclass Dirichlet(ExponentialFamily): 35*da0073e9SAndroid Build Coastguard Worker r""" 36*da0073e9SAndroid Build Coastguard Worker Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`. 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker Example:: 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 41*da0073e9SAndroid Build Coastguard Worker >>> m = Dirichlet(torch.tensor([0.5, 0.5])) 42*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] 43*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1046, 0.8954]) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker Args: 46*da0073e9SAndroid Build Coastguard Worker concentration (Tensor): concentration parameter of the distribution 47*da0073e9SAndroid Build Coastguard Worker (often referred to as alpha) 48*da0073e9SAndroid Build Coastguard Worker """ 49*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 50*da0073e9SAndroid Build Coastguard Worker "concentration": constraints.independent(constraints.positive, 1) 51*da0073e9SAndroid Build Coastguard Worker } 52*da0073e9SAndroid Build Coastguard Worker support = constraints.simplex 53*da0073e9SAndroid Build Coastguard Worker has_rsample = True 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def __init__(self, concentration, validate_args=None): 56*da0073e9SAndroid Build Coastguard Worker if concentration.dim() < 1: 57*da0073e9SAndroid Build Coastguard Worker raise ValueError( 58*da0073e9SAndroid Build Coastguard Worker "`concentration` parameter must be at least one-dimensional." 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker self.concentration = concentration 61*da0073e9SAndroid Build Coastguard Worker batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] 62*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, event_shape, validate_args=validate_args) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 65*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Dirichlet, _instance) 66*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 67*da0073e9SAndroid Build Coastguard Worker new.concentration = self.concentration.expand(batch_shape + self.event_shape) 68*da0073e9SAndroid Build Coastguard Worker super(Dirichlet, new).__init__( 69*da0073e9SAndroid Build Coastguard Worker batch_shape, self.event_shape, validate_args=False 70*da0073e9SAndroid Build Coastguard Worker ) 71*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 72*da0073e9SAndroid Build Coastguard Worker return new 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = ()) -> torch.Tensor: 75*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 76*da0073e9SAndroid Build Coastguard Worker concentration = self.concentration.expand(shape) 77*da0073e9SAndroid Build Coastguard Worker return _Dirichlet.apply(concentration) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 80*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 81*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 82*da0073e9SAndroid Build Coastguard Worker return ( 83*da0073e9SAndroid Build Coastguard Worker torch.xlogy(self.concentration - 1.0, value).sum(-1) 84*da0073e9SAndroid Build Coastguard Worker + torch.lgamma(self.concentration.sum(-1)) 85*da0073e9SAndroid Build Coastguard Worker - torch.lgamma(self.concentration).sum(-1) 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker @property 89*da0073e9SAndroid Build Coastguard Worker def mean(self): 90*da0073e9SAndroid Build Coastguard Worker return self.concentration / self.concentration.sum(-1, True) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker @property 93*da0073e9SAndroid Build Coastguard Worker def mode(self): 94*da0073e9SAndroid Build Coastguard Worker concentrationm1 = (self.concentration - 1).clamp(min=0.0) 95*da0073e9SAndroid Build Coastguard Worker mode = concentrationm1 / concentrationm1.sum(-1, True) 96*da0073e9SAndroid Build Coastguard Worker mask = (self.concentration < 1).all(axis=-1) 97*da0073e9SAndroid Build Coastguard Worker mode[mask] = torch.nn.functional.one_hot( 98*da0073e9SAndroid Build Coastguard Worker mode[mask].argmax(axis=-1), concentrationm1.shape[-1] 99*da0073e9SAndroid Build Coastguard Worker ).to(mode) 100*da0073e9SAndroid Build Coastguard Worker return mode 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @property 103*da0073e9SAndroid Build Coastguard Worker def variance(self): 104*da0073e9SAndroid Build Coastguard Worker con0 = self.concentration.sum(-1, True) 105*da0073e9SAndroid Build Coastguard Worker return ( 106*da0073e9SAndroid Build Coastguard Worker self.concentration 107*da0073e9SAndroid Build Coastguard Worker * (con0 - self.concentration) 108*da0073e9SAndroid Build Coastguard Worker / (con0.pow(2) * (con0 + 1)) 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def entropy(self): 112*da0073e9SAndroid Build Coastguard Worker k = self.concentration.size(-1) 113*da0073e9SAndroid Build Coastguard Worker a0 = self.concentration.sum(-1) 114*da0073e9SAndroid Build Coastguard Worker return ( 115*da0073e9SAndroid Build Coastguard Worker torch.lgamma(self.concentration).sum(-1) 116*da0073e9SAndroid Build Coastguard Worker - torch.lgamma(a0) 117*da0073e9SAndroid Build Coastguard Worker - (k - a0) * torch.digamma(a0) 118*da0073e9SAndroid Build Coastguard Worker - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker @property 122*da0073e9SAndroid Build Coastguard Worker def _natural_params(self): 123*da0073e9SAndroid Build Coastguard Worker return (self.concentration,) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker def _log_normalizer(self, x): 126*da0073e9SAndroid Build Coastguard Worker return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) 127