xref: /aosp_15_r20/external/pytorch/torch/distributions/dirichlet.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.autograd import Function
4from torch.autograd.function import once_differentiable
5from torch.distributions import constraints
6from torch.distributions.exp_family import ExponentialFamily
7from torch.types import _size
8
9
10__all__ = ["Dirichlet"]
11
12
13# This helper is exposed for testing.
14def _Dirichlet_backward(x, concentration, grad_output):
15    total = concentration.sum(-1, True).expand_as(concentration)
16    grad = torch._dirichlet_grad(x, concentration, total)
17    return grad * (grad_output - (x * grad_output).sum(-1, True))
18
19
20class _Dirichlet(Function):
21    @staticmethod
22    def forward(ctx, concentration):
23        x = torch._sample_dirichlet(concentration)
24        ctx.save_for_backward(x, concentration)
25        return x
26
27    @staticmethod
28    @once_differentiable
29    def backward(ctx, grad_output):
30        x, concentration = ctx.saved_tensors
31        return _Dirichlet_backward(x, concentration, grad_output)
32
33
34class Dirichlet(ExponentialFamily):
35    r"""
36    Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
37
38    Example::
39
40        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
41        >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
42        >>> m.sample()  # Dirichlet distributed with concentration [0.5, 0.5]
43        tensor([ 0.1046,  0.8954])
44
45    Args:
46        concentration (Tensor): concentration parameter of the distribution
47            (often referred to as alpha)
48    """
49    arg_constraints = {
50        "concentration": constraints.independent(constraints.positive, 1)
51    }
52    support = constraints.simplex
53    has_rsample = True
54
55    def __init__(self, concentration, validate_args=None):
56        if concentration.dim() < 1:
57            raise ValueError(
58                "`concentration` parameter must be at least one-dimensional."
59            )
60        self.concentration = concentration
61        batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
62        super().__init__(batch_shape, event_shape, validate_args=validate_args)
63
64    def expand(self, batch_shape, _instance=None):
65        new = self._get_checked_instance(Dirichlet, _instance)
66        batch_shape = torch.Size(batch_shape)
67        new.concentration = self.concentration.expand(batch_shape + self.event_shape)
68        super(Dirichlet, new).__init__(
69            batch_shape, self.event_shape, validate_args=False
70        )
71        new._validate_args = self._validate_args
72        return new
73
74    def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
75        shape = self._extended_shape(sample_shape)
76        concentration = self.concentration.expand(shape)
77        return _Dirichlet.apply(concentration)
78
79    def log_prob(self, value):
80        if self._validate_args:
81            self._validate_sample(value)
82        return (
83            torch.xlogy(self.concentration - 1.0, value).sum(-1)
84            + torch.lgamma(self.concentration.sum(-1))
85            - torch.lgamma(self.concentration).sum(-1)
86        )
87
88    @property
89    def mean(self):
90        return self.concentration / self.concentration.sum(-1, True)
91
92    @property
93    def mode(self):
94        concentrationm1 = (self.concentration - 1).clamp(min=0.0)
95        mode = concentrationm1 / concentrationm1.sum(-1, True)
96        mask = (self.concentration < 1).all(axis=-1)
97        mode[mask] = torch.nn.functional.one_hot(
98            mode[mask].argmax(axis=-1), concentrationm1.shape[-1]
99        ).to(mode)
100        return mode
101
102    @property
103    def variance(self):
104        con0 = self.concentration.sum(-1, True)
105        return (
106            self.concentration
107            * (con0 - self.concentration)
108            / (con0.pow(2) * (con0 + 1))
109        )
110
111    def entropy(self):
112        k = self.concentration.size(-1)
113        a0 = self.concentration.sum(-1)
114        return (
115            torch.lgamma(self.concentration).sum(-1)
116            - torch.lgamma(a0)
117            - (k - a0) * torch.digamma(a0)
118            - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
119        )
120
121    @property
122    def _natural_params(self):
123        return (self.concentration,)
124
125    def _log_normalizer(self, x):
126        return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
127