1# mypy: allow-untyped-defs 2import torch 3from torch.distributions.distribution import Distribution 4 5 6__all__ = ["ExponentialFamily"] 7 8 9class ExponentialFamily(Distribution): 10 r""" 11 ExponentialFamily is the abstract base class for probability distributions belonging to an 12 exponential family, whose probability mass/density function has the form is defined below 13 14 .. math:: 15 16 p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) 17 18 where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, 19 :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier 20 measure. 21 22 Note: 23 This class is an intermediary between the `Distribution` class and distributions which belong 24 to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL 25 divergence methods. We use this class to compute the entropy and KL divergence using the AD 26 framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and 27 Cross-entropies of Exponential Families). 28 """ 29 30 @property 31 def _natural_params(self): 32 """ 33 Abstract method for natural parameters. Returns a tuple of Tensors based 34 on the distribution 35 """ 36 raise NotImplementedError 37 38 def _log_normalizer(self, *natural_params): 39 """ 40 Abstract method for log normalizer function. Returns a log normalizer based on 41 the distribution and input 42 """ 43 raise NotImplementedError 44 45 @property 46 def _mean_carrier_measure(self): 47 """ 48 Abstract method for expected carrier measure, which is required for computing 49 entropy. 50 """ 51 raise NotImplementedError 52 53 def entropy(self): 54 """ 55 Method to compute the entropy using Bregman divergence of the log normalizer. 56 """ 57 result = -self._mean_carrier_measure 58 nparams = [p.detach().requires_grad_() for p in self._natural_params] 59 lg_normal = self._log_normalizer(*nparams) 60 gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) 61 result += lg_normal 62 for np, g in zip(nparams, gradients): 63 result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) 64 return result 65