xref: /aosp_15_r20/external/pytorch/torch/distributions/kl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport math
3*da0073e9SAndroid Build Coastguard Workerimport warnings
4*da0073e9SAndroid Build Coastguard Workerfrom functools import total_ordering
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Tuple, Type
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch import inf
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom .bernoulli import Bernoulli
11*da0073e9SAndroid Build Coastguard Workerfrom .beta import Beta
12*da0073e9SAndroid Build Coastguard Workerfrom .binomial import Binomial
13*da0073e9SAndroid Build Coastguard Workerfrom .categorical import Categorical
14*da0073e9SAndroid Build Coastguard Workerfrom .cauchy import Cauchy
15*da0073e9SAndroid Build Coastguard Workerfrom .continuous_bernoulli import ContinuousBernoulli
16*da0073e9SAndroid Build Coastguard Workerfrom .dirichlet import Dirichlet
17*da0073e9SAndroid Build Coastguard Workerfrom .distribution import Distribution
18*da0073e9SAndroid Build Coastguard Workerfrom .exp_family import ExponentialFamily
19*da0073e9SAndroid Build Coastguard Workerfrom .exponential import Exponential
20*da0073e9SAndroid Build Coastguard Workerfrom .gamma import Gamma
21*da0073e9SAndroid Build Coastguard Workerfrom .geometric import Geometric
22*da0073e9SAndroid Build Coastguard Workerfrom .gumbel import Gumbel
23*da0073e9SAndroid Build Coastguard Workerfrom .half_normal import HalfNormal
24*da0073e9SAndroid Build Coastguard Workerfrom .independent import Independent
25*da0073e9SAndroid Build Coastguard Workerfrom .laplace import Laplace
26*da0073e9SAndroid Build Coastguard Workerfrom .lowrank_multivariate_normal import (
27*da0073e9SAndroid Build Coastguard Worker    _batch_lowrank_logdet,
28*da0073e9SAndroid Build Coastguard Worker    _batch_lowrank_mahalanobis,
29*da0073e9SAndroid Build Coastguard Worker    LowRankMultivariateNormal,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom .multivariate_normal import _batch_mahalanobis, MultivariateNormal
32*da0073e9SAndroid Build Coastguard Workerfrom .normal import Normal
33*da0073e9SAndroid Build Coastguard Workerfrom .one_hot_categorical import OneHotCategorical
34*da0073e9SAndroid Build Coastguard Workerfrom .pareto import Pareto
35*da0073e9SAndroid Build Coastguard Workerfrom .poisson import Poisson
36*da0073e9SAndroid Build Coastguard Workerfrom .transformed_distribution import TransformedDistribution
37*da0073e9SAndroid Build Coastguard Workerfrom .uniform import Uniform
38*da0073e9SAndroid Build Coastguard Workerfrom .utils import _sum_rightmost, euler_constant as _euler_gamma
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker_KL_REGISTRY: Dict[
42*da0073e9SAndroid Build Coastguard Worker    Tuple[Type, Type], Callable
43*da0073e9SAndroid Build Coastguard Worker] = {}  # Source of truth mapping a few general (type, type) pairs to functions.
44*da0073e9SAndroid Build Coastguard Worker_KL_MEMOIZE: Dict[
45*da0073e9SAndroid Build Coastguard Worker    Tuple[Type, Type], Callable
46*da0073e9SAndroid Build Coastguard Worker] = {}  # Memoized version mapping many specific (type, type) pairs to functions.
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker__all__ = ["register_kl", "kl_divergence"]
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerdef register_kl(type_p, type_q):
52*da0073e9SAndroid Build Coastguard Worker    """
53*da0073e9SAndroid Build Coastguard Worker    Decorator to register a pairwise function with :meth:`kl_divergence`.
54*da0073e9SAndroid Build Coastguard Worker    Usage::
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        @register_kl(Normal, Normal)
57*da0073e9SAndroid Build Coastguard Worker        def kl_normal_normal(p, q):
58*da0073e9SAndroid Build Coastguard Worker            # insert implementation here
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    Lookup returns the most specific (type,type) match ordered by subclass. If
61*da0073e9SAndroid Build Coastguard Worker    the match is ambiguous, a `RuntimeWarning` is raised. For example to
62*da0073e9SAndroid Build Coastguard Worker    resolve the ambiguous situation::
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        @register_kl(BaseP, DerivedQ)
65*da0073e9SAndroid Build Coastguard Worker        def kl_version1(p, q): ...
66*da0073e9SAndroid Build Coastguard Worker        @register_kl(DerivedP, BaseQ)
67*da0073e9SAndroid Build Coastguard Worker        def kl_version2(p, q): ...
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    you should register a third most-specific implementation, e.g.::
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    Args:
74*da0073e9SAndroid Build Coastguard Worker        type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
75*da0073e9SAndroid Build Coastguard Worker        type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
76*da0073e9SAndroid Build Coastguard Worker    """
77*da0073e9SAndroid Build Coastguard Worker    if not isinstance(type_p, type) and issubclass(type_p, Distribution):
78*da0073e9SAndroid Build Coastguard Worker        raise TypeError(
79*da0073e9SAndroid Build Coastguard Worker            f"Expected type_p to be a Distribution subclass but got {type_p}"
80*da0073e9SAndroid Build Coastguard Worker        )
81*da0073e9SAndroid Build Coastguard Worker    if not isinstance(type_q, type) and issubclass(type_q, Distribution):
82*da0073e9SAndroid Build Coastguard Worker        raise TypeError(
83*da0073e9SAndroid Build Coastguard Worker            f"Expected type_q to be a Distribution subclass but got {type_q}"
84*da0073e9SAndroid Build Coastguard Worker        )
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def decorator(fun):
87*da0073e9SAndroid Build Coastguard Worker        _KL_REGISTRY[type_p, type_q] = fun
88*da0073e9SAndroid Build Coastguard Worker        _KL_MEMOIZE.clear()  # reset since lookup order may have changed
89*da0073e9SAndroid Build Coastguard Worker        return fun
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    return decorator
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker@total_ordering
95*da0073e9SAndroid Build Coastguard Workerclass _Match:
96*da0073e9SAndroid Build Coastguard Worker    __slots__ = ["types"]
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *types):
99*da0073e9SAndroid Build Coastguard Worker        self.types = types
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
102*da0073e9SAndroid Build Coastguard Worker        return self.types == other.types
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def __le__(self, other):
105*da0073e9SAndroid Build Coastguard Worker        for x, y in zip(self.types, other.types):
106*da0073e9SAndroid Build Coastguard Worker            if not issubclass(x, y):
107*da0073e9SAndroid Build Coastguard Worker                return False
108*da0073e9SAndroid Build Coastguard Worker            if x is not y:
109*da0073e9SAndroid Build Coastguard Worker                break
110*da0073e9SAndroid Build Coastguard Worker        return True
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Workerdef _dispatch_kl(type_p, type_q):
114*da0073e9SAndroid Build Coastguard Worker    """
115*da0073e9SAndroid Build Coastguard Worker    Find the most specific approximate match, assuming single inheritance.
116*da0073e9SAndroid Build Coastguard Worker    """
117*da0073e9SAndroid Build Coastguard Worker    matches = [
118*da0073e9SAndroid Build Coastguard Worker        (super_p, super_q)
119*da0073e9SAndroid Build Coastguard Worker        for super_p, super_q in _KL_REGISTRY
120*da0073e9SAndroid Build Coastguard Worker        if issubclass(type_p, super_p) and issubclass(type_q, super_q)
121*da0073e9SAndroid Build Coastguard Worker    ]
122*da0073e9SAndroid Build Coastguard Worker    if not matches:
123*da0073e9SAndroid Build Coastguard Worker        return NotImplemented
124*da0073e9SAndroid Build Coastguard Worker    # Check that the left- and right- lexicographic orders agree.
125*da0073e9SAndroid Build Coastguard Worker    # mypy isn't smart enough to know that _Match implements __lt__
126*da0073e9SAndroid Build Coastguard Worker    # see: https://github.com/python/typing/issues/760#issuecomment-710670503
127*da0073e9SAndroid Build Coastguard Worker    left_p, left_q = min(_Match(*m) for m in matches).types  # type: ignore[type-var]
128*da0073e9SAndroid Build Coastguard Worker    right_q, right_p = min(_Match(*reversed(m)) for m in matches).types  # type: ignore[type-var]
129*da0073e9SAndroid Build Coastguard Worker    left_fun = _KL_REGISTRY[left_p, left_q]
130*da0073e9SAndroid Build Coastguard Worker    right_fun = _KL_REGISTRY[right_p, right_q]
131*da0073e9SAndroid Build Coastguard Worker    if left_fun is not right_fun:
132*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
133*da0073e9SAndroid Build Coastguard Worker            f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). "
134*da0073e9SAndroid Build Coastguard Worker            f"Please register_kl({left_p.__name__}, {right_q.__name__})",
135*da0073e9SAndroid Build Coastguard Worker            RuntimeWarning,
136*da0073e9SAndroid Build Coastguard Worker        )
137*da0073e9SAndroid Build Coastguard Worker    return left_fun
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Workerdef _infinite_like(tensor):
141*da0073e9SAndroid Build Coastguard Worker    """
142*da0073e9SAndroid Build Coastguard Worker    Helper function for obtaining infinite KL Divergence throughout
143*da0073e9SAndroid Build Coastguard Worker    """
144*da0073e9SAndroid Build Coastguard Worker    return torch.full_like(tensor, inf)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Workerdef _x_log_x(tensor):
148*da0073e9SAndroid Build Coastguard Worker    """
149*da0073e9SAndroid Build Coastguard Worker    Utility function for calculating x log x
150*da0073e9SAndroid Build Coastguard Worker    """
151*da0073e9SAndroid Build Coastguard Worker    return tensor * tensor.log()
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Workerdef _batch_trace_XXT(bmat):
155*da0073e9SAndroid Build Coastguard Worker    """
156*da0073e9SAndroid Build Coastguard Worker    Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
157*da0073e9SAndroid Build Coastguard Worker    """
158*da0073e9SAndroid Build Coastguard Worker    n = bmat.size(-1)
159*da0073e9SAndroid Build Coastguard Worker    m = bmat.size(-2)
160*da0073e9SAndroid Build Coastguard Worker    flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
161*da0073e9SAndroid Build Coastguard Worker    return flat_trace.reshape(bmat.shape[:-2])
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Workerdef kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
165*da0073e9SAndroid Build Coastguard Worker    r"""
166*da0073e9SAndroid Build Coastguard Worker    Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    .. math::
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    Args:
173*da0073e9SAndroid Build Coastguard Worker        p (Distribution): A :class:`~torch.distributions.Distribution` object.
174*da0073e9SAndroid Build Coastguard Worker        q (Distribution): A :class:`~torch.distributions.Distribution` object.
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    Returns:
177*da0073e9SAndroid Build Coastguard Worker        Tensor: A batch of KL divergences of shape `batch_shape`.
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    Raises:
180*da0073e9SAndroid Build Coastguard Worker        NotImplementedError: If the distribution types have not been registered via
181*da0073e9SAndroid Build Coastguard Worker            :meth:`register_kl`.
182*da0073e9SAndroid Build Coastguard Worker    """
183*da0073e9SAndroid Build Coastguard Worker    try:
184*da0073e9SAndroid Build Coastguard Worker        fun = _KL_MEMOIZE[type(p), type(q)]
185*da0073e9SAndroid Build Coastguard Worker    except KeyError:
186*da0073e9SAndroid Build Coastguard Worker        fun = _dispatch_kl(type(p), type(q))
187*da0073e9SAndroid Build Coastguard Worker        _KL_MEMOIZE[type(p), type(q)] = fun
188*da0073e9SAndroid Build Coastguard Worker    if fun is NotImplemented:
189*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
190*da0073e9SAndroid Build Coastguard Worker            f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
191*da0073e9SAndroid Build Coastguard Worker        )
192*da0073e9SAndroid Build Coastguard Worker    return fun(p, q)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker################################################################################
196*da0073e9SAndroid Build Coastguard Worker# KL Divergence Implementations
197*da0073e9SAndroid Build Coastguard Worker################################################################################
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker# Same distributions
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker@register_kl(Bernoulli, Bernoulli)
203*da0073e9SAndroid Build Coastguard Workerdef _kl_bernoulli_bernoulli(p, q):
204*da0073e9SAndroid Build Coastguard Worker    t1 = p.probs * (
205*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.softplus(-q.logits)
206*da0073e9SAndroid Build Coastguard Worker        - torch.nn.functional.softplus(-p.logits)
207*da0073e9SAndroid Build Coastguard Worker    )
208*da0073e9SAndroid Build Coastguard Worker    t1[q.probs == 0] = inf
209*da0073e9SAndroid Build Coastguard Worker    t1[p.probs == 0] = 0
210*da0073e9SAndroid Build Coastguard Worker    t2 = (1 - p.probs) * (
211*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
212*da0073e9SAndroid Build Coastguard Worker    )
213*da0073e9SAndroid Build Coastguard Worker    t2[q.probs == 1] = inf
214*da0073e9SAndroid Build Coastguard Worker    t2[p.probs == 1] = 0
215*da0073e9SAndroid Build Coastguard Worker    return t1 + t2
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Beta)
219*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_beta(p, q):
220*da0073e9SAndroid Build Coastguard Worker    sum_params_p = p.concentration1 + p.concentration0
221*da0073e9SAndroid Build Coastguard Worker    sum_params_q = q.concentration1 + q.concentration0
222*da0073e9SAndroid Build Coastguard Worker    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
223*da0073e9SAndroid Build Coastguard Worker    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
224*da0073e9SAndroid Build Coastguard Worker    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
225*da0073e9SAndroid Build Coastguard Worker    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
226*da0073e9SAndroid Build Coastguard Worker    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
227*da0073e9SAndroid Build Coastguard Worker    return t1 - t2 + t3 + t4 + t5
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker@register_kl(Binomial, Binomial)
231*da0073e9SAndroid Build Coastguard Workerdef _kl_binomial_binomial(p, q):
232*da0073e9SAndroid Build Coastguard Worker    # from https://math.stackexchange.com/questions/2214993/
233*da0073e9SAndroid Build Coastguard Worker    # kullback-leibler-divergence-for-binomial-distributions-p-and-q
234*da0073e9SAndroid Build Coastguard Worker    if (p.total_count < q.total_count).any():
235*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
236*da0073e9SAndroid Build Coastguard Worker            "KL between Binomials where q.total_count > p.total_count is not implemented"
237*da0073e9SAndroid Build Coastguard Worker        )
238*da0073e9SAndroid Build Coastguard Worker    kl = p.total_count * (
239*da0073e9SAndroid Build Coastguard Worker        p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
240*da0073e9SAndroid Build Coastguard Worker    )
241*da0073e9SAndroid Build Coastguard Worker    inf_idxs = p.total_count > q.total_count
242*da0073e9SAndroid Build Coastguard Worker    kl[inf_idxs] = _infinite_like(kl[inf_idxs])
243*da0073e9SAndroid Build Coastguard Worker    return kl
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker@register_kl(Categorical, Categorical)
247*da0073e9SAndroid Build Coastguard Workerdef _kl_categorical_categorical(p, q):
248*da0073e9SAndroid Build Coastguard Worker    t = p.probs * (p.logits - q.logits)
249*da0073e9SAndroid Build Coastguard Worker    t[(q.probs == 0).expand_as(t)] = inf
250*da0073e9SAndroid Build Coastguard Worker    t[(p.probs == 0).expand_as(t)] = 0
251*da0073e9SAndroid Build Coastguard Worker    return t.sum(-1)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, ContinuousBernoulli)
255*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_continuous_bernoulli(p, q):
256*da0073e9SAndroid Build Coastguard Worker    t1 = p.mean * (p.logits - q.logits)
257*da0073e9SAndroid Build Coastguard Worker    t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
258*da0073e9SAndroid Build Coastguard Worker    t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
259*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker@register_kl(Dirichlet, Dirichlet)
263*da0073e9SAndroid Build Coastguard Workerdef _kl_dirichlet_dirichlet(p, q):
264*da0073e9SAndroid Build Coastguard Worker    # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
265*da0073e9SAndroid Build Coastguard Worker    sum_p_concentration = p.concentration.sum(-1)
266*da0073e9SAndroid Build Coastguard Worker    sum_q_concentration = q.concentration.sum(-1)
267*da0073e9SAndroid Build Coastguard Worker    t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
268*da0073e9SAndroid Build Coastguard Worker    t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
269*da0073e9SAndroid Build Coastguard Worker    t3 = p.concentration - q.concentration
270*da0073e9SAndroid Build Coastguard Worker    t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
271*da0073e9SAndroid Build Coastguard Worker    return t1 - t2 + (t3 * t4).sum(-1)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Exponential)
275*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_exponential(p, q):
276*da0073e9SAndroid Build Coastguard Worker    rate_ratio = q.rate / p.rate
277*da0073e9SAndroid Build Coastguard Worker    t1 = -rate_ratio.log()
278*da0073e9SAndroid Build Coastguard Worker    return t1 + rate_ratio - 1
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker@register_kl(ExponentialFamily, ExponentialFamily)
282*da0073e9SAndroid Build Coastguard Workerdef _kl_expfamily_expfamily(p, q):
283*da0073e9SAndroid Build Coastguard Worker    if not type(p) == type(q):
284*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
285*da0073e9SAndroid Build Coastguard Worker            "The cross KL-divergence between different exponential families cannot \
286*da0073e9SAndroid Build Coastguard Worker                            be computed using Bregman divergences"
287*da0073e9SAndroid Build Coastguard Worker        )
288*da0073e9SAndroid Build Coastguard Worker    p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
289*da0073e9SAndroid Build Coastguard Worker    q_nparams = q._natural_params
290*da0073e9SAndroid Build Coastguard Worker    lg_normal = p._log_normalizer(*p_nparams)
291*da0073e9SAndroid Build Coastguard Worker    gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
292*da0073e9SAndroid Build Coastguard Worker    result = q._log_normalizer(*q_nparams) - lg_normal
293*da0073e9SAndroid Build Coastguard Worker    for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
294*da0073e9SAndroid Build Coastguard Worker        term = (qnp - pnp) * g
295*da0073e9SAndroid Build Coastguard Worker        result -= _sum_rightmost(term, len(q.event_shape))
296*da0073e9SAndroid Build Coastguard Worker    return result
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Gamma)
300*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_gamma(p, q):
301*da0073e9SAndroid Build Coastguard Worker    t1 = q.concentration * (p.rate / q.rate).log()
302*da0073e9SAndroid Build Coastguard Worker    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
303*da0073e9SAndroid Build Coastguard Worker    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
304*da0073e9SAndroid Build Coastguard Worker    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
305*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3 + t4
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Gumbel)
309*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_gumbel(p, q):
310*da0073e9SAndroid Build Coastguard Worker    ct1 = p.scale / q.scale
311*da0073e9SAndroid Build Coastguard Worker    ct2 = q.loc / q.scale
312*da0073e9SAndroid Build Coastguard Worker    ct3 = p.loc / q.scale
313*da0073e9SAndroid Build Coastguard Worker    t1 = -ct1.log() - ct2 + ct3
314*da0073e9SAndroid Build Coastguard Worker    t2 = ct1 * _euler_gamma
315*da0073e9SAndroid Build Coastguard Worker    t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
316*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3 - (1 + _euler_gamma)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker@register_kl(Geometric, Geometric)
320*da0073e9SAndroid Build Coastguard Workerdef _kl_geometric_geometric(p, q):
321*da0073e9SAndroid Build Coastguard Worker    return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker@register_kl(HalfNormal, HalfNormal)
325*da0073e9SAndroid Build Coastguard Workerdef _kl_halfnormal_halfnormal(p, q):
326*da0073e9SAndroid Build Coastguard Worker    return _kl_normal_normal(p.base_dist, q.base_dist)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Laplace)
330*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_laplace(p, q):
331*da0073e9SAndroid Build Coastguard Worker    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
332*da0073e9SAndroid Build Coastguard Worker    scale_ratio = p.scale / q.scale
333*da0073e9SAndroid Build Coastguard Worker    loc_abs_diff = (p.loc - q.loc).abs()
334*da0073e9SAndroid Build Coastguard Worker    t1 = -scale_ratio.log()
335*da0073e9SAndroid Build Coastguard Worker    t2 = loc_abs_diff / q.scale
336*da0073e9SAndroid Build Coastguard Worker    t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
337*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3 - 1
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
341*da0073e9SAndroid Build Coastguard Workerdef _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
342*da0073e9SAndroid Build Coastguard Worker    if p.event_shape != q.event_shape:
343*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
344*da0073e9SAndroid Build Coastguard Worker            "KL-divergence between two Low Rank Multivariate Normals with\
345*da0073e9SAndroid Build Coastguard Worker                          different event shapes cannot be computed"
346*da0073e9SAndroid Build Coastguard Worker        )
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    term1 = _batch_lowrank_logdet(
349*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
350*da0073e9SAndroid Build Coastguard Worker    ) - _batch_lowrank_logdet(
351*da0073e9SAndroid Build Coastguard Worker        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
352*da0073e9SAndroid Build Coastguard Worker    )
353*da0073e9SAndroid Build Coastguard Worker    term3 = _batch_lowrank_mahalanobis(
354*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_factor,
355*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_diag,
356*da0073e9SAndroid Build Coastguard Worker        q.loc - p.loc,
357*da0073e9SAndroid Build Coastguard Worker        q._capacitance_tril,
358*da0073e9SAndroid Build Coastguard Worker    )
359*da0073e9SAndroid Build Coastguard Worker    # Expands term2 according to
360*da0073e9SAndroid Build Coastguard Worker    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
361*da0073e9SAndroid Build Coastguard Worker    #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
362*da0073e9SAndroid Build Coastguard Worker    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
363*da0073e9SAndroid Build Coastguard Worker    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
364*da0073e9SAndroid Build Coastguard Worker    term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
365*da0073e9SAndroid Build Coastguard Worker    term22 = _batch_trace_XXT(
366*da0073e9SAndroid Build Coastguard Worker        p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
367*da0073e9SAndroid Build Coastguard Worker    )
368*da0073e9SAndroid Build Coastguard Worker    term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
369*da0073e9SAndroid Build Coastguard Worker    term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
370*da0073e9SAndroid Build Coastguard Worker    term2 = term21 + term22 - term23 - term24
371*da0073e9SAndroid Build Coastguard Worker    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker@register_kl(MultivariateNormal, LowRankMultivariateNormal)
375*da0073e9SAndroid Build Coastguard Workerdef _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
376*da0073e9SAndroid Build Coastguard Worker    if p.event_shape != q.event_shape:
377*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
378*da0073e9SAndroid Build Coastguard Worker            "KL-divergence between two (Low Rank) Multivariate Normals with\
379*da0073e9SAndroid Build Coastguard Worker                          different event shapes cannot be computed"
380*da0073e9SAndroid Build Coastguard Worker        )
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    term1 = _batch_lowrank_logdet(
383*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
384*da0073e9SAndroid Build Coastguard Worker    ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
385*da0073e9SAndroid Build Coastguard Worker    term3 = _batch_lowrank_mahalanobis(
386*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_factor,
387*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_cov_diag,
388*da0073e9SAndroid Build Coastguard Worker        q.loc - p.loc,
389*da0073e9SAndroid Build Coastguard Worker        q._capacitance_tril,
390*da0073e9SAndroid Build Coastguard Worker    )
391*da0073e9SAndroid Build Coastguard Worker    # Expands term2 according to
392*da0073e9SAndroid Build Coastguard Worker    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
393*da0073e9SAndroid Build Coastguard Worker    #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
394*da0073e9SAndroid Build Coastguard Worker    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
395*da0073e9SAndroid Build Coastguard Worker    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
396*da0073e9SAndroid Build Coastguard Worker    term21 = _batch_trace_XXT(
397*da0073e9SAndroid Build Coastguard Worker        p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
398*da0073e9SAndroid Build Coastguard Worker    )
399*da0073e9SAndroid Build Coastguard Worker    term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
400*da0073e9SAndroid Build Coastguard Worker    term2 = term21 - term22
401*da0073e9SAndroid Build Coastguard Worker    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker@register_kl(LowRankMultivariateNormal, MultivariateNormal)
405*da0073e9SAndroid Build Coastguard Workerdef _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
406*da0073e9SAndroid Build Coastguard Worker    if p.event_shape != q.event_shape:
407*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
408*da0073e9SAndroid Build Coastguard Worker            "KL-divergence between two (Low Rank) Multivariate Normals with\
409*da0073e9SAndroid Build Coastguard Worker                          different event shapes cannot be computed"
410*da0073e9SAndroid Build Coastguard Worker        )
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker    term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
413*da0073e9SAndroid Build Coastguard Worker        -1
414*da0073e9SAndroid Build Coastguard Worker    ) - _batch_lowrank_logdet(
415*da0073e9SAndroid Build Coastguard Worker        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
416*da0073e9SAndroid Build Coastguard Worker    )
417*da0073e9SAndroid Build Coastguard Worker    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
418*da0073e9SAndroid Build Coastguard Worker    # Expands term2 according to
419*da0073e9SAndroid Build Coastguard Worker    # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
420*da0073e9SAndroid Build Coastguard Worker    combined_batch_shape = torch._C._infer_size(
421*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
422*da0073e9SAndroid Build Coastguard Worker    )
423*da0073e9SAndroid Build Coastguard Worker    n = p.event_shape[0]
424*da0073e9SAndroid Build Coastguard Worker    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
425*da0073e9SAndroid Build Coastguard Worker    p_cov_factor = p._unbroadcasted_cov_factor.expand(
426*da0073e9SAndroid Build Coastguard Worker        combined_batch_shape + (n, p.cov_factor.size(-1))
427*da0073e9SAndroid Build Coastguard Worker    )
428*da0073e9SAndroid Build Coastguard Worker    p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
429*da0073e9SAndroid Build Coastguard Worker        combined_batch_shape + (n, n)
430*da0073e9SAndroid Build Coastguard Worker    )
431*da0073e9SAndroid Build Coastguard Worker    term21 = _batch_trace_XXT(
432*da0073e9SAndroid Build Coastguard Worker        torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
433*da0073e9SAndroid Build Coastguard Worker    )
434*da0073e9SAndroid Build Coastguard Worker    term22 = _batch_trace_XXT(
435*da0073e9SAndroid Build Coastguard Worker        torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
436*da0073e9SAndroid Build Coastguard Worker    )
437*da0073e9SAndroid Build Coastguard Worker    term2 = term21 + term22
438*da0073e9SAndroid Build Coastguard Worker    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker@register_kl(MultivariateNormal, MultivariateNormal)
442*da0073e9SAndroid Build Coastguard Workerdef _kl_multivariatenormal_multivariatenormal(p, q):
443*da0073e9SAndroid Build Coastguard Worker    # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
444*da0073e9SAndroid Build Coastguard Worker    if p.event_shape != q.event_shape:
445*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
446*da0073e9SAndroid Build Coastguard Worker            "KL-divergence between two Multivariate Normals with\
447*da0073e9SAndroid Build Coastguard Worker                          different event shapes cannot be computed"
448*da0073e9SAndroid Build Coastguard Worker        )
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker    half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
451*da0073e9SAndroid Build Coastguard Worker        -1
452*da0073e9SAndroid Build Coastguard Worker    ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
453*da0073e9SAndroid Build Coastguard Worker    combined_batch_shape = torch._C._infer_size(
454*da0073e9SAndroid Build Coastguard Worker        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
455*da0073e9SAndroid Build Coastguard Worker    )
456*da0073e9SAndroid Build Coastguard Worker    n = p.event_shape[0]
457*da0073e9SAndroid Build Coastguard Worker    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
458*da0073e9SAndroid Build Coastguard Worker    p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
459*da0073e9SAndroid Build Coastguard Worker    term2 = _batch_trace_XXT(
460*da0073e9SAndroid Build Coastguard Worker        torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
461*da0073e9SAndroid Build Coastguard Worker    )
462*da0073e9SAndroid Build Coastguard Worker    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
463*da0073e9SAndroid Build Coastguard Worker    return half_term1 + 0.5 * (term2 + term3 - n)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Normal)
467*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_normal(p, q):
468*da0073e9SAndroid Build Coastguard Worker    var_ratio = (p.scale / q.scale).pow(2)
469*da0073e9SAndroid Build Coastguard Worker    t1 = ((p.loc - q.loc) / q.scale).pow(2)
470*da0073e9SAndroid Build Coastguard Worker    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker@register_kl(OneHotCategorical, OneHotCategorical)
474*da0073e9SAndroid Build Coastguard Workerdef _kl_onehotcategorical_onehotcategorical(p, q):
475*da0073e9SAndroid Build Coastguard Worker    return _kl_categorical_categorical(p._categorical, q._categorical)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Pareto)
479*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_pareto(p, q):
480*da0073e9SAndroid Build Coastguard Worker    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
481*da0073e9SAndroid Build Coastguard Worker    scale_ratio = p.scale / q.scale
482*da0073e9SAndroid Build Coastguard Worker    alpha_ratio = q.alpha / p.alpha
483*da0073e9SAndroid Build Coastguard Worker    t1 = q.alpha * scale_ratio.log()
484*da0073e9SAndroid Build Coastguard Worker    t2 = -alpha_ratio.log()
485*da0073e9SAndroid Build Coastguard Worker    result = t1 + t2 + alpha_ratio - 1
486*da0073e9SAndroid Build Coastguard Worker    result[p.support.lower_bound < q.support.lower_bound] = inf
487*da0073e9SAndroid Build Coastguard Worker    return result
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Poisson)
491*da0073e9SAndroid Build Coastguard Workerdef _kl_poisson_poisson(p, q):
492*da0073e9SAndroid Build Coastguard Worker    return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker@register_kl(TransformedDistribution, TransformedDistribution)
496*da0073e9SAndroid Build Coastguard Workerdef _kl_transformed_transformed(p, q):
497*da0073e9SAndroid Build Coastguard Worker    if p.transforms != q.transforms:
498*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
499*da0073e9SAndroid Build Coastguard Worker    if p.event_shape != q.event_shape:
500*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
501*da0073e9SAndroid Build Coastguard Worker    return kl_divergence(p.base_dist, q.base_dist)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Uniform)
505*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_uniform(p, q):
506*da0073e9SAndroid Build Coastguard Worker    result = ((q.high - q.low) / (p.high - p.low)).log()
507*da0073e9SAndroid Build Coastguard Worker    result[(q.low > p.low) | (q.high < p.high)] = inf
508*da0073e9SAndroid Build Coastguard Worker    return result
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker# Different distributions
512*da0073e9SAndroid Build Coastguard Worker@register_kl(Bernoulli, Poisson)
513*da0073e9SAndroid Build Coastguard Workerdef _kl_bernoulli_poisson(p, q):
514*da0073e9SAndroid Build Coastguard Worker    return -p.entropy() - (p.probs * q.rate.log() - q.rate)
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, ContinuousBernoulli)
518*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_continuous_bernoulli(p, q):
519*da0073e9SAndroid Build Coastguard Worker    return (
520*da0073e9SAndroid Build Coastguard Worker        -p.entropy()
521*da0073e9SAndroid Build Coastguard Worker        - p.mean * q.logits
522*da0073e9SAndroid Build Coastguard Worker        - torch.log1p(-q.probs)
523*da0073e9SAndroid Build Coastguard Worker        - q._cont_bern_log_norm()
524*da0073e9SAndroid Build Coastguard Worker    )
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Pareto)
528*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_infinity(p, q):
529*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.concentration1)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Exponential)
533*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_exponential(p, q):
534*da0073e9SAndroid Build Coastguard Worker    return (
535*da0073e9SAndroid Build Coastguard Worker        -p.entropy()
536*da0073e9SAndroid Build Coastguard Worker        - q.rate.log()
537*da0073e9SAndroid Build Coastguard Worker        + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
538*da0073e9SAndroid Build Coastguard Worker    )
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Gamma)
542*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_gamma(p, q):
543*da0073e9SAndroid Build Coastguard Worker    t1 = -p.entropy()
544*da0073e9SAndroid Build Coastguard Worker    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
545*da0073e9SAndroid Build Coastguard Worker    t3 = (q.concentration - 1) * (
546*da0073e9SAndroid Build Coastguard Worker        p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
547*da0073e9SAndroid Build Coastguard Worker    )
548*da0073e9SAndroid Build Coastguard Worker    t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
549*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 - t3 + t4
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker# TODO: Add Beta-Laplace KL Divergence
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Normal)
556*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_normal(p, q):
557*da0073e9SAndroid Build Coastguard Worker    E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
558*da0073e9SAndroid Build Coastguard Worker    var_normal = q.scale.pow(2)
559*da0073e9SAndroid Build Coastguard Worker    t1 = -p.entropy()
560*da0073e9SAndroid Build Coastguard Worker    t2 = 0.5 * (var_normal * 2 * math.pi).log()
561*da0073e9SAndroid Build Coastguard Worker    t3 = (
562*da0073e9SAndroid Build Coastguard Worker        E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
563*da0073e9SAndroid Build Coastguard Worker        + E_beta.pow(2)
564*da0073e9SAndroid Build Coastguard Worker    ) * 0.5
565*da0073e9SAndroid Build Coastguard Worker    t4 = q.loc * E_beta
566*da0073e9SAndroid Build Coastguard Worker    t5 = q.loc.pow(2) * 0.5
567*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + (t3 - t4 + t5) / var_normal
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Uniform)
571*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_uniform(p, q):
572*da0073e9SAndroid Build Coastguard Worker    result = -p.entropy() + (q.high - q.low).log()
573*da0073e9SAndroid Build Coastguard Worker    result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
574*da0073e9SAndroid Build Coastguard Worker    return result
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker# Note that the KL between a ContinuousBernoulli and Beta has no closed form
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Pareto)
581*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_infinity(p, q):
582*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.probs)
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Exponential)
586*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_exponential(p, q):
587*da0073e9SAndroid Build Coastguard Worker    return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
591*da0073e9SAndroid Build Coastguard Worker# TODO: Add ContinuousBernoulli-Laplace KL Divergence
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Normal)
595*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_normal(p, q):
596*da0073e9SAndroid Build Coastguard Worker    t1 = -p.entropy()
597*da0073e9SAndroid Build Coastguard Worker    t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
598*da0073e9SAndroid Build Coastguard Worker        q.scale
599*da0073e9SAndroid Build Coastguard Worker    )
600*da0073e9SAndroid Build Coastguard Worker    t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
601*da0073e9SAndroid Build Coastguard Worker        2.0 * torch.square(q.scale)
602*da0073e9SAndroid Build Coastguard Worker    )
603*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Uniform)
607*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_uniform(p, q):
608*da0073e9SAndroid Build Coastguard Worker    result = -p.entropy() + (q.high - q.low).log()
609*da0073e9SAndroid Build Coastguard Worker    return torch.where(
610*da0073e9SAndroid Build Coastguard Worker        torch.max(
611*da0073e9SAndroid Build Coastguard Worker            torch.ge(q.low, p.support.lower_bound),
612*da0073e9SAndroid Build Coastguard Worker            torch.le(q.high, p.support.upper_bound),
613*da0073e9SAndroid Build Coastguard Worker        ),
614*da0073e9SAndroid Build Coastguard Worker        torch.ones_like(result) * inf,
615*da0073e9SAndroid Build Coastguard Worker        result,
616*da0073e9SAndroid Build Coastguard Worker    )
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Beta)
620*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, ContinuousBernoulli)
621*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Pareto)
622*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Uniform)
623*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_infinity(p, q):
624*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.rate)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Gamma)
628*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_gamma(p, q):
629*da0073e9SAndroid Build Coastguard Worker    ratio = q.rate / p.rate
630*da0073e9SAndroid Build Coastguard Worker    t1 = -q.concentration * torch.log(ratio)
631*da0073e9SAndroid Build Coastguard Worker    return (
632*da0073e9SAndroid Build Coastguard Worker        t1
633*da0073e9SAndroid Build Coastguard Worker        + ratio
634*da0073e9SAndroid Build Coastguard Worker        + q.concentration.lgamma()
635*da0073e9SAndroid Build Coastguard Worker        + q.concentration * _euler_gamma
636*da0073e9SAndroid Build Coastguard Worker        - (1 + _euler_gamma)
637*da0073e9SAndroid Build Coastguard Worker    )
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Gumbel)
641*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_gumbel(p, q):
642*da0073e9SAndroid Build Coastguard Worker    scale_rate_prod = p.rate * q.scale
643*da0073e9SAndroid Build Coastguard Worker    loc_scale_ratio = q.loc / q.scale
644*da0073e9SAndroid Build Coastguard Worker    t1 = scale_rate_prod.log() - 1
645*da0073e9SAndroid Build Coastguard Worker    t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
646*da0073e9SAndroid Build Coastguard Worker    t3 = scale_rate_prod.reciprocal()
647*da0073e9SAndroid Build Coastguard Worker    return t1 - loc_scale_ratio + t2 + t3
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker# TODO: Add Exponential-Laplace KL Divergence
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Normal)
654*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_normal(p, q):
655*da0073e9SAndroid Build Coastguard Worker    var_normal = q.scale.pow(2)
656*da0073e9SAndroid Build Coastguard Worker    rate_sqr = p.rate.pow(2)
657*da0073e9SAndroid Build Coastguard Worker    t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
658*da0073e9SAndroid Build Coastguard Worker    t2 = rate_sqr.reciprocal()
659*da0073e9SAndroid Build Coastguard Worker    t3 = q.loc / p.rate
660*da0073e9SAndroid Build Coastguard Worker    t4 = q.loc.pow(2) * 0.5
661*da0073e9SAndroid Build Coastguard Worker    return t1 - 1 + (t2 - t3 + t4) / var_normal
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Beta)
665*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, ContinuousBernoulli)
666*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Pareto)
667*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Uniform)
668*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_infinity(p, q):
669*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.concentration)
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Exponential)
673*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_exponential(p, q):
674*da0073e9SAndroid Build Coastguard Worker    return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Gumbel)
678*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_gumbel(p, q):
679*da0073e9SAndroid Build Coastguard Worker    beta_scale_prod = p.rate * q.scale
680*da0073e9SAndroid Build Coastguard Worker    loc_scale_ratio = q.loc / q.scale
681*da0073e9SAndroid Build Coastguard Worker    t1 = (
682*da0073e9SAndroid Build Coastguard Worker        (p.concentration - 1) * p.concentration.digamma()
683*da0073e9SAndroid Build Coastguard Worker        - p.concentration.lgamma()
684*da0073e9SAndroid Build Coastguard Worker        - p.concentration
685*da0073e9SAndroid Build Coastguard Worker    )
686*da0073e9SAndroid Build Coastguard Worker    t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
687*da0073e9SAndroid Build Coastguard Worker    t3 = (
688*da0073e9SAndroid Build Coastguard Worker        torch.exp(loc_scale_ratio)
689*da0073e9SAndroid Build Coastguard Worker        * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
690*da0073e9SAndroid Build Coastguard Worker        - loc_scale_ratio
691*da0073e9SAndroid Build Coastguard Worker    )
692*da0073e9SAndroid Build Coastguard Worker    return t1 + t2 + t3
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker# TODO: Add Gamma-Laplace KL Divergence
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Normal)
699*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_normal(p, q):
700*da0073e9SAndroid Build Coastguard Worker    var_normal = q.scale.pow(2)
701*da0073e9SAndroid Build Coastguard Worker    beta_sqr = p.rate.pow(2)
702*da0073e9SAndroid Build Coastguard Worker    t1 = (
703*da0073e9SAndroid Build Coastguard Worker        0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
704*da0073e9SAndroid Build Coastguard Worker        - p.concentration
705*da0073e9SAndroid Build Coastguard Worker        - p.concentration.lgamma()
706*da0073e9SAndroid Build Coastguard Worker    )
707*da0073e9SAndroid Build Coastguard Worker    t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
708*da0073e9SAndroid Build Coastguard Worker    t3 = q.loc * p.concentration / p.rate
709*da0073e9SAndroid Build Coastguard Worker    t4 = 0.5 * q.loc.pow(2)
710*da0073e9SAndroid Build Coastguard Worker    return (
711*da0073e9SAndroid Build Coastguard Worker        t1
712*da0073e9SAndroid Build Coastguard Worker        + (p.concentration - 1) * p.concentration.digamma()
713*da0073e9SAndroid Build Coastguard Worker        + (t2 - t3 + t4) / var_normal
714*da0073e9SAndroid Build Coastguard Worker    )
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Beta)
718*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, ContinuousBernoulli)
719*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Exponential)
720*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Gamma)
721*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Pareto)
722*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Uniform)
723*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_infinity(p, q):
724*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.loc)
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker# TODO: Add Gumbel-Laplace KL Divergence
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Normal)
731*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_normal(p, q):
732*da0073e9SAndroid Build Coastguard Worker    param_ratio = p.scale / q.scale
733*da0073e9SAndroid Build Coastguard Worker    t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
734*da0073e9SAndroid Build Coastguard Worker    t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
735*da0073e9SAndroid Build Coastguard Worker    t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
736*da0073e9SAndroid Build Coastguard Worker    return -t1 + t2 + t3 - (_euler_gamma + 1)
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Beta)
740*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, ContinuousBernoulli)
741*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Exponential)
742*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Gamma)
743*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Pareto)
744*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Uniform)
745*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_infinity(p, q):
746*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.loc)
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Normal)
750*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_normal(p, q):
751*da0073e9SAndroid Build Coastguard Worker    var_normal = q.scale.pow(2)
752*da0073e9SAndroid Build Coastguard Worker    scale_sqr_var_ratio = p.scale.pow(2) / var_normal
753*da0073e9SAndroid Build Coastguard Worker    t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
754*da0073e9SAndroid Build Coastguard Worker    t2 = 0.5 * p.loc.pow(2)
755*da0073e9SAndroid Build Coastguard Worker    t3 = p.loc * q.loc
756*da0073e9SAndroid Build Coastguard Worker    t4 = 0.5 * q.loc.pow(2)
757*da0073e9SAndroid Build Coastguard Worker    return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Beta)
761*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, ContinuousBernoulli)
762*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Exponential)
763*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Gamma)
764*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Pareto)
765*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Uniform)
766*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_infinity(p, q):
767*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.loc)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Gumbel)
771*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_gumbel(p, q):
772*da0073e9SAndroid Build Coastguard Worker    mean_scale_ratio = p.loc / q.scale
773*da0073e9SAndroid Build Coastguard Worker    var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
774*da0073e9SAndroid Build Coastguard Worker    loc_scale_ratio = q.loc / q.scale
775*da0073e9SAndroid Build Coastguard Worker    t1 = var_scale_sqr_ratio.log() * 0.5
776*da0073e9SAndroid Build Coastguard Worker    t2 = mean_scale_ratio - loc_scale_ratio
777*da0073e9SAndroid Build Coastguard Worker    t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
778*da0073e9SAndroid Build Coastguard Worker    return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Laplace)
782*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_laplace(p, q):
783*da0073e9SAndroid Build Coastguard Worker    loc_diff = p.loc - q.loc
784*da0073e9SAndroid Build Coastguard Worker    scale_ratio = p.scale / q.scale
785*da0073e9SAndroid Build Coastguard Worker    loc_diff_scale_ratio = loc_diff / p.scale
786*da0073e9SAndroid Build Coastguard Worker    t1 = torch.log(scale_ratio)
787*da0073e9SAndroid Build Coastguard Worker    t2 = (
788*da0073e9SAndroid Build Coastguard Worker        math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
789*da0073e9SAndroid Build Coastguard Worker    )
790*da0073e9SAndroid Build Coastguard Worker    t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
791*da0073e9SAndroid Build Coastguard Worker    return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Beta)
795*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, ContinuousBernoulli)
796*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Uniform)
797*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_infinity(p, q):
798*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.scale)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Exponential)
802*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_exponential(p, q):
803*da0073e9SAndroid Build Coastguard Worker    scale_rate_prod = p.scale * q.rate
804*da0073e9SAndroid Build Coastguard Worker    t1 = (p.alpha / scale_rate_prod).log()
805*da0073e9SAndroid Build Coastguard Worker    t2 = p.alpha.reciprocal()
806*da0073e9SAndroid Build Coastguard Worker    t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
807*da0073e9SAndroid Build Coastguard Worker    result = t1 - t2 + t3 - 1
808*da0073e9SAndroid Build Coastguard Worker    result[p.alpha <= 1] = inf
809*da0073e9SAndroid Build Coastguard Worker    return result
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Gamma)
813*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_gamma(p, q):
814*da0073e9SAndroid Build Coastguard Worker    common_term = p.scale.log() + p.alpha.reciprocal()
815*da0073e9SAndroid Build Coastguard Worker    t1 = p.alpha.log() - common_term
816*da0073e9SAndroid Build Coastguard Worker    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
817*da0073e9SAndroid Build Coastguard Worker    t3 = (1 - q.concentration) * common_term
818*da0073e9SAndroid Build Coastguard Worker    t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
819*da0073e9SAndroid Build Coastguard Worker    result = t1 + t2 + t3 + t4 - 1
820*da0073e9SAndroid Build Coastguard Worker    result[p.alpha <= 1] = inf
821*da0073e9SAndroid Build Coastguard Worker    return result
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker# TODO: Add Pareto-Laplace KL Divergence
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Normal)
828*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_normal(p, q):
829*da0073e9SAndroid Build Coastguard Worker    var_normal = 2 * q.scale.pow(2)
830*da0073e9SAndroid Build Coastguard Worker    common_term = p.scale / (p.alpha - 1)
831*da0073e9SAndroid Build Coastguard Worker    t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
832*da0073e9SAndroid Build Coastguard Worker    t2 = p.alpha.reciprocal()
833*da0073e9SAndroid Build Coastguard Worker    t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
834*da0073e9SAndroid Build Coastguard Worker    t4 = (p.alpha * common_term - q.loc).pow(2)
835*da0073e9SAndroid Build Coastguard Worker    result = t1 - t2 + (t3 + t4) / var_normal - 1
836*da0073e9SAndroid Build Coastguard Worker    result[p.alpha <= 2] = inf
837*da0073e9SAndroid Build Coastguard Worker    return result
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Bernoulli)
841*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Binomial)
842*da0073e9SAndroid Build Coastguard Workerdef _kl_poisson_infinity(p, q):
843*da0073e9SAndroid Build Coastguard Worker    return _infinite_like(p.rate)
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Beta)
847*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_beta(p, q):
848*da0073e9SAndroid Build Coastguard Worker    common_term = p.high - p.low
849*da0073e9SAndroid Build Coastguard Worker    t1 = torch.log(common_term)
850*da0073e9SAndroid Build Coastguard Worker    t2 = (
851*da0073e9SAndroid Build Coastguard Worker        (q.concentration1 - 1)
852*da0073e9SAndroid Build Coastguard Worker        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
853*da0073e9SAndroid Build Coastguard Worker        / common_term
854*da0073e9SAndroid Build Coastguard Worker    )
855*da0073e9SAndroid Build Coastguard Worker    t3 = (
856*da0073e9SAndroid Build Coastguard Worker        (q.concentration0 - 1)
857*da0073e9SAndroid Build Coastguard Worker        * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
858*da0073e9SAndroid Build Coastguard Worker        / common_term
859*da0073e9SAndroid Build Coastguard Worker    )
860*da0073e9SAndroid Build Coastguard Worker    t4 = (
861*da0073e9SAndroid Build Coastguard Worker        q.concentration1.lgamma()
862*da0073e9SAndroid Build Coastguard Worker        + q.concentration0.lgamma()
863*da0073e9SAndroid Build Coastguard Worker        - (q.concentration1 + q.concentration0).lgamma()
864*da0073e9SAndroid Build Coastguard Worker    )
865*da0073e9SAndroid Build Coastguard Worker    result = t3 + t4 - t1 - t2
866*da0073e9SAndroid Build Coastguard Worker    result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
867*da0073e9SAndroid Build Coastguard Worker    return result
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, ContinuousBernoulli)
871*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_continuous_bernoulli(p, q):
872*da0073e9SAndroid Build Coastguard Worker    result = (
873*da0073e9SAndroid Build Coastguard Worker        -p.entropy()
874*da0073e9SAndroid Build Coastguard Worker        - p.mean * q.logits
875*da0073e9SAndroid Build Coastguard Worker        - torch.log1p(-q.probs)
876*da0073e9SAndroid Build Coastguard Worker        - q._cont_bern_log_norm()
877*da0073e9SAndroid Build Coastguard Worker    )
878*da0073e9SAndroid Build Coastguard Worker    return torch.where(
879*da0073e9SAndroid Build Coastguard Worker        torch.max(
880*da0073e9SAndroid Build Coastguard Worker            torch.ge(p.high, q.support.upper_bound),
881*da0073e9SAndroid Build Coastguard Worker            torch.le(p.low, q.support.lower_bound),
882*da0073e9SAndroid Build Coastguard Worker        ),
883*da0073e9SAndroid Build Coastguard Worker        torch.ones_like(result) * inf,
884*da0073e9SAndroid Build Coastguard Worker        result,
885*da0073e9SAndroid Build Coastguard Worker    )
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Exponential)
889*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_exponetial(p, q):
890*da0073e9SAndroid Build Coastguard Worker    result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
891*da0073e9SAndroid Build Coastguard Worker    result[p.low < q.support.lower_bound] = inf
892*da0073e9SAndroid Build Coastguard Worker    return result
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Gamma)
896*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_gamma(p, q):
897*da0073e9SAndroid Build Coastguard Worker    common_term = p.high - p.low
898*da0073e9SAndroid Build Coastguard Worker    t1 = common_term.log()
899*da0073e9SAndroid Build Coastguard Worker    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
900*da0073e9SAndroid Build Coastguard Worker    t3 = (
901*da0073e9SAndroid Build Coastguard Worker        (1 - q.concentration)
902*da0073e9SAndroid Build Coastguard Worker        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
903*da0073e9SAndroid Build Coastguard Worker        / common_term
904*da0073e9SAndroid Build Coastguard Worker    )
905*da0073e9SAndroid Build Coastguard Worker    t4 = q.rate * (p.high + p.low) / 2
906*da0073e9SAndroid Build Coastguard Worker    result = -t1 + t2 + t3 + t4
907*da0073e9SAndroid Build Coastguard Worker    result[p.low < q.support.lower_bound] = inf
908*da0073e9SAndroid Build Coastguard Worker    return result
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Gumbel)
912*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_gumbel(p, q):
913*da0073e9SAndroid Build Coastguard Worker    common_term = q.scale / (p.high - p.low)
914*da0073e9SAndroid Build Coastguard Worker    high_loc_diff = (p.high - q.loc) / q.scale
915*da0073e9SAndroid Build Coastguard Worker    low_loc_diff = (p.low - q.loc) / q.scale
916*da0073e9SAndroid Build Coastguard Worker    t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
917*da0073e9SAndroid Build Coastguard Worker    t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
918*da0073e9SAndroid Build Coastguard Worker    return t1 - t2
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker# TODO: Uniform-Laplace KL Divergence
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Normal)
925*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_normal(p, q):
926*da0073e9SAndroid Build Coastguard Worker    common_term = p.high - p.low
927*da0073e9SAndroid Build Coastguard Worker    t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
928*da0073e9SAndroid Build Coastguard Worker    t2 = (common_term).pow(2) / 12
929*da0073e9SAndroid Build Coastguard Worker    t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
930*da0073e9SAndroid Build Coastguard Worker    return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Pareto)
934*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_pareto(p, q):
935*da0073e9SAndroid Build Coastguard Worker    support_uniform = p.high - p.low
936*da0073e9SAndroid Build Coastguard Worker    t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
937*da0073e9SAndroid Build Coastguard Worker    t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
938*da0073e9SAndroid Build Coastguard Worker    result = t2 * (q.alpha + 1) - t1
939*da0073e9SAndroid Build Coastguard Worker    result[p.low < q.support.lower_bound] = inf
940*da0073e9SAndroid Build Coastguard Worker    return result
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker@register_kl(Independent, Independent)
944*da0073e9SAndroid Build Coastguard Workerdef _kl_independent_independent(p, q):
945*da0073e9SAndroid Build Coastguard Worker    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
946*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
947*da0073e9SAndroid Build Coastguard Worker    result = kl_divergence(p.base_dist, q.base_dist)
948*da0073e9SAndroid Build Coastguard Worker    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker@register_kl(Cauchy, Cauchy)
952*da0073e9SAndroid Build Coastguard Workerdef _kl_cauchy_cauchy(p, q):
953*da0073e9SAndroid Build Coastguard Worker    # From https://arxiv.org/abs/1905.10965
954*da0073e9SAndroid Build Coastguard Worker    t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
955*da0073e9SAndroid Build Coastguard Worker    t2 = (4 * p.scale * q.scale).log()
956*da0073e9SAndroid Build Coastguard Worker    return t1 - t2
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Workerdef _add_kl_info():
960*da0073e9SAndroid Build Coastguard Worker    """Appends a list of implemented KL functions to the doc for kl_divergence."""
961*da0073e9SAndroid Build Coastguard Worker    rows = [
962*da0073e9SAndroid Build Coastguard Worker        "KL divergence is currently implemented for the following distribution pairs:"
963*da0073e9SAndroid Build Coastguard Worker    ]
964*da0073e9SAndroid Build Coastguard Worker    for p, q in sorted(
965*da0073e9SAndroid Build Coastguard Worker        _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
966*da0073e9SAndroid Build Coastguard Worker    ):
967*da0073e9SAndroid Build Coastguard Worker        rows.append(
968*da0073e9SAndroid Build Coastguard Worker            f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
969*da0073e9SAndroid Build Coastguard Worker        )
970*da0073e9SAndroid Build Coastguard Worker    kl_info = "\n\t".join(rows)
971*da0073e9SAndroid Build Coastguard Worker    if kl_divergence.__doc__:
972*da0073e9SAndroid Build Coastguard Worker        kl_divergence.__doc__ += kl_info  # type: ignore[operator]
973