xref: /aosp_15_r20/external/pytorch/torch/distributions/exp_family.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.distributions.distribution import Distribution
4
5
6__all__ = ["ExponentialFamily"]
7
8
9class ExponentialFamily(Distribution):
10    r"""
11    ExponentialFamily is the abstract base class for probability distributions belonging to an
12    exponential family, whose probability mass/density function has the form is defined below
13
14    .. math::
15
16        p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
17
18    where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
19    :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
20    measure.
21
22    Note:
23        This class is an intermediary between the `Distribution` class and distributions which belong
24        to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
25        divergence methods. We use this class to compute the entropy and KL divergence using the AD
26        framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
27        Cross-entropies of Exponential Families).
28    """
29
30    @property
31    def _natural_params(self):
32        """
33        Abstract method for natural parameters. Returns a tuple of Tensors based
34        on the distribution
35        """
36        raise NotImplementedError
37
38    def _log_normalizer(self, *natural_params):
39        """
40        Abstract method for log normalizer function. Returns a log normalizer based on
41        the distribution and input
42        """
43        raise NotImplementedError
44
45    @property
46    def _mean_carrier_measure(self):
47        """
48        Abstract method for expected carrier measure, which is required for computing
49        entropy.
50        """
51        raise NotImplementedError
52
53    def entropy(self):
54        """
55        Method to compute the entropy using Bregman divergence of the log normalizer.
56        """
57        result = -self._mean_carrier_measure
58        nparams = [p.detach().requires_grad_() for p in self._natural_params]
59        lg_normal = self._log_normalizer(*nparams)
60        gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
61        result += lg_normal
62        for np, g in zip(nparams, gradients):
63            result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
64        return result
65