1# mypy: allow-untyped-defs 2import math 3import warnings 4from functools import total_ordering 5from typing import Callable, Dict, Tuple, Type 6 7import torch 8from torch import inf 9 10from .bernoulli import Bernoulli 11from .beta import Beta 12from .binomial import Binomial 13from .categorical import Categorical 14from .cauchy import Cauchy 15from .continuous_bernoulli import ContinuousBernoulli 16from .dirichlet import Dirichlet 17from .distribution import Distribution 18from .exp_family import ExponentialFamily 19from .exponential import Exponential 20from .gamma import Gamma 21from .geometric import Geometric 22from .gumbel import Gumbel 23from .half_normal import HalfNormal 24from .independent import Independent 25from .laplace import Laplace 26from .lowrank_multivariate_normal import ( 27 _batch_lowrank_logdet, 28 _batch_lowrank_mahalanobis, 29 LowRankMultivariateNormal, 30) 31from .multivariate_normal import _batch_mahalanobis, MultivariateNormal 32from .normal import Normal 33from .one_hot_categorical import OneHotCategorical 34from .pareto import Pareto 35from .poisson import Poisson 36from .transformed_distribution import TransformedDistribution 37from .uniform import Uniform 38from .utils import _sum_rightmost, euler_constant as _euler_gamma 39 40 41_KL_REGISTRY: Dict[ 42 Tuple[Type, Type], Callable 43] = {} # Source of truth mapping a few general (type, type) pairs to functions. 44_KL_MEMOIZE: Dict[ 45 Tuple[Type, Type], Callable 46] = {} # Memoized version mapping many specific (type, type) pairs to functions. 47 48__all__ = ["register_kl", "kl_divergence"] 49 50 51def register_kl(type_p, type_q): 52 """ 53 Decorator to register a pairwise function with :meth:`kl_divergence`. 54 Usage:: 55 56 @register_kl(Normal, Normal) 57 def kl_normal_normal(p, q): 58 # insert implementation here 59 60 Lookup returns the most specific (type,type) match ordered by subclass. If 61 the match is ambiguous, a `RuntimeWarning` is raised. For example to 62 resolve the ambiguous situation:: 63 64 @register_kl(BaseP, DerivedQ) 65 def kl_version1(p, q): ... 66 @register_kl(DerivedP, BaseQ) 67 def kl_version2(p, q): ... 68 69 you should register a third most-specific implementation, e.g.:: 70 71 register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. 72 73 Args: 74 type_p (type): A subclass of :class:`~torch.distributions.Distribution`. 75 type_q (type): A subclass of :class:`~torch.distributions.Distribution`. 76 """ 77 if not isinstance(type_p, type) and issubclass(type_p, Distribution): 78 raise TypeError( 79 f"Expected type_p to be a Distribution subclass but got {type_p}" 80 ) 81 if not isinstance(type_q, type) and issubclass(type_q, Distribution): 82 raise TypeError( 83 f"Expected type_q to be a Distribution subclass but got {type_q}" 84 ) 85 86 def decorator(fun): 87 _KL_REGISTRY[type_p, type_q] = fun 88 _KL_MEMOIZE.clear() # reset since lookup order may have changed 89 return fun 90 91 return decorator 92 93 94@total_ordering 95class _Match: 96 __slots__ = ["types"] 97 98 def __init__(self, *types): 99 self.types = types 100 101 def __eq__(self, other): 102 return self.types == other.types 103 104 def __le__(self, other): 105 for x, y in zip(self.types, other.types): 106 if not issubclass(x, y): 107 return False 108 if x is not y: 109 break 110 return True 111 112 113def _dispatch_kl(type_p, type_q): 114 """ 115 Find the most specific approximate match, assuming single inheritance. 116 """ 117 matches = [ 118 (super_p, super_q) 119 for super_p, super_q in _KL_REGISTRY 120 if issubclass(type_p, super_p) and issubclass(type_q, super_q) 121 ] 122 if not matches: 123 return NotImplemented 124 # Check that the left- and right- lexicographic orders agree. 125 # mypy isn't smart enough to know that _Match implements __lt__ 126 # see: https://github.com/python/typing/issues/760#issuecomment-710670503 127 left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var] 128 right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var] 129 left_fun = _KL_REGISTRY[left_p, left_q] 130 right_fun = _KL_REGISTRY[right_p, right_q] 131 if left_fun is not right_fun: 132 warnings.warn( 133 f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). " 134 f"Please register_kl({left_p.__name__}, {right_q.__name__})", 135 RuntimeWarning, 136 ) 137 return left_fun 138 139 140def _infinite_like(tensor): 141 """ 142 Helper function for obtaining infinite KL Divergence throughout 143 """ 144 return torch.full_like(tensor, inf) 145 146 147def _x_log_x(tensor): 148 """ 149 Utility function for calculating x log x 150 """ 151 return tensor * tensor.log() 152 153 154def _batch_trace_XXT(bmat): 155 """ 156 Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions 157 """ 158 n = bmat.size(-1) 159 m = bmat.size(-2) 160 flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) 161 return flat_trace.reshape(bmat.shape[:-2]) 162 163 164def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: 165 r""" 166 Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. 167 168 .. math:: 169 170 KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx 171 172 Args: 173 p (Distribution): A :class:`~torch.distributions.Distribution` object. 174 q (Distribution): A :class:`~torch.distributions.Distribution` object. 175 176 Returns: 177 Tensor: A batch of KL divergences of shape `batch_shape`. 178 179 Raises: 180 NotImplementedError: If the distribution types have not been registered via 181 :meth:`register_kl`. 182 """ 183 try: 184 fun = _KL_MEMOIZE[type(p), type(q)] 185 except KeyError: 186 fun = _dispatch_kl(type(p), type(q)) 187 _KL_MEMOIZE[type(p), type(q)] = fun 188 if fun is NotImplemented: 189 raise NotImplementedError( 190 f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" 191 ) 192 return fun(p, q) 193 194 195################################################################################ 196# KL Divergence Implementations 197################################################################################ 198 199# Same distributions 200 201 202@register_kl(Bernoulli, Bernoulli) 203def _kl_bernoulli_bernoulli(p, q): 204 t1 = p.probs * ( 205 torch.nn.functional.softplus(-q.logits) 206 - torch.nn.functional.softplus(-p.logits) 207 ) 208 t1[q.probs == 0] = inf 209 t1[p.probs == 0] = 0 210 t2 = (1 - p.probs) * ( 211 torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) 212 ) 213 t2[q.probs == 1] = inf 214 t2[p.probs == 1] = 0 215 return t1 + t2 216 217 218@register_kl(Beta, Beta) 219def _kl_beta_beta(p, q): 220 sum_params_p = p.concentration1 + p.concentration0 221 sum_params_q = q.concentration1 + q.concentration0 222 t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma() 223 t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma() 224 t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1) 225 t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0) 226 t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) 227 return t1 - t2 + t3 + t4 + t5 228 229 230@register_kl(Binomial, Binomial) 231def _kl_binomial_binomial(p, q): 232 # from https://math.stackexchange.com/questions/2214993/ 233 # kullback-leibler-divergence-for-binomial-distributions-p-and-q 234 if (p.total_count < q.total_count).any(): 235 raise NotImplementedError( 236 "KL between Binomials where q.total_count > p.total_count is not implemented" 237 ) 238 kl = p.total_count * ( 239 p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() 240 ) 241 inf_idxs = p.total_count > q.total_count 242 kl[inf_idxs] = _infinite_like(kl[inf_idxs]) 243 return kl 244 245 246@register_kl(Categorical, Categorical) 247def _kl_categorical_categorical(p, q): 248 t = p.probs * (p.logits - q.logits) 249 t[(q.probs == 0).expand_as(t)] = inf 250 t[(p.probs == 0).expand_as(t)] = 0 251 return t.sum(-1) 252 253 254@register_kl(ContinuousBernoulli, ContinuousBernoulli) 255def _kl_continuous_bernoulli_continuous_bernoulli(p, q): 256 t1 = p.mean * (p.logits - q.logits) 257 t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) 258 t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) 259 return t1 + t2 + t3 260 261 262@register_kl(Dirichlet, Dirichlet) 263def _kl_dirichlet_dirichlet(p, q): 264 # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ 265 sum_p_concentration = p.concentration.sum(-1) 266 sum_q_concentration = q.concentration.sum(-1) 267 t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma() 268 t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1) 269 t3 = p.concentration - q.concentration 270 t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1) 271 return t1 - t2 + (t3 * t4).sum(-1) 272 273 274@register_kl(Exponential, Exponential) 275def _kl_exponential_exponential(p, q): 276 rate_ratio = q.rate / p.rate 277 t1 = -rate_ratio.log() 278 return t1 + rate_ratio - 1 279 280 281@register_kl(ExponentialFamily, ExponentialFamily) 282def _kl_expfamily_expfamily(p, q): 283 if not type(p) == type(q): 284 raise NotImplementedError( 285 "The cross KL-divergence between different exponential families cannot \ 286 be computed using Bregman divergences" 287 ) 288 p_nparams = [np.detach().requires_grad_() for np in p._natural_params] 289 q_nparams = q._natural_params 290 lg_normal = p._log_normalizer(*p_nparams) 291 gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True) 292 result = q._log_normalizer(*q_nparams) - lg_normal 293 for pnp, qnp, g in zip(p_nparams, q_nparams, gradients): 294 term = (qnp - pnp) * g 295 result -= _sum_rightmost(term, len(q.event_shape)) 296 return result 297 298 299@register_kl(Gamma, Gamma) 300def _kl_gamma_gamma(p, q): 301 t1 = q.concentration * (p.rate / q.rate).log() 302 t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration) 303 t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration) 304 t4 = (q.rate - p.rate) * (p.concentration / p.rate) 305 return t1 + t2 + t3 + t4 306 307 308@register_kl(Gumbel, Gumbel) 309def _kl_gumbel_gumbel(p, q): 310 ct1 = p.scale / q.scale 311 ct2 = q.loc / q.scale 312 ct3 = p.loc / q.scale 313 t1 = -ct1.log() - ct2 + ct3 314 t2 = ct1 * _euler_gamma 315 t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) 316 return t1 + t2 + t3 - (1 + _euler_gamma) 317 318 319@register_kl(Geometric, Geometric) 320def _kl_geometric_geometric(p, q): 321 return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits 322 323 324@register_kl(HalfNormal, HalfNormal) 325def _kl_halfnormal_halfnormal(p, q): 326 return _kl_normal_normal(p.base_dist, q.base_dist) 327 328 329@register_kl(Laplace, Laplace) 330def _kl_laplace_laplace(p, q): 331 # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf 332 scale_ratio = p.scale / q.scale 333 loc_abs_diff = (p.loc - q.loc).abs() 334 t1 = -scale_ratio.log() 335 t2 = loc_abs_diff / q.scale 336 t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale) 337 return t1 + t2 + t3 - 1 338 339 340@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) 341def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): 342 if p.event_shape != q.event_shape: 343 raise ValueError( 344 "KL-divergence between two Low Rank Multivariate Normals with\ 345 different event shapes cannot be computed" 346 ) 347 348 term1 = _batch_lowrank_logdet( 349 q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril 350 ) - _batch_lowrank_logdet( 351 p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril 352 ) 353 term3 = _batch_lowrank_mahalanobis( 354 q._unbroadcasted_cov_factor, 355 q._unbroadcasted_cov_diag, 356 q.loc - p.loc, 357 q._capacitance_tril, 358 ) 359 # Expands term2 according to 360 # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) 361 # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) 362 qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) 363 A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) 364 term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) 365 term22 = _batch_trace_XXT( 366 p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) 367 ) 368 term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) 369 term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) 370 term2 = term21 + term22 - term23 - term24 371 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 372 373 374@register_kl(MultivariateNormal, LowRankMultivariateNormal) 375def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): 376 if p.event_shape != q.event_shape: 377 raise ValueError( 378 "KL-divergence between two (Low Rank) Multivariate Normals with\ 379 different event shapes cannot be computed" 380 ) 381 382 term1 = _batch_lowrank_logdet( 383 q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril 384 ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 385 term3 = _batch_lowrank_mahalanobis( 386 q._unbroadcasted_cov_factor, 387 q._unbroadcasted_cov_diag, 388 q.loc - p.loc, 389 q._capacitance_tril, 390 ) 391 # Expands term2 according to 392 # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T 393 # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T 394 qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) 395 A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) 396 term21 = _batch_trace_XXT( 397 p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) 398 ) 399 term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) 400 term2 = term21 - term22 401 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 402 403 404@register_kl(LowRankMultivariateNormal, MultivariateNormal) 405def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): 406 if p.event_shape != q.event_shape: 407 raise ValueError( 408 "KL-divergence between two (Low Rank) Multivariate Normals with\ 409 different event shapes cannot be computed" 410 ) 411 412 term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( 413 -1 414 ) - _batch_lowrank_logdet( 415 p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril 416 ) 417 term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) 418 # Expands term2 according to 419 # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) 420 combined_batch_shape = torch._C._infer_size( 421 q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] 422 ) 423 n = p.event_shape[0] 424 q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 425 p_cov_factor = p._unbroadcasted_cov_factor.expand( 426 combined_batch_shape + (n, p.cov_factor.size(-1)) 427 ) 428 p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( 429 combined_batch_shape + (n, n) 430 ) 431 term21 = _batch_trace_XXT( 432 torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) 433 ) 434 term22 = _batch_trace_XXT( 435 torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) 436 ) 437 term2 = term21 + term22 438 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 439 440 441@register_kl(MultivariateNormal, MultivariateNormal) 442def _kl_multivariatenormal_multivariatenormal(p, q): 443 # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence 444 if p.event_shape != q.event_shape: 445 raise ValueError( 446 "KL-divergence between two Multivariate Normals with\ 447 different event shapes cannot be computed" 448 ) 449 450 half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( 451 -1 452 ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 453 combined_batch_shape = torch._C._infer_size( 454 q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] 455 ) 456 n = p.event_shape[0] 457 q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 458 p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 459 term2 = _batch_trace_XXT( 460 torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) 461 ) 462 term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) 463 return half_term1 + 0.5 * (term2 + term3 - n) 464 465 466@register_kl(Normal, Normal) 467def _kl_normal_normal(p, q): 468 var_ratio = (p.scale / q.scale).pow(2) 469 t1 = ((p.loc - q.loc) / q.scale).pow(2) 470 return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) 471 472 473@register_kl(OneHotCategorical, OneHotCategorical) 474def _kl_onehotcategorical_onehotcategorical(p, q): 475 return _kl_categorical_categorical(p._categorical, q._categorical) 476 477 478@register_kl(Pareto, Pareto) 479def _kl_pareto_pareto(p, q): 480 # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf 481 scale_ratio = p.scale / q.scale 482 alpha_ratio = q.alpha / p.alpha 483 t1 = q.alpha * scale_ratio.log() 484 t2 = -alpha_ratio.log() 485 result = t1 + t2 + alpha_ratio - 1 486 result[p.support.lower_bound < q.support.lower_bound] = inf 487 return result 488 489 490@register_kl(Poisson, Poisson) 491def _kl_poisson_poisson(p, q): 492 return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate) 493 494 495@register_kl(TransformedDistribution, TransformedDistribution) 496def _kl_transformed_transformed(p, q): 497 if p.transforms != q.transforms: 498 raise NotImplementedError 499 if p.event_shape != q.event_shape: 500 raise NotImplementedError 501 return kl_divergence(p.base_dist, q.base_dist) 502 503 504@register_kl(Uniform, Uniform) 505def _kl_uniform_uniform(p, q): 506 result = ((q.high - q.low) / (p.high - p.low)).log() 507 result[(q.low > p.low) | (q.high < p.high)] = inf 508 return result 509 510 511# Different distributions 512@register_kl(Bernoulli, Poisson) 513def _kl_bernoulli_poisson(p, q): 514 return -p.entropy() - (p.probs * q.rate.log() - q.rate) 515 516 517@register_kl(Beta, ContinuousBernoulli) 518def _kl_beta_continuous_bernoulli(p, q): 519 return ( 520 -p.entropy() 521 - p.mean * q.logits 522 - torch.log1p(-q.probs) 523 - q._cont_bern_log_norm() 524 ) 525 526 527@register_kl(Beta, Pareto) 528def _kl_beta_infinity(p, q): 529 return _infinite_like(p.concentration1) 530 531 532@register_kl(Beta, Exponential) 533def _kl_beta_exponential(p, q): 534 return ( 535 -p.entropy() 536 - q.rate.log() 537 + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) 538 ) 539 540 541@register_kl(Beta, Gamma) 542def _kl_beta_gamma(p, q): 543 t1 = -p.entropy() 544 t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 545 t3 = (q.concentration - 1) * ( 546 p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() 547 ) 548 t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) 549 return t1 + t2 - t3 + t4 550 551 552# TODO: Add Beta-Laplace KL Divergence 553 554 555@register_kl(Beta, Normal) 556def _kl_beta_normal(p, q): 557 E_beta = p.concentration1 / (p.concentration1 + p.concentration0) 558 var_normal = q.scale.pow(2) 559 t1 = -p.entropy() 560 t2 = 0.5 * (var_normal * 2 * math.pi).log() 561 t3 = ( 562 E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) 563 + E_beta.pow(2) 564 ) * 0.5 565 t4 = q.loc * E_beta 566 t5 = q.loc.pow(2) * 0.5 567 return t1 + t2 + (t3 - t4 + t5) / var_normal 568 569 570@register_kl(Beta, Uniform) 571def _kl_beta_uniform(p, q): 572 result = -p.entropy() + (q.high - q.low).log() 573 result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf 574 return result 575 576 577# Note that the KL between a ContinuousBernoulli and Beta has no closed form 578 579 580@register_kl(ContinuousBernoulli, Pareto) 581def _kl_continuous_bernoulli_infinity(p, q): 582 return _infinite_like(p.probs) 583 584 585@register_kl(ContinuousBernoulli, Exponential) 586def _kl_continuous_bernoulli_exponential(p, q): 587 return -p.entropy() - torch.log(q.rate) + q.rate * p.mean 588 589 590# Note that the KL between a ContinuousBernoulli and Gamma has no closed form 591# TODO: Add ContinuousBernoulli-Laplace KL Divergence 592 593 594@register_kl(ContinuousBernoulli, Normal) 595def _kl_continuous_bernoulli_normal(p, q): 596 t1 = -p.entropy() 597 t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( 598 q.scale 599 ) 600 t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( 601 2.0 * torch.square(q.scale) 602 ) 603 return t1 + t2 + t3 604 605 606@register_kl(ContinuousBernoulli, Uniform) 607def _kl_continuous_bernoulli_uniform(p, q): 608 result = -p.entropy() + (q.high - q.low).log() 609 return torch.where( 610 torch.max( 611 torch.ge(q.low, p.support.lower_bound), 612 torch.le(q.high, p.support.upper_bound), 613 ), 614 torch.ones_like(result) * inf, 615 result, 616 ) 617 618 619@register_kl(Exponential, Beta) 620@register_kl(Exponential, ContinuousBernoulli) 621@register_kl(Exponential, Pareto) 622@register_kl(Exponential, Uniform) 623def _kl_exponential_infinity(p, q): 624 return _infinite_like(p.rate) 625 626 627@register_kl(Exponential, Gamma) 628def _kl_exponential_gamma(p, q): 629 ratio = q.rate / p.rate 630 t1 = -q.concentration * torch.log(ratio) 631 return ( 632 t1 633 + ratio 634 + q.concentration.lgamma() 635 + q.concentration * _euler_gamma 636 - (1 + _euler_gamma) 637 ) 638 639 640@register_kl(Exponential, Gumbel) 641def _kl_exponential_gumbel(p, q): 642 scale_rate_prod = p.rate * q.scale 643 loc_scale_ratio = q.loc / q.scale 644 t1 = scale_rate_prod.log() - 1 645 t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) 646 t3 = scale_rate_prod.reciprocal() 647 return t1 - loc_scale_ratio + t2 + t3 648 649 650# TODO: Add Exponential-Laplace KL Divergence 651 652 653@register_kl(Exponential, Normal) 654def _kl_exponential_normal(p, q): 655 var_normal = q.scale.pow(2) 656 rate_sqr = p.rate.pow(2) 657 t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi) 658 t2 = rate_sqr.reciprocal() 659 t3 = q.loc / p.rate 660 t4 = q.loc.pow(2) * 0.5 661 return t1 - 1 + (t2 - t3 + t4) / var_normal 662 663 664@register_kl(Gamma, Beta) 665@register_kl(Gamma, ContinuousBernoulli) 666@register_kl(Gamma, Pareto) 667@register_kl(Gamma, Uniform) 668def _kl_gamma_infinity(p, q): 669 return _infinite_like(p.concentration) 670 671 672@register_kl(Gamma, Exponential) 673def _kl_gamma_exponential(p, q): 674 return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate 675 676 677@register_kl(Gamma, Gumbel) 678def _kl_gamma_gumbel(p, q): 679 beta_scale_prod = p.rate * q.scale 680 loc_scale_ratio = q.loc / q.scale 681 t1 = ( 682 (p.concentration - 1) * p.concentration.digamma() 683 - p.concentration.lgamma() 684 - p.concentration 685 ) 686 t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod 687 t3 = ( 688 torch.exp(loc_scale_ratio) 689 * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) 690 - loc_scale_ratio 691 ) 692 return t1 + t2 + t3 693 694 695# TODO: Add Gamma-Laplace KL Divergence 696 697 698@register_kl(Gamma, Normal) 699def _kl_gamma_normal(p, q): 700 var_normal = q.scale.pow(2) 701 beta_sqr = p.rate.pow(2) 702 t1 = ( 703 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) 704 - p.concentration 705 - p.concentration.lgamma() 706 ) 707 t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr 708 t3 = q.loc * p.concentration / p.rate 709 t4 = 0.5 * q.loc.pow(2) 710 return ( 711 t1 712 + (p.concentration - 1) * p.concentration.digamma() 713 + (t2 - t3 + t4) / var_normal 714 ) 715 716 717@register_kl(Gumbel, Beta) 718@register_kl(Gumbel, ContinuousBernoulli) 719@register_kl(Gumbel, Exponential) 720@register_kl(Gumbel, Gamma) 721@register_kl(Gumbel, Pareto) 722@register_kl(Gumbel, Uniform) 723def _kl_gumbel_infinity(p, q): 724 return _infinite_like(p.loc) 725 726 727# TODO: Add Gumbel-Laplace KL Divergence 728 729 730@register_kl(Gumbel, Normal) 731def _kl_gumbel_normal(p, q): 732 param_ratio = p.scale / q.scale 733 t1 = (param_ratio / math.sqrt(2 * math.pi)).log() 734 t2 = (math.pi * param_ratio * 0.5).pow(2) / 3 735 t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5 736 return -t1 + t2 + t3 - (_euler_gamma + 1) 737 738 739@register_kl(Laplace, Beta) 740@register_kl(Laplace, ContinuousBernoulli) 741@register_kl(Laplace, Exponential) 742@register_kl(Laplace, Gamma) 743@register_kl(Laplace, Pareto) 744@register_kl(Laplace, Uniform) 745def _kl_laplace_infinity(p, q): 746 return _infinite_like(p.loc) 747 748 749@register_kl(Laplace, Normal) 750def _kl_laplace_normal(p, q): 751 var_normal = q.scale.pow(2) 752 scale_sqr_var_ratio = p.scale.pow(2) / var_normal 753 t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi) 754 t2 = 0.5 * p.loc.pow(2) 755 t3 = p.loc * q.loc 756 t4 = 0.5 * q.loc.pow(2) 757 return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1 758 759 760@register_kl(Normal, Beta) 761@register_kl(Normal, ContinuousBernoulli) 762@register_kl(Normal, Exponential) 763@register_kl(Normal, Gamma) 764@register_kl(Normal, Pareto) 765@register_kl(Normal, Uniform) 766def _kl_normal_infinity(p, q): 767 return _infinite_like(p.loc) 768 769 770@register_kl(Normal, Gumbel) 771def _kl_normal_gumbel(p, q): 772 mean_scale_ratio = p.loc / q.scale 773 var_scale_sqr_ratio = (p.scale / q.scale).pow(2) 774 loc_scale_ratio = q.loc / q.scale 775 t1 = var_scale_sqr_ratio.log() * 0.5 776 t2 = mean_scale_ratio - loc_scale_ratio 777 t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio) 778 return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi))) 779 780 781@register_kl(Normal, Laplace) 782def _kl_normal_laplace(p, q): 783 loc_diff = p.loc - q.loc 784 scale_ratio = p.scale / q.scale 785 loc_diff_scale_ratio = loc_diff / p.scale 786 t1 = torch.log(scale_ratio) 787 t2 = ( 788 math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) 789 ) 790 t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) 791 return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) 792 793 794@register_kl(Pareto, Beta) 795@register_kl(Pareto, ContinuousBernoulli) 796@register_kl(Pareto, Uniform) 797def _kl_pareto_infinity(p, q): 798 return _infinite_like(p.scale) 799 800 801@register_kl(Pareto, Exponential) 802def _kl_pareto_exponential(p, q): 803 scale_rate_prod = p.scale * q.rate 804 t1 = (p.alpha / scale_rate_prod).log() 805 t2 = p.alpha.reciprocal() 806 t3 = p.alpha * scale_rate_prod / (p.alpha - 1) 807 result = t1 - t2 + t3 - 1 808 result[p.alpha <= 1] = inf 809 return result 810 811 812@register_kl(Pareto, Gamma) 813def _kl_pareto_gamma(p, q): 814 common_term = p.scale.log() + p.alpha.reciprocal() 815 t1 = p.alpha.log() - common_term 816 t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 817 t3 = (1 - q.concentration) * common_term 818 t4 = q.rate * p.alpha * p.scale / (p.alpha - 1) 819 result = t1 + t2 + t3 + t4 - 1 820 result[p.alpha <= 1] = inf 821 return result 822 823 824# TODO: Add Pareto-Laplace KL Divergence 825 826 827@register_kl(Pareto, Normal) 828def _kl_pareto_normal(p, q): 829 var_normal = 2 * q.scale.pow(2) 830 common_term = p.scale / (p.alpha - 1) 831 t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log() 832 t2 = p.alpha.reciprocal() 833 t3 = p.alpha * common_term.pow(2) / (p.alpha - 2) 834 t4 = (p.alpha * common_term - q.loc).pow(2) 835 result = t1 - t2 + (t3 + t4) / var_normal - 1 836 result[p.alpha <= 2] = inf 837 return result 838 839 840@register_kl(Poisson, Bernoulli) 841@register_kl(Poisson, Binomial) 842def _kl_poisson_infinity(p, q): 843 return _infinite_like(p.rate) 844 845 846@register_kl(Uniform, Beta) 847def _kl_uniform_beta(p, q): 848 common_term = p.high - p.low 849 t1 = torch.log(common_term) 850 t2 = ( 851 (q.concentration1 - 1) 852 * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) 853 / common_term 854 ) 855 t3 = ( 856 (q.concentration0 - 1) 857 * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) 858 / common_term 859 ) 860 t4 = ( 861 q.concentration1.lgamma() 862 + q.concentration0.lgamma() 863 - (q.concentration1 + q.concentration0).lgamma() 864 ) 865 result = t3 + t4 - t1 - t2 866 result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf 867 return result 868 869 870@register_kl(Uniform, ContinuousBernoulli) 871def _kl_uniform_continuous_bernoulli(p, q): 872 result = ( 873 -p.entropy() 874 - p.mean * q.logits 875 - torch.log1p(-q.probs) 876 - q._cont_bern_log_norm() 877 ) 878 return torch.where( 879 torch.max( 880 torch.ge(p.high, q.support.upper_bound), 881 torch.le(p.low, q.support.lower_bound), 882 ), 883 torch.ones_like(result) * inf, 884 result, 885 ) 886 887 888@register_kl(Uniform, Exponential) 889def _kl_uniform_exponetial(p, q): 890 result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log() 891 result[p.low < q.support.lower_bound] = inf 892 return result 893 894 895@register_kl(Uniform, Gamma) 896def _kl_uniform_gamma(p, q): 897 common_term = p.high - p.low 898 t1 = common_term.log() 899 t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 900 t3 = ( 901 (1 - q.concentration) 902 * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) 903 / common_term 904 ) 905 t4 = q.rate * (p.high + p.low) / 2 906 result = -t1 + t2 + t3 + t4 907 result[p.low < q.support.lower_bound] = inf 908 return result 909 910 911@register_kl(Uniform, Gumbel) 912def _kl_uniform_gumbel(p, q): 913 common_term = q.scale / (p.high - p.low) 914 high_loc_diff = (p.high - q.loc) / q.scale 915 low_loc_diff = (p.low - q.loc) / q.scale 916 t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff) 917 t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) 918 return t1 - t2 919 920 921# TODO: Uniform-Laplace KL Divergence 922 923 924@register_kl(Uniform, Normal) 925def _kl_uniform_normal(p, q): 926 common_term = p.high - p.low 927 t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log() 928 t2 = (common_term).pow(2) / 12 929 t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2) 930 return t1 + 0.5 * (t2 + t3) / q.scale.pow(2) 931 932 933@register_kl(Uniform, Pareto) 934def _kl_uniform_pareto(p, q): 935 support_uniform = p.high - p.low 936 t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log() 937 t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform 938 result = t2 * (q.alpha + 1) - t1 939 result[p.low < q.support.lower_bound] = inf 940 return result 941 942 943@register_kl(Independent, Independent) 944def _kl_independent_independent(p, q): 945 if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: 946 raise NotImplementedError 947 result = kl_divergence(p.base_dist, q.base_dist) 948 return _sum_rightmost(result, p.reinterpreted_batch_ndims) 949 950 951@register_kl(Cauchy, Cauchy) 952def _kl_cauchy_cauchy(p, q): 953 # From https://arxiv.org/abs/1905.10965 954 t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log() 955 t2 = (4 * p.scale * q.scale).log() 956 return t1 - t2 957 958 959def _add_kl_info(): 960 """Appends a list of implemented KL functions to the doc for kl_divergence.""" 961 rows = [ 962 "KL divergence is currently implemented for the following distribution pairs:" 963 ] 964 for p, q in sorted( 965 _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) 966 ): 967 rows.append( 968 f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" 969 ) 970 kl_info = "\n\t".join(rows) 971 if kl_divergence.__doc__: 972 kl_divergence.__doc__ += kl_info # type: ignore[operator] 973