xref: /aosp_15_r20/external/pytorch/torch/distributions/kl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3import warnings
4from functools import total_ordering
5from typing import Callable, Dict, Tuple, Type
6
7import torch
8from torch import inf
9
10from .bernoulli import Bernoulli
11from .beta import Beta
12from .binomial import Binomial
13from .categorical import Categorical
14from .cauchy import Cauchy
15from .continuous_bernoulli import ContinuousBernoulli
16from .dirichlet import Dirichlet
17from .distribution import Distribution
18from .exp_family import ExponentialFamily
19from .exponential import Exponential
20from .gamma import Gamma
21from .geometric import Geometric
22from .gumbel import Gumbel
23from .half_normal import HalfNormal
24from .independent import Independent
25from .laplace import Laplace
26from .lowrank_multivariate_normal import (
27    _batch_lowrank_logdet,
28    _batch_lowrank_mahalanobis,
29    LowRankMultivariateNormal,
30)
31from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
32from .normal import Normal
33from .one_hot_categorical import OneHotCategorical
34from .pareto import Pareto
35from .poisson import Poisson
36from .transformed_distribution import TransformedDistribution
37from .uniform import Uniform
38from .utils import _sum_rightmost, euler_constant as _euler_gamma
39
40
41_KL_REGISTRY: Dict[
42    Tuple[Type, Type], Callable
43] = {}  # Source of truth mapping a few general (type, type) pairs to functions.
44_KL_MEMOIZE: Dict[
45    Tuple[Type, Type], Callable
46] = {}  # Memoized version mapping many specific (type, type) pairs to functions.
47
48__all__ = ["register_kl", "kl_divergence"]
49
50
51def register_kl(type_p, type_q):
52    """
53    Decorator to register a pairwise function with :meth:`kl_divergence`.
54    Usage::
55
56        @register_kl(Normal, Normal)
57        def kl_normal_normal(p, q):
58            # insert implementation here
59
60    Lookup returns the most specific (type,type) match ordered by subclass. If
61    the match is ambiguous, a `RuntimeWarning` is raised. For example to
62    resolve the ambiguous situation::
63
64        @register_kl(BaseP, DerivedQ)
65        def kl_version1(p, q): ...
66        @register_kl(DerivedP, BaseQ)
67        def kl_version2(p, q): ...
68
69    you should register a third most-specific implementation, e.g.::
70
71        register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
72
73    Args:
74        type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
75        type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
76    """
77    if not isinstance(type_p, type) and issubclass(type_p, Distribution):
78        raise TypeError(
79            f"Expected type_p to be a Distribution subclass but got {type_p}"
80        )
81    if not isinstance(type_q, type) and issubclass(type_q, Distribution):
82        raise TypeError(
83            f"Expected type_q to be a Distribution subclass but got {type_q}"
84        )
85
86    def decorator(fun):
87        _KL_REGISTRY[type_p, type_q] = fun
88        _KL_MEMOIZE.clear()  # reset since lookup order may have changed
89        return fun
90
91    return decorator
92
93
94@total_ordering
95class _Match:
96    __slots__ = ["types"]
97
98    def __init__(self, *types):
99        self.types = types
100
101    def __eq__(self, other):
102        return self.types == other.types
103
104    def __le__(self, other):
105        for x, y in zip(self.types, other.types):
106            if not issubclass(x, y):
107                return False
108            if x is not y:
109                break
110        return True
111
112
113def _dispatch_kl(type_p, type_q):
114    """
115    Find the most specific approximate match, assuming single inheritance.
116    """
117    matches = [
118        (super_p, super_q)
119        for super_p, super_q in _KL_REGISTRY
120        if issubclass(type_p, super_p) and issubclass(type_q, super_q)
121    ]
122    if not matches:
123        return NotImplemented
124    # Check that the left- and right- lexicographic orders agree.
125    # mypy isn't smart enough to know that _Match implements __lt__
126    # see: https://github.com/python/typing/issues/760#issuecomment-710670503
127    left_p, left_q = min(_Match(*m) for m in matches).types  # type: ignore[type-var]
128    right_q, right_p = min(_Match(*reversed(m)) for m in matches).types  # type: ignore[type-var]
129    left_fun = _KL_REGISTRY[left_p, left_q]
130    right_fun = _KL_REGISTRY[right_p, right_q]
131    if left_fun is not right_fun:
132        warnings.warn(
133            f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). "
134            f"Please register_kl({left_p.__name__}, {right_q.__name__})",
135            RuntimeWarning,
136        )
137    return left_fun
138
139
140def _infinite_like(tensor):
141    """
142    Helper function for obtaining infinite KL Divergence throughout
143    """
144    return torch.full_like(tensor, inf)
145
146
147def _x_log_x(tensor):
148    """
149    Utility function for calculating x log x
150    """
151    return tensor * tensor.log()
152
153
154def _batch_trace_XXT(bmat):
155    """
156    Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
157    """
158    n = bmat.size(-1)
159    m = bmat.size(-2)
160    flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
161    return flat_trace.reshape(bmat.shape[:-2])
162
163
164def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
165    r"""
166    Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
167
168    .. math::
169
170        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
171
172    Args:
173        p (Distribution): A :class:`~torch.distributions.Distribution` object.
174        q (Distribution): A :class:`~torch.distributions.Distribution` object.
175
176    Returns:
177        Tensor: A batch of KL divergences of shape `batch_shape`.
178
179    Raises:
180        NotImplementedError: If the distribution types have not been registered via
181            :meth:`register_kl`.
182    """
183    try:
184        fun = _KL_MEMOIZE[type(p), type(q)]
185    except KeyError:
186        fun = _dispatch_kl(type(p), type(q))
187        _KL_MEMOIZE[type(p), type(q)] = fun
188    if fun is NotImplemented:
189        raise NotImplementedError(
190            f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
191        )
192    return fun(p, q)
193
194
195################################################################################
196# KL Divergence Implementations
197################################################################################
198
199# Same distributions
200
201
202@register_kl(Bernoulli, Bernoulli)
203def _kl_bernoulli_bernoulli(p, q):
204    t1 = p.probs * (
205        torch.nn.functional.softplus(-q.logits)
206        - torch.nn.functional.softplus(-p.logits)
207    )
208    t1[q.probs == 0] = inf
209    t1[p.probs == 0] = 0
210    t2 = (1 - p.probs) * (
211        torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
212    )
213    t2[q.probs == 1] = inf
214    t2[p.probs == 1] = 0
215    return t1 + t2
216
217
218@register_kl(Beta, Beta)
219def _kl_beta_beta(p, q):
220    sum_params_p = p.concentration1 + p.concentration0
221    sum_params_q = q.concentration1 + q.concentration0
222    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
223    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
224    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
225    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
226    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
227    return t1 - t2 + t3 + t4 + t5
228
229
230@register_kl(Binomial, Binomial)
231def _kl_binomial_binomial(p, q):
232    # from https://math.stackexchange.com/questions/2214993/
233    # kullback-leibler-divergence-for-binomial-distributions-p-and-q
234    if (p.total_count < q.total_count).any():
235        raise NotImplementedError(
236            "KL between Binomials where q.total_count > p.total_count is not implemented"
237        )
238    kl = p.total_count * (
239        p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
240    )
241    inf_idxs = p.total_count > q.total_count
242    kl[inf_idxs] = _infinite_like(kl[inf_idxs])
243    return kl
244
245
246@register_kl(Categorical, Categorical)
247def _kl_categorical_categorical(p, q):
248    t = p.probs * (p.logits - q.logits)
249    t[(q.probs == 0).expand_as(t)] = inf
250    t[(p.probs == 0).expand_as(t)] = 0
251    return t.sum(-1)
252
253
254@register_kl(ContinuousBernoulli, ContinuousBernoulli)
255def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
256    t1 = p.mean * (p.logits - q.logits)
257    t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
258    t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
259    return t1 + t2 + t3
260
261
262@register_kl(Dirichlet, Dirichlet)
263def _kl_dirichlet_dirichlet(p, q):
264    # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
265    sum_p_concentration = p.concentration.sum(-1)
266    sum_q_concentration = q.concentration.sum(-1)
267    t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
268    t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
269    t3 = p.concentration - q.concentration
270    t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
271    return t1 - t2 + (t3 * t4).sum(-1)
272
273
274@register_kl(Exponential, Exponential)
275def _kl_exponential_exponential(p, q):
276    rate_ratio = q.rate / p.rate
277    t1 = -rate_ratio.log()
278    return t1 + rate_ratio - 1
279
280
281@register_kl(ExponentialFamily, ExponentialFamily)
282def _kl_expfamily_expfamily(p, q):
283    if not type(p) == type(q):
284        raise NotImplementedError(
285            "The cross KL-divergence between different exponential families cannot \
286                            be computed using Bregman divergences"
287        )
288    p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
289    q_nparams = q._natural_params
290    lg_normal = p._log_normalizer(*p_nparams)
291    gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
292    result = q._log_normalizer(*q_nparams) - lg_normal
293    for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
294        term = (qnp - pnp) * g
295        result -= _sum_rightmost(term, len(q.event_shape))
296    return result
297
298
299@register_kl(Gamma, Gamma)
300def _kl_gamma_gamma(p, q):
301    t1 = q.concentration * (p.rate / q.rate).log()
302    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
303    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
304    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
305    return t1 + t2 + t3 + t4
306
307
308@register_kl(Gumbel, Gumbel)
309def _kl_gumbel_gumbel(p, q):
310    ct1 = p.scale / q.scale
311    ct2 = q.loc / q.scale
312    ct3 = p.loc / q.scale
313    t1 = -ct1.log() - ct2 + ct3
314    t2 = ct1 * _euler_gamma
315    t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
316    return t1 + t2 + t3 - (1 + _euler_gamma)
317
318
319@register_kl(Geometric, Geometric)
320def _kl_geometric_geometric(p, q):
321    return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
322
323
324@register_kl(HalfNormal, HalfNormal)
325def _kl_halfnormal_halfnormal(p, q):
326    return _kl_normal_normal(p.base_dist, q.base_dist)
327
328
329@register_kl(Laplace, Laplace)
330def _kl_laplace_laplace(p, q):
331    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
332    scale_ratio = p.scale / q.scale
333    loc_abs_diff = (p.loc - q.loc).abs()
334    t1 = -scale_ratio.log()
335    t2 = loc_abs_diff / q.scale
336    t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
337    return t1 + t2 + t3 - 1
338
339
340@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
341def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
342    if p.event_shape != q.event_shape:
343        raise ValueError(
344            "KL-divergence between two Low Rank Multivariate Normals with\
345                          different event shapes cannot be computed"
346        )
347
348    term1 = _batch_lowrank_logdet(
349        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
350    ) - _batch_lowrank_logdet(
351        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
352    )
353    term3 = _batch_lowrank_mahalanobis(
354        q._unbroadcasted_cov_factor,
355        q._unbroadcasted_cov_diag,
356        q.loc - p.loc,
357        q._capacitance_tril,
358    )
359    # Expands term2 according to
360    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
361    #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
362    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
363    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
364    term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
365    term22 = _batch_trace_XXT(
366        p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
367    )
368    term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
369    term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
370    term2 = term21 + term22 - term23 - term24
371    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
372
373
374@register_kl(MultivariateNormal, LowRankMultivariateNormal)
375def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
376    if p.event_shape != q.event_shape:
377        raise ValueError(
378            "KL-divergence between two (Low Rank) Multivariate Normals with\
379                          different event shapes cannot be computed"
380        )
381
382    term1 = _batch_lowrank_logdet(
383        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
384    ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
385    term3 = _batch_lowrank_mahalanobis(
386        q._unbroadcasted_cov_factor,
387        q._unbroadcasted_cov_diag,
388        q.loc - p.loc,
389        q._capacitance_tril,
390    )
391    # Expands term2 according to
392    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
393    #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
394    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
395    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
396    term21 = _batch_trace_XXT(
397        p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
398    )
399    term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
400    term2 = term21 - term22
401    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
402
403
404@register_kl(LowRankMultivariateNormal, MultivariateNormal)
405def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
406    if p.event_shape != q.event_shape:
407        raise ValueError(
408            "KL-divergence between two (Low Rank) Multivariate Normals with\
409                          different event shapes cannot be computed"
410        )
411
412    term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
413        -1
414    ) - _batch_lowrank_logdet(
415        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
416    )
417    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
418    # Expands term2 according to
419    # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
420    combined_batch_shape = torch._C._infer_size(
421        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
422    )
423    n = p.event_shape[0]
424    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
425    p_cov_factor = p._unbroadcasted_cov_factor.expand(
426        combined_batch_shape + (n, p.cov_factor.size(-1))
427    )
428    p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
429        combined_batch_shape + (n, n)
430    )
431    term21 = _batch_trace_XXT(
432        torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
433    )
434    term22 = _batch_trace_XXT(
435        torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
436    )
437    term2 = term21 + term22
438    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
439
440
441@register_kl(MultivariateNormal, MultivariateNormal)
442def _kl_multivariatenormal_multivariatenormal(p, q):
443    # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
444    if p.event_shape != q.event_shape:
445        raise ValueError(
446            "KL-divergence between two Multivariate Normals with\
447                          different event shapes cannot be computed"
448        )
449
450    half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
451        -1
452    ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
453    combined_batch_shape = torch._C._infer_size(
454        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
455    )
456    n = p.event_shape[0]
457    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
458    p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
459    term2 = _batch_trace_XXT(
460        torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
461    )
462    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
463    return half_term1 + 0.5 * (term2 + term3 - n)
464
465
466@register_kl(Normal, Normal)
467def _kl_normal_normal(p, q):
468    var_ratio = (p.scale / q.scale).pow(2)
469    t1 = ((p.loc - q.loc) / q.scale).pow(2)
470    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
471
472
473@register_kl(OneHotCategorical, OneHotCategorical)
474def _kl_onehotcategorical_onehotcategorical(p, q):
475    return _kl_categorical_categorical(p._categorical, q._categorical)
476
477
478@register_kl(Pareto, Pareto)
479def _kl_pareto_pareto(p, q):
480    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
481    scale_ratio = p.scale / q.scale
482    alpha_ratio = q.alpha / p.alpha
483    t1 = q.alpha * scale_ratio.log()
484    t2 = -alpha_ratio.log()
485    result = t1 + t2 + alpha_ratio - 1
486    result[p.support.lower_bound < q.support.lower_bound] = inf
487    return result
488
489
490@register_kl(Poisson, Poisson)
491def _kl_poisson_poisson(p, q):
492    return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
493
494
495@register_kl(TransformedDistribution, TransformedDistribution)
496def _kl_transformed_transformed(p, q):
497    if p.transforms != q.transforms:
498        raise NotImplementedError
499    if p.event_shape != q.event_shape:
500        raise NotImplementedError
501    return kl_divergence(p.base_dist, q.base_dist)
502
503
504@register_kl(Uniform, Uniform)
505def _kl_uniform_uniform(p, q):
506    result = ((q.high - q.low) / (p.high - p.low)).log()
507    result[(q.low > p.low) | (q.high < p.high)] = inf
508    return result
509
510
511# Different distributions
512@register_kl(Bernoulli, Poisson)
513def _kl_bernoulli_poisson(p, q):
514    return -p.entropy() - (p.probs * q.rate.log() - q.rate)
515
516
517@register_kl(Beta, ContinuousBernoulli)
518def _kl_beta_continuous_bernoulli(p, q):
519    return (
520        -p.entropy()
521        - p.mean * q.logits
522        - torch.log1p(-q.probs)
523        - q._cont_bern_log_norm()
524    )
525
526
527@register_kl(Beta, Pareto)
528def _kl_beta_infinity(p, q):
529    return _infinite_like(p.concentration1)
530
531
532@register_kl(Beta, Exponential)
533def _kl_beta_exponential(p, q):
534    return (
535        -p.entropy()
536        - q.rate.log()
537        + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
538    )
539
540
541@register_kl(Beta, Gamma)
542def _kl_beta_gamma(p, q):
543    t1 = -p.entropy()
544    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
545    t3 = (q.concentration - 1) * (
546        p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
547    )
548    t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
549    return t1 + t2 - t3 + t4
550
551
552# TODO: Add Beta-Laplace KL Divergence
553
554
555@register_kl(Beta, Normal)
556def _kl_beta_normal(p, q):
557    E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
558    var_normal = q.scale.pow(2)
559    t1 = -p.entropy()
560    t2 = 0.5 * (var_normal * 2 * math.pi).log()
561    t3 = (
562        E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
563        + E_beta.pow(2)
564    ) * 0.5
565    t4 = q.loc * E_beta
566    t5 = q.loc.pow(2) * 0.5
567    return t1 + t2 + (t3 - t4 + t5) / var_normal
568
569
570@register_kl(Beta, Uniform)
571def _kl_beta_uniform(p, q):
572    result = -p.entropy() + (q.high - q.low).log()
573    result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
574    return result
575
576
577# Note that the KL between a ContinuousBernoulli and Beta has no closed form
578
579
580@register_kl(ContinuousBernoulli, Pareto)
581def _kl_continuous_bernoulli_infinity(p, q):
582    return _infinite_like(p.probs)
583
584
585@register_kl(ContinuousBernoulli, Exponential)
586def _kl_continuous_bernoulli_exponential(p, q):
587    return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
588
589
590# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
591# TODO: Add ContinuousBernoulli-Laplace KL Divergence
592
593
594@register_kl(ContinuousBernoulli, Normal)
595def _kl_continuous_bernoulli_normal(p, q):
596    t1 = -p.entropy()
597    t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
598        q.scale
599    )
600    t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
601        2.0 * torch.square(q.scale)
602    )
603    return t1 + t2 + t3
604
605
606@register_kl(ContinuousBernoulli, Uniform)
607def _kl_continuous_bernoulli_uniform(p, q):
608    result = -p.entropy() + (q.high - q.low).log()
609    return torch.where(
610        torch.max(
611            torch.ge(q.low, p.support.lower_bound),
612            torch.le(q.high, p.support.upper_bound),
613        ),
614        torch.ones_like(result) * inf,
615        result,
616    )
617
618
619@register_kl(Exponential, Beta)
620@register_kl(Exponential, ContinuousBernoulli)
621@register_kl(Exponential, Pareto)
622@register_kl(Exponential, Uniform)
623def _kl_exponential_infinity(p, q):
624    return _infinite_like(p.rate)
625
626
627@register_kl(Exponential, Gamma)
628def _kl_exponential_gamma(p, q):
629    ratio = q.rate / p.rate
630    t1 = -q.concentration * torch.log(ratio)
631    return (
632        t1
633        + ratio
634        + q.concentration.lgamma()
635        + q.concentration * _euler_gamma
636        - (1 + _euler_gamma)
637    )
638
639
640@register_kl(Exponential, Gumbel)
641def _kl_exponential_gumbel(p, q):
642    scale_rate_prod = p.rate * q.scale
643    loc_scale_ratio = q.loc / q.scale
644    t1 = scale_rate_prod.log() - 1
645    t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
646    t3 = scale_rate_prod.reciprocal()
647    return t1 - loc_scale_ratio + t2 + t3
648
649
650# TODO: Add Exponential-Laplace KL Divergence
651
652
653@register_kl(Exponential, Normal)
654def _kl_exponential_normal(p, q):
655    var_normal = q.scale.pow(2)
656    rate_sqr = p.rate.pow(2)
657    t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
658    t2 = rate_sqr.reciprocal()
659    t3 = q.loc / p.rate
660    t4 = q.loc.pow(2) * 0.5
661    return t1 - 1 + (t2 - t3 + t4) / var_normal
662
663
664@register_kl(Gamma, Beta)
665@register_kl(Gamma, ContinuousBernoulli)
666@register_kl(Gamma, Pareto)
667@register_kl(Gamma, Uniform)
668def _kl_gamma_infinity(p, q):
669    return _infinite_like(p.concentration)
670
671
672@register_kl(Gamma, Exponential)
673def _kl_gamma_exponential(p, q):
674    return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
675
676
677@register_kl(Gamma, Gumbel)
678def _kl_gamma_gumbel(p, q):
679    beta_scale_prod = p.rate * q.scale
680    loc_scale_ratio = q.loc / q.scale
681    t1 = (
682        (p.concentration - 1) * p.concentration.digamma()
683        - p.concentration.lgamma()
684        - p.concentration
685    )
686    t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
687    t3 = (
688        torch.exp(loc_scale_ratio)
689        * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
690        - loc_scale_ratio
691    )
692    return t1 + t2 + t3
693
694
695# TODO: Add Gamma-Laplace KL Divergence
696
697
698@register_kl(Gamma, Normal)
699def _kl_gamma_normal(p, q):
700    var_normal = q.scale.pow(2)
701    beta_sqr = p.rate.pow(2)
702    t1 = (
703        0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
704        - p.concentration
705        - p.concentration.lgamma()
706    )
707    t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
708    t3 = q.loc * p.concentration / p.rate
709    t4 = 0.5 * q.loc.pow(2)
710    return (
711        t1
712        + (p.concentration - 1) * p.concentration.digamma()
713        + (t2 - t3 + t4) / var_normal
714    )
715
716
717@register_kl(Gumbel, Beta)
718@register_kl(Gumbel, ContinuousBernoulli)
719@register_kl(Gumbel, Exponential)
720@register_kl(Gumbel, Gamma)
721@register_kl(Gumbel, Pareto)
722@register_kl(Gumbel, Uniform)
723def _kl_gumbel_infinity(p, q):
724    return _infinite_like(p.loc)
725
726
727# TODO: Add Gumbel-Laplace KL Divergence
728
729
730@register_kl(Gumbel, Normal)
731def _kl_gumbel_normal(p, q):
732    param_ratio = p.scale / q.scale
733    t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
734    t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
735    t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
736    return -t1 + t2 + t3 - (_euler_gamma + 1)
737
738
739@register_kl(Laplace, Beta)
740@register_kl(Laplace, ContinuousBernoulli)
741@register_kl(Laplace, Exponential)
742@register_kl(Laplace, Gamma)
743@register_kl(Laplace, Pareto)
744@register_kl(Laplace, Uniform)
745def _kl_laplace_infinity(p, q):
746    return _infinite_like(p.loc)
747
748
749@register_kl(Laplace, Normal)
750def _kl_laplace_normal(p, q):
751    var_normal = q.scale.pow(2)
752    scale_sqr_var_ratio = p.scale.pow(2) / var_normal
753    t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
754    t2 = 0.5 * p.loc.pow(2)
755    t3 = p.loc * q.loc
756    t4 = 0.5 * q.loc.pow(2)
757    return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
758
759
760@register_kl(Normal, Beta)
761@register_kl(Normal, ContinuousBernoulli)
762@register_kl(Normal, Exponential)
763@register_kl(Normal, Gamma)
764@register_kl(Normal, Pareto)
765@register_kl(Normal, Uniform)
766def _kl_normal_infinity(p, q):
767    return _infinite_like(p.loc)
768
769
770@register_kl(Normal, Gumbel)
771def _kl_normal_gumbel(p, q):
772    mean_scale_ratio = p.loc / q.scale
773    var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
774    loc_scale_ratio = q.loc / q.scale
775    t1 = var_scale_sqr_ratio.log() * 0.5
776    t2 = mean_scale_ratio - loc_scale_ratio
777    t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
778    return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
779
780
781@register_kl(Normal, Laplace)
782def _kl_normal_laplace(p, q):
783    loc_diff = p.loc - q.loc
784    scale_ratio = p.scale / q.scale
785    loc_diff_scale_ratio = loc_diff / p.scale
786    t1 = torch.log(scale_ratio)
787    t2 = (
788        math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
789    )
790    t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
791    return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
792
793
794@register_kl(Pareto, Beta)
795@register_kl(Pareto, ContinuousBernoulli)
796@register_kl(Pareto, Uniform)
797def _kl_pareto_infinity(p, q):
798    return _infinite_like(p.scale)
799
800
801@register_kl(Pareto, Exponential)
802def _kl_pareto_exponential(p, q):
803    scale_rate_prod = p.scale * q.rate
804    t1 = (p.alpha / scale_rate_prod).log()
805    t2 = p.alpha.reciprocal()
806    t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
807    result = t1 - t2 + t3 - 1
808    result[p.alpha <= 1] = inf
809    return result
810
811
812@register_kl(Pareto, Gamma)
813def _kl_pareto_gamma(p, q):
814    common_term = p.scale.log() + p.alpha.reciprocal()
815    t1 = p.alpha.log() - common_term
816    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
817    t3 = (1 - q.concentration) * common_term
818    t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
819    result = t1 + t2 + t3 + t4 - 1
820    result[p.alpha <= 1] = inf
821    return result
822
823
824# TODO: Add Pareto-Laplace KL Divergence
825
826
827@register_kl(Pareto, Normal)
828def _kl_pareto_normal(p, q):
829    var_normal = 2 * q.scale.pow(2)
830    common_term = p.scale / (p.alpha - 1)
831    t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
832    t2 = p.alpha.reciprocal()
833    t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
834    t4 = (p.alpha * common_term - q.loc).pow(2)
835    result = t1 - t2 + (t3 + t4) / var_normal - 1
836    result[p.alpha <= 2] = inf
837    return result
838
839
840@register_kl(Poisson, Bernoulli)
841@register_kl(Poisson, Binomial)
842def _kl_poisson_infinity(p, q):
843    return _infinite_like(p.rate)
844
845
846@register_kl(Uniform, Beta)
847def _kl_uniform_beta(p, q):
848    common_term = p.high - p.low
849    t1 = torch.log(common_term)
850    t2 = (
851        (q.concentration1 - 1)
852        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
853        / common_term
854    )
855    t3 = (
856        (q.concentration0 - 1)
857        * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
858        / common_term
859    )
860    t4 = (
861        q.concentration1.lgamma()
862        + q.concentration0.lgamma()
863        - (q.concentration1 + q.concentration0).lgamma()
864    )
865    result = t3 + t4 - t1 - t2
866    result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
867    return result
868
869
870@register_kl(Uniform, ContinuousBernoulli)
871def _kl_uniform_continuous_bernoulli(p, q):
872    result = (
873        -p.entropy()
874        - p.mean * q.logits
875        - torch.log1p(-q.probs)
876        - q._cont_bern_log_norm()
877    )
878    return torch.where(
879        torch.max(
880            torch.ge(p.high, q.support.upper_bound),
881            torch.le(p.low, q.support.lower_bound),
882        ),
883        torch.ones_like(result) * inf,
884        result,
885    )
886
887
888@register_kl(Uniform, Exponential)
889def _kl_uniform_exponetial(p, q):
890    result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
891    result[p.low < q.support.lower_bound] = inf
892    return result
893
894
895@register_kl(Uniform, Gamma)
896def _kl_uniform_gamma(p, q):
897    common_term = p.high - p.low
898    t1 = common_term.log()
899    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
900    t3 = (
901        (1 - q.concentration)
902        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
903        / common_term
904    )
905    t4 = q.rate * (p.high + p.low) / 2
906    result = -t1 + t2 + t3 + t4
907    result[p.low < q.support.lower_bound] = inf
908    return result
909
910
911@register_kl(Uniform, Gumbel)
912def _kl_uniform_gumbel(p, q):
913    common_term = q.scale / (p.high - p.low)
914    high_loc_diff = (p.high - q.loc) / q.scale
915    low_loc_diff = (p.low - q.loc) / q.scale
916    t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
917    t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
918    return t1 - t2
919
920
921# TODO: Uniform-Laplace KL Divergence
922
923
924@register_kl(Uniform, Normal)
925def _kl_uniform_normal(p, q):
926    common_term = p.high - p.low
927    t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
928    t2 = (common_term).pow(2) / 12
929    t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
930    return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
931
932
933@register_kl(Uniform, Pareto)
934def _kl_uniform_pareto(p, q):
935    support_uniform = p.high - p.low
936    t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
937    t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
938    result = t2 * (q.alpha + 1) - t1
939    result[p.low < q.support.lower_bound] = inf
940    return result
941
942
943@register_kl(Independent, Independent)
944def _kl_independent_independent(p, q):
945    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
946        raise NotImplementedError
947    result = kl_divergence(p.base_dist, q.base_dist)
948    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
949
950
951@register_kl(Cauchy, Cauchy)
952def _kl_cauchy_cauchy(p, q):
953    # From https://arxiv.org/abs/1905.10965
954    t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
955    t2 = (4 * p.scale * q.scale).log()
956    return t1 - t2
957
958
959def _add_kl_info():
960    """Appends a list of implemented KL functions to the doc for kl_divergence."""
961    rows = [
962        "KL divergence is currently implemented for the following distribution pairs:"
963    ]
964    for p, q in sorted(
965        _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
966    ):
967        rows.append(
968            f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
969        )
970    kl_info = "\n\t".join(rows)
971    if kl_divergence.__doc__:
972        kl_divergence.__doc__ += kl_info  # type: ignore[operator]
973