1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerimport warnings 4*da0073e9SAndroid Build Coastguard Workerfrom functools import total_ordering 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Tuple, Type 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch import inf 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom .bernoulli import Bernoulli 11*da0073e9SAndroid Build Coastguard Workerfrom .beta import Beta 12*da0073e9SAndroid Build Coastguard Workerfrom .binomial import Binomial 13*da0073e9SAndroid Build Coastguard Workerfrom .categorical import Categorical 14*da0073e9SAndroid Build Coastguard Workerfrom .cauchy import Cauchy 15*da0073e9SAndroid Build Coastguard Workerfrom .continuous_bernoulli import ContinuousBernoulli 16*da0073e9SAndroid Build Coastguard Workerfrom .dirichlet import Dirichlet 17*da0073e9SAndroid Build Coastguard Workerfrom .distribution import Distribution 18*da0073e9SAndroid Build Coastguard Workerfrom .exp_family import ExponentialFamily 19*da0073e9SAndroid Build Coastguard Workerfrom .exponential import Exponential 20*da0073e9SAndroid Build Coastguard Workerfrom .gamma import Gamma 21*da0073e9SAndroid Build Coastguard Workerfrom .geometric import Geometric 22*da0073e9SAndroid Build Coastguard Workerfrom .gumbel import Gumbel 23*da0073e9SAndroid Build Coastguard Workerfrom .half_normal import HalfNormal 24*da0073e9SAndroid Build Coastguard Workerfrom .independent import Independent 25*da0073e9SAndroid Build Coastguard Workerfrom .laplace import Laplace 26*da0073e9SAndroid Build Coastguard Workerfrom .lowrank_multivariate_normal import ( 27*da0073e9SAndroid Build Coastguard Worker _batch_lowrank_logdet, 28*da0073e9SAndroid Build Coastguard Worker _batch_lowrank_mahalanobis, 29*da0073e9SAndroid Build Coastguard Worker LowRankMultivariateNormal, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom .multivariate_normal import _batch_mahalanobis, MultivariateNormal 32*da0073e9SAndroid Build Coastguard Workerfrom .normal import Normal 33*da0073e9SAndroid Build Coastguard Workerfrom .one_hot_categorical import OneHotCategorical 34*da0073e9SAndroid Build Coastguard Workerfrom .pareto import Pareto 35*da0073e9SAndroid Build Coastguard Workerfrom .poisson import Poisson 36*da0073e9SAndroid Build Coastguard Workerfrom .transformed_distribution import TransformedDistribution 37*da0073e9SAndroid Build Coastguard Workerfrom .uniform import Uniform 38*da0073e9SAndroid Build Coastguard Workerfrom .utils import _sum_rightmost, euler_constant as _euler_gamma 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker_KL_REGISTRY: Dict[ 42*da0073e9SAndroid Build Coastguard Worker Tuple[Type, Type], Callable 43*da0073e9SAndroid Build Coastguard Worker] = {} # Source of truth mapping a few general (type, type) pairs to functions. 44*da0073e9SAndroid Build Coastguard Worker_KL_MEMOIZE: Dict[ 45*da0073e9SAndroid Build Coastguard Worker Tuple[Type, Type], Callable 46*da0073e9SAndroid Build Coastguard Worker] = {} # Memoized version mapping many specific (type, type) pairs to functions. 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker__all__ = ["register_kl", "kl_divergence"] 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerdef register_kl(type_p, type_q): 52*da0073e9SAndroid Build Coastguard Worker """ 53*da0073e9SAndroid Build Coastguard Worker Decorator to register a pairwise function with :meth:`kl_divergence`. 54*da0073e9SAndroid Build Coastguard Worker Usage:: 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker @register_kl(Normal, Normal) 57*da0073e9SAndroid Build Coastguard Worker def kl_normal_normal(p, q): 58*da0073e9SAndroid Build Coastguard Worker # insert implementation here 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker Lookup returns the most specific (type,type) match ordered by subclass. If 61*da0073e9SAndroid Build Coastguard Worker the match is ambiguous, a `RuntimeWarning` is raised. For example to 62*da0073e9SAndroid Build Coastguard Worker resolve the ambiguous situation:: 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @register_kl(BaseP, DerivedQ) 65*da0073e9SAndroid Build Coastguard Worker def kl_version1(p, q): ... 66*da0073e9SAndroid Build Coastguard Worker @register_kl(DerivedP, BaseQ) 67*da0073e9SAndroid Build Coastguard Worker def kl_version2(p, q): ... 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker you should register a third most-specific implementation, e.g.:: 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker Args: 74*da0073e9SAndroid Build Coastguard Worker type_p (type): A subclass of :class:`~torch.distributions.Distribution`. 75*da0073e9SAndroid Build Coastguard Worker type_q (type): A subclass of :class:`~torch.distributions.Distribution`. 76*da0073e9SAndroid Build Coastguard Worker """ 77*da0073e9SAndroid Build Coastguard Worker if not isinstance(type_p, type) and issubclass(type_p, Distribution): 78*da0073e9SAndroid Build Coastguard Worker raise TypeError( 79*da0073e9SAndroid Build Coastguard Worker f"Expected type_p to be a Distribution subclass but got {type_p}" 80*da0073e9SAndroid Build Coastguard Worker ) 81*da0073e9SAndroid Build Coastguard Worker if not isinstance(type_q, type) and issubclass(type_q, Distribution): 82*da0073e9SAndroid Build Coastguard Worker raise TypeError( 83*da0073e9SAndroid Build Coastguard Worker f"Expected type_q to be a Distribution subclass but got {type_q}" 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def decorator(fun): 87*da0073e9SAndroid Build Coastguard Worker _KL_REGISTRY[type_p, type_q] = fun 88*da0073e9SAndroid Build Coastguard Worker _KL_MEMOIZE.clear() # reset since lookup order may have changed 89*da0073e9SAndroid Build Coastguard Worker return fun 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker return decorator 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker@total_ordering 95*da0073e9SAndroid Build Coastguard Workerclass _Match: 96*da0073e9SAndroid Build Coastguard Worker __slots__ = ["types"] 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def __init__(self, *types): 99*da0073e9SAndroid Build Coastguard Worker self.types = types 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 102*da0073e9SAndroid Build Coastguard Worker return self.types == other.types 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def __le__(self, other): 105*da0073e9SAndroid Build Coastguard Worker for x, y in zip(self.types, other.types): 106*da0073e9SAndroid Build Coastguard Worker if not issubclass(x, y): 107*da0073e9SAndroid Build Coastguard Worker return False 108*da0073e9SAndroid Build Coastguard Worker if x is not y: 109*da0073e9SAndroid Build Coastguard Worker break 110*da0073e9SAndroid Build Coastguard Worker return True 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerdef _dispatch_kl(type_p, type_q): 114*da0073e9SAndroid Build Coastguard Worker """ 115*da0073e9SAndroid Build Coastguard Worker Find the most specific approximate match, assuming single inheritance. 116*da0073e9SAndroid Build Coastguard Worker """ 117*da0073e9SAndroid Build Coastguard Worker matches = [ 118*da0073e9SAndroid Build Coastguard Worker (super_p, super_q) 119*da0073e9SAndroid Build Coastguard Worker for super_p, super_q in _KL_REGISTRY 120*da0073e9SAndroid Build Coastguard Worker if issubclass(type_p, super_p) and issubclass(type_q, super_q) 121*da0073e9SAndroid Build Coastguard Worker ] 122*da0073e9SAndroid Build Coastguard Worker if not matches: 123*da0073e9SAndroid Build Coastguard Worker return NotImplemented 124*da0073e9SAndroid Build Coastguard Worker # Check that the left- and right- lexicographic orders agree. 125*da0073e9SAndroid Build Coastguard Worker # mypy isn't smart enough to know that _Match implements __lt__ 126*da0073e9SAndroid Build Coastguard Worker # see: https://github.com/python/typing/issues/760#issuecomment-710670503 127*da0073e9SAndroid Build Coastguard Worker left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var] 128*da0073e9SAndroid Build Coastguard Worker right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var] 129*da0073e9SAndroid Build Coastguard Worker left_fun = _KL_REGISTRY[left_p, left_q] 130*da0073e9SAndroid Build Coastguard Worker right_fun = _KL_REGISTRY[right_p, right_q] 131*da0073e9SAndroid Build Coastguard Worker if left_fun is not right_fun: 132*da0073e9SAndroid Build Coastguard Worker warnings.warn( 133*da0073e9SAndroid Build Coastguard Worker f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). " 134*da0073e9SAndroid Build Coastguard Worker f"Please register_kl({left_p.__name__}, {right_q.__name__})", 135*da0073e9SAndroid Build Coastguard Worker RuntimeWarning, 136*da0073e9SAndroid Build Coastguard Worker ) 137*da0073e9SAndroid Build Coastguard Worker return left_fun 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Workerdef _infinite_like(tensor): 141*da0073e9SAndroid Build Coastguard Worker """ 142*da0073e9SAndroid Build Coastguard Worker Helper function for obtaining infinite KL Divergence throughout 143*da0073e9SAndroid Build Coastguard Worker """ 144*da0073e9SAndroid Build Coastguard Worker return torch.full_like(tensor, inf) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Workerdef _x_log_x(tensor): 148*da0073e9SAndroid Build Coastguard Worker """ 149*da0073e9SAndroid Build Coastguard Worker Utility function for calculating x log x 150*da0073e9SAndroid Build Coastguard Worker """ 151*da0073e9SAndroid Build Coastguard Worker return tensor * tensor.log() 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Workerdef _batch_trace_XXT(bmat): 155*da0073e9SAndroid Build Coastguard Worker """ 156*da0073e9SAndroid Build Coastguard Worker Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions 157*da0073e9SAndroid Build Coastguard Worker """ 158*da0073e9SAndroid Build Coastguard Worker n = bmat.size(-1) 159*da0073e9SAndroid Build Coastguard Worker m = bmat.size(-2) 160*da0073e9SAndroid Build Coastguard Worker flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) 161*da0073e9SAndroid Build Coastguard Worker return flat_trace.reshape(bmat.shape[:-2]) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Workerdef kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: 165*da0073e9SAndroid Build Coastguard Worker r""" 166*da0073e9SAndroid Build Coastguard Worker Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker .. math:: 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker Args: 173*da0073e9SAndroid Build Coastguard Worker p (Distribution): A :class:`~torch.distributions.Distribution` object. 174*da0073e9SAndroid Build Coastguard Worker q (Distribution): A :class:`~torch.distributions.Distribution` object. 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker Returns: 177*da0073e9SAndroid Build Coastguard Worker Tensor: A batch of KL divergences of shape `batch_shape`. 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker Raises: 180*da0073e9SAndroid Build Coastguard Worker NotImplementedError: If the distribution types have not been registered via 181*da0073e9SAndroid Build Coastguard Worker :meth:`register_kl`. 182*da0073e9SAndroid Build Coastguard Worker """ 183*da0073e9SAndroid Build Coastguard Worker try: 184*da0073e9SAndroid Build Coastguard Worker fun = _KL_MEMOIZE[type(p), type(q)] 185*da0073e9SAndroid Build Coastguard Worker except KeyError: 186*da0073e9SAndroid Build Coastguard Worker fun = _dispatch_kl(type(p), type(q)) 187*da0073e9SAndroid Build Coastguard Worker _KL_MEMOIZE[type(p), type(q)] = fun 188*da0073e9SAndroid Build Coastguard Worker if fun is NotImplemented: 189*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 190*da0073e9SAndroid Build Coastguard Worker f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker return fun(p, q) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker################################################################################ 196*da0073e9SAndroid Build Coastguard Worker# KL Divergence Implementations 197*da0073e9SAndroid Build Coastguard Worker################################################################################ 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker# Same distributions 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker@register_kl(Bernoulli, Bernoulli) 203*da0073e9SAndroid Build Coastguard Workerdef _kl_bernoulli_bernoulli(p, q): 204*da0073e9SAndroid Build Coastguard Worker t1 = p.probs * ( 205*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.softplus(-q.logits) 206*da0073e9SAndroid Build Coastguard Worker - torch.nn.functional.softplus(-p.logits) 207*da0073e9SAndroid Build Coastguard Worker ) 208*da0073e9SAndroid Build Coastguard Worker t1[q.probs == 0] = inf 209*da0073e9SAndroid Build Coastguard Worker t1[p.probs == 0] = 0 210*da0073e9SAndroid Build Coastguard Worker t2 = (1 - p.probs) * ( 211*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker t2[q.probs == 1] = inf 214*da0073e9SAndroid Build Coastguard Worker t2[p.probs == 1] = 0 215*da0073e9SAndroid Build Coastguard Worker return t1 + t2 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Beta) 219*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_beta(p, q): 220*da0073e9SAndroid Build Coastguard Worker sum_params_p = p.concentration1 + p.concentration0 221*da0073e9SAndroid Build Coastguard Worker sum_params_q = q.concentration1 + q.concentration0 222*da0073e9SAndroid Build Coastguard Worker t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma() 223*da0073e9SAndroid Build Coastguard Worker t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma() 224*da0073e9SAndroid Build Coastguard Worker t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1) 225*da0073e9SAndroid Build Coastguard Worker t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0) 226*da0073e9SAndroid Build Coastguard Worker t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) 227*da0073e9SAndroid Build Coastguard Worker return t1 - t2 + t3 + t4 + t5 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker@register_kl(Binomial, Binomial) 231*da0073e9SAndroid Build Coastguard Workerdef _kl_binomial_binomial(p, q): 232*da0073e9SAndroid Build Coastguard Worker # from https://math.stackexchange.com/questions/2214993/ 233*da0073e9SAndroid Build Coastguard Worker # kullback-leibler-divergence-for-binomial-distributions-p-and-q 234*da0073e9SAndroid Build Coastguard Worker if (p.total_count < q.total_count).any(): 235*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 236*da0073e9SAndroid Build Coastguard Worker "KL between Binomials where q.total_count > p.total_count is not implemented" 237*da0073e9SAndroid Build Coastguard Worker ) 238*da0073e9SAndroid Build Coastguard Worker kl = p.total_count * ( 239*da0073e9SAndroid Build Coastguard Worker p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker inf_idxs = p.total_count > q.total_count 242*da0073e9SAndroid Build Coastguard Worker kl[inf_idxs] = _infinite_like(kl[inf_idxs]) 243*da0073e9SAndroid Build Coastguard Worker return kl 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker@register_kl(Categorical, Categorical) 247*da0073e9SAndroid Build Coastguard Workerdef _kl_categorical_categorical(p, q): 248*da0073e9SAndroid Build Coastguard Worker t = p.probs * (p.logits - q.logits) 249*da0073e9SAndroid Build Coastguard Worker t[(q.probs == 0).expand_as(t)] = inf 250*da0073e9SAndroid Build Coastguard Worker t[(p.probs == 0).expand_as(t)] = 0 251*da0073e9SAndroid Build Coastguard Worker return t.sum(-1) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, ContinuousBernoulli) 255*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_continuous_bernoulli(p, q): 256*da0073e9SAndroid Build Coastguard Worker t1 = p.mean * (p.logits - q.logits) 257*da0073e9SAndroid Build Coastguard Worker t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) 258*da0073e9SAndroid Build Coastguard Worker t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) 259*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker@register_kl(Dirichlet, Dirichlet) 263*da0073e9SAndroid Build Coastguard Workerdef _kl_dirichlet_dirichlet(p, q): 264*da0073e9SAndroid Build Coastguard Worker # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ 265*da0073e9SAndroid Build Coastguard Worker sum_p_concentration = p.concentration.sum(-1) 266*da0073e9SAndroid Build Coastguard Worker sum_q_concentration = q.concentration.sum(-1) 267*da0073e9SAndroid Build Coastguard Worker t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma() 268*da0073e9SAndroid Build Coastguard Worker t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1) 269*da0073e9SAndroid Build Coastguard Worker t3 = p.concentration - q.concentration 270*da0073e9SAndroid Build Coastguard Worker t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1) 271*da0073e9SAndroid Build Coastguard Worker return t1 - t2 + (t3 * t4).sum(-1) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Exponential) 275*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_exponential(p, q): 276*da0073e9SAndroid Build Coastguard Worker rate_ratio = q.rate / p.rate 277*da0073e9SAndroid Build Coastguard Worker t1 = -rate_ratio.log() 278*da0073e9SAndroid Build Coastguard Worker return t1 + rate_ratio - 1 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker@register_kl(ExponentialFamily, ExponentialFamily) 282*da0073e9SAndroid Build Coastguard Workerdef _kl_expfamily_expfamily(p, q): 283*da0073e9SAndroid Build Coastguard Worker if not type(p) == type(q): 284*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 285*da0073e9SAndroid Build Coastguard Worker "The cross KL-divergence between different exponential families cannot \ 286*da0073e9SAndroid Build Coastguard Worker be computed using Bregman divergences" 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker p_nparams = [np.detach().requires_grad_() for np in p._natural_params] 289*da0073e9SAndroid Build Coastguard Worker q_nparams = q._natural_params 290*da0073e9SAndroid Build Coastguard Worker lg_normal = p._log_normalizer(*p_nparams) 291*da0073e9SAndroid Build Coastguard Worker gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True) 292*da0073e9SAndroid Build Coastguard Worker result = q._log_normalizer(*q_nparams) - lg_normal 293*da0073e9SAndroid Build Coastguard Worker for pnp, qnp, g in zip(p_nparams, q_nparams, gradients): 294*da0073e9SAndroid Build Coastguard Worker term = (qnp - pnp) * g 295*da0073e9SAndroid Build Coastguard Worker result -= _sum_rightmost(term, len(q.event_shape)) 296*da0073e9SAndroid Build Coastguard Worker return result 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Gamma) 300*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_gamma(p, q): 301*da0073e9SAndroid Build Coastguard Worker t1 = q.concentration * (p.rate / q.rate).log() 302*da0073e9SAndroid Build Coastguard Worker t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration) 303*da0073e9SAndroid Build Coastguard Worker t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration) 304*da0073e9SAndroid Build Coastguard Worker t4 = (q.rate - p.rate) * (p.concentration / p.rate) 305*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 + t4 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Gumbel) 309*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_gumbel(p, q): 310*da0073e9SAndroid Build Coastguard Worker ct1 = p.scale / q.scale 311*da0073e9SAndroid Build Coastguard Worker ct2 = q.loc / q.scale 312*da0073e9SAndroid Build Coastguard Worker ct3 = p.loc / q.scale 313*da0073e9SAndroid Build Coastguard Worker t1 = -ct1.log() - ct2 + ct3 314*da0073e9SAndroid Build Coastguard Worker t2 = ct1 * _euler_gamma 315*da0073e9SAndroid Build Coastguard Worker t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) 316*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 - (1 + _euler_gamma) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker@register_kl(Geometric, Geometric) 320*da0073e9SAndroid Build Coastguard Workerdef _kl_geometric_geometric(p, q): 321*da0073e9SAndroid Build Coastguard Worker return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker@register_kl(HalfNormal, HalfNormal) 325*da0073e9SAndroid Build Coastguard Workerdef _kl_halfnormal_halfnormal(p, q): 326*da0073e9SAndroid Build Coastguard Worker return _kl_normal_normal(p.base_dist, q.base_dist) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Laplace) 330*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_laplace(p, q): 331*da0073e9SAndroid Build Coastguard Worker # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf 332*da0073e9SAndroid Build Coastguard Worker scale_ratio = p.scale / q.scale 333*da0073e9SAndroid Build Coastguard Worker loc_abs_diff = (p.loc - q.loc).abs() 334*da0073e9SAndroid Build Coastguard Worker t1 = -scale_ratio.log() 335*da0073e9SAndroid Build Coastguard Worker t2 = loc_abs_diff / q.scale 336*da0073e9SAndroid Build Coastguard Worker t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale) 337*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 - 1 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) 341*da0073e9SAndroid Build Coastguard Workerdef _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): 342*da0073e9SAndroid Build Coastguard Worker if p.event_shape != q.event_shape: 343*da0073e9SAndroid Build Coastguard Worker raise ValueError( 344*da0073e9SAndroid Build Coastguard Worker "KL-divergence between two Low Rank Multivariate Normals with\ 345*da0073e9SAndroid Build Coastguard Worker different event shapes cannot be computed" 346*da0073e9SAndroid Build Coastguard Worker ) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker term1 = _batch_lowrank_logdet( 349*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril 350*da0073e9SAndroid Build Coastguard Worker ) - _batch_lowrank_logdet( 351*da0073e9SAndroid Build Coastguard Worker p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril 352*da0073e9SAndroid Build Coastguard Worker ) 353*da0073e9SAndroid Build Coastguard Worker term3 = _batch_lowrank_mahalanobis( 354*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_factor, 355*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_diag, 356*da0073e9SAndroid Build Coastguard Worker q.loc - p.loc, 357*da0073e9SAndroid Build Coastguard Worker q._capacitance_tril, 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker # Expands term2 according to 360*da0073e9SAndroid Build Coastguard Worker # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) 361*da0073e9SAndroid Build Coastguard Worker # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) 362*da0073e9SAndroid Build Coastguard Worker qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) 363*da0073e9SAndroid Build Coastguard Worker A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) 364*da0073e9SAndroid Build Coastguard Worker term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) 365*da0073e9SAndroid Build Coastguard Worker term22 = _batch_trace_XXT( 366*da0073e9SAndroid Build Coastguard Worker p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) 369*da0073e9SAndroid Build Coastguard Worker term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) 370*da0073e9SAndroid Build Coastguard Worker term2 = term21 + term22 - term23 - term24 371*da0073e9SAndroid Build Coastguard Worker return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker@register_kl(MultivariateNormal, LowRankMultivariateNormal) 375*da0073e9SAndroid Build Coastguard Workerdef _kl_multivariatenormal_lowrankmultivariatenormal(p, q): 376*da0073e9SAndroid Build Coastguard Worker if p.event_shape != q.event_shape: 377*da0073e9SAndroid Build Coastguard Worker raise ValueError( 378*da0073e9SAndroid Build Coastguard Worker "KL-divergence between two (Low Rank) Multivariate Normals with\ 379*da0073e9SAndroid Build Coastguard Worker different event shapes cannot be computed" 380*da0073e9SAndroid Build Coastguard Worker ) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker term1 = _batch_lowrank_logdet( 383*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril 384*da0073e9SAndroid Build Coastguard Worker ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 385*da0073e9SAndroid Build Coastguard Worker term3 = _batch_lowrank_mahalanobis( 386*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_factor, 387*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_cov_diag, 388*da0073e9SAndroid Build Coastguard Worker q.loc - p.loc, 389*da0073e9SAndroid Build Coastguard Worker q._capacitance_tril, 390*da0073e9SAndroid Build Coastguard Worker ) 391*da0073e9SAndroid Build Coastguard Worker # Expands term2 according to 392*da0073e9SAndroid Build Coastguard Worker # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T 393*da0073e9SAndroid Build Coastguard Worker # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T 394*da0073e9SAndroid Build Coastguard Worker qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) 395*da0073e9SAndroid Build Coastguard Worker A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) 396*da0073e9SAndroid Build Coastguard Worker term21 = _batch_trace_XXT( 397*da0073e9SAndroid Build Coastguard Worker p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) 400*da0073e9SAndroid Build Coastguard Worker term2 = term21 - term22 401*da0073e9SAndroid Build Coastguard Worker return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker@register_kl(LowRankMultivariateNormal, MultivariateNormal) 405*da0073e9SAndroid Build Coastguard Workerdef _kl_lowrankmultivariatenormal_multivariatenormal(p, q): 406*da0073e9SAndroid Build Coastguard Worker if p.event_shape != q.event_shape: 407*da0073e9SAndroid Build Coastguard Worker raise ValueError( 408*da0073e9SAndroid Build Coastguard Worker "KL-divergence between two (Low Rank) Multivariate Normals with\ 409*da0073e9SAndroid Build Coastguard Worker different event shapes cannot be computed" 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( 413*da0073e9SAndroid Build Coastguard Worker -1 414*da0073e9SAndroid Build Coastguard Worker ) - _batch_lowrank_logdet( 415*da0073e9SAndroid Build Coastguard Worker p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril 416*da0073e9SAndroid Build Coastguard Worker ) 417*da0073e9SAndroid Build Coastguard Worker term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) 418*da0073e9SAndroid Build Coastguard Worker # Expands term2 according to 419*da0073e9SAndroid Build Coastguard Worker # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) 420*da0073e9SAndroid Build Coastguard Worker combined_batch_shape = torch._C._infer_size( 421*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] 422*da0073e9SAndroid Build Coastguard Worker ) 423*da0073e9SAndroid Build Coastguard Worker n = p.event_shape[0] 424*da0073e9SAndroid Build Coastguard Worker q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 425*da0073e9SAndroid Build Coastguard Worker p_cov_factor = p._unbroadcasted_cov_factor.expand( 426*da0073e9SAndroid Build Coastguard Worker combined_batch_shape + (n, p.cov_factor.size(-1)) 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( 429*da0073e9SAndroid Build Coastguard Worker combined_batch_shape + (n, n) 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker term21 = _batch_trace_XXT( 432*da0073e9SAndroid Build Coastguard Worker torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) 433*da0073e9SAndroid Build Coastguard Worker ) 434*da0073e9SAndroid Build Coastguard Worker term22 = _batch_trace_XXT( 435*da0073e9SAndroid Build Coastguard Worker torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker term2 = term21 + term22 438*da0073e9SAndroid Build Coastguard Worker return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker@register_kl(MultivariateNormal, MultivariateNormal) 442*da0073e9SAndroid Build Coastguard Workerdef _kl_multivariatenormal_multivariatenormal(p, q): 443*da0073e9SAndroid Build Coastguard Worker # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence 444*da0073e9SAndroid Build Coastguard Worker if p.event_shape != q.event_shape: 445*da0073e9SAndroid Build Coastguard Worker raise ValueError( 446*da0073e9SAndroid Build Coastguard Worker "KL-divergence between two Multivariate Normals with\ 447*da0073e9SAndroid Build Coastguard Worker different event shapes cannot be computed" 448*da0073e9SAndroid Build Coastguard Worker ) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( 451*da0073e9SAndroid Build Coastguard Worker -1 452*da0073e9SAndroid Build Coastguard Worker ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 453*da0073e9SAndroid Build Coastguard Worker combined_batch_shape = torch._C._infer_size( 454*da0073e9SAndroid Build Coastguard Worker q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] 455*da0073e9SAndroid Build Coastguard Worker ) 456*da0073e9SAndroid Build Coastguard Worker n = p.event_shape[0] 457*da0073e9SAndroid Build Coastguard Worker q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 458*da0073e9SAndroid Build Coastguard Worker p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) 459*da0073e9SAndroid Build Coastguard Worker term2 = _batch_trace_XXT( 460*da0073e9SAndroid Build Coastguard Worker torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) 463*da0073e9SAndroid Build Coastguard Worker return half_term1 + 0.5 * (term2 + term3 - n) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Normal) 467*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_normal(p, q): 468*da0073e9SAndroid Build Coastguard Worker var_ratio = (p.scale / q.scale).pow(2) 469*da0073e9SAndroid Build Coastguard Worker t1 = ((p.loc - q.loc) / q.scale).pow(2) 470*da0073e9SAndroid Build Coastguard Worker return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker@register_kl(OneHotCategorical, OneHotCategorical) 474*da0073e9SAndroid Build Coastguard Workerdef _kl_onehotcategorical_onehotcategorical(p, q): 475*da0073e9SAndroid Build Coastguard Worker return _kl_categorical_categorical(p._categorical, q._categorical) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Pareto) 479*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_pareto(p, q): 480*da0073e9SAndroid Build Coastguard Worker # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf 481*da0073e9SAndroid Build Coastguard Worker scale_ratio = p.scale / q.scale 482*da0073e9SAndroid Build Coastguard Worker alpha_ratio = q.alpha / p.alpha 483*da0073e9SAndroid Build Coastguard Worker t1 = q.alpha * scale_ratio.log() 484*da0073e9SAndroid Build Coastguard Worker t2 = -alpha_ratio.log() 485*da0073e9SAndroid Build Coastguard Worker result = t1 + t2 + alpha_ratio - 1 486*da0073e9SAndroid Build Coastguard Worker result[p.support.lower_bound < q.support.lower_bound] = inf 487*da0073e9SAndroid Build Coastguard Worker return result 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Poisson) 491*da0073e9SAndroid Build Coastguard Workerdef _kl_poisson_poisson(p, q): 492*da0073e9SAndroid Build Coastguard Worker return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker@register_kl(TransformedDistribution, TransformedDistribution) 496*da0073e9SAndroid Build Coastguard Workerdef _kl_transformed_transformed(p, q): 497*da0073e9SAndroid Build Coastguard Worker if p.transforms != q.transforms: 498*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 499*da0073e9SAndroid Build Coastguard Worker if p.event_shape != q.event_shape: 500*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 501*da0073e9SAndroid Build Coastguard Worker return kl_divergence(p.base_dist, q.base_dist) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Uniform) 505*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_uniform(p, q): 506*da0073e9SAndroid Build Coastguard Worker result = ((q.high - q.low) / (p.high - p.low)).log() 507*da0073e9SAndroid Build Coastguard Worker result[(q.low > p.low) | (q.high < p.high)] = inf 508*da0073e9SAndroid Build Coastguard Worker return result 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker# Different distributions 512*da0073e9SAndroid Build Coastguard Worker@register_kl(Bernoulli, Poisson) 513*da0073e9SAndroid Build Coastguard Workerdef _kl_bernoulli_poisson(p, q): 514*da0073e9SAndroid Build Coastguard Worker return -p.entropy() - (p.probs * q.rate.log() - q.rate) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, ContinuousBernoulli) 518*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_continuous_bernoulli(p, q): 519*da0073e9SAndroid Build Coastguard Worker return ( 520*da0073e9SAndroid Build Coastguard Worker -p.entropy() 521*da0073e9SAndroid Build Coastguard Worker - p.mean * q.logits 522*da0073e9SAndroid Build Coastguard Worker - torch.log1p(-q.probs) 523*da0073e9SAndroid Build Coastguard Worker - q._cont_bern_log_norm() 524*da0073e9SAndroid Build Coastguard Worker ) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Pareto) 528*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_infinity(p, q): 529*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.concentration1) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Exponential) 533*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_exponential(p, q): 534*da0073e9SAndroid Build Coastguard Worker return ( 535*da0073e9SAndroid Build Coastguard Worker -p.entropy() 536*da0073e9SAndroid Build Coastguard Worker - q.rate.log() 537*da0073e9SAndroid Build Coastguard Worker + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) 538*da0073e9SAndroid Build Coastguard Worker ) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Gamma) 542*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_gamma(p, q): 543*da0073e9SAndroid Build Coastguard Worker t1 = -p.entropy() 544*da0073e9SAndroid Build Coastguard Worker t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 545*da0073e9SAndroid Build Coastguard Worker t3 = (q.concentration - 1) * ( 546*da0073e9SAndroid Build Coastguard Worker p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() 547*da0073e9SAndroid Build Coastguard Worker ) 548*da0073e9SAndroid Build Coastguard Worker t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) 549*da0073e9SAndroid Build Coastguard Worker return t1 + t2 - t3 + t4 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker# TODO: Add Beta-Laplace KL Divergence 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Normal) 556*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_normal(p, q): 557*da0073e9SAndroid Build Coastguard Worker E_beta = p.concentration1 / (p.concentration1 + p.concentration0) 558*da0073e9SAndroid Build Coastguard Worker var_normal = q.scale.pow(2) 559*da0073e9SAndroid Build Coastguard Worker t1 = -p.entropy() 560*da0073e9SAndroid Build Coastguard Worker t2 = 0.5 * (var_normal * 2 * math.pi).log() 561*da0073e9SAndroid Build Coastguard Worker t3 = ( 562*da0073e9SAndroid Build Coastguard Worker E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) 563*da0073e9SAndroid Build Coastguard Worker + E_beta.pow(2) 564*da0073e9SAndroid Build Coastguard Worker ) * 0.5 565*da0073e9SAndroid Build Coastguard Worker t4 = q.loc * E_beta 566*da0073e9SAndroid Build Coastguard Worker t5 = q.loc.pow(2) * 0.5 567*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + (t3 - t4 + t5) / var_normal 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker@register_kl(Beta, Uniform) 571*da0073e9SAndroid Build Coastguard Workerdef _kl_beta_uniform(p, q): 572*da0073e9SAndroid Build Coastguard Worker result = -p.entropy() + (q.high - q.low).log() 573*da0073e9SAndroid Build Coastguard Worker result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf 574*da0073e9SAndroid Build Coastguard Worker return result 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker# Note that the KL between a ContinuousBernoulli and Beta has no closed form 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Pareto) 581*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_infinity(p, q): 582*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.probs) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Exponential) 586*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_exponential(p, q): 587*da0073e9SAndroid Build Coastguard Worker return -p.entropy() - torch.log(q.rate) + q.rate * p.mean 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker# Note that the KL between a ContinuousBernoulli and Gamma has no closed form 591*da0073e9SAndroid Build Coastguard Worker# TODO: Add ContinuousBernoulli-Laplace KL Divergence 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Normal) 595*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_normal(p, q): 596*da0073e9SAndroid Build Coastguard Worker t1 = -p.entropy() 597*da0073e9SAndroid Build Coastguard Worker t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( 598*da0073e9SAndroid Build Coastguard Worker q.scale 599*da0073e9SAndroid Build Coastguard Worker ) 600*da0073e9SAndroid Build Coastguard Worker t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( 601*da0073e9SAndroid Build Coastguard Worker 2.0 * torch.square(q.scale) 602*da0073e9SAndroid Build Coastguard Worker ) 603*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker@register_kl(ContinuousBernoulli, Uniform) 607*da0073e9SAndroid Build Coastguard Workerdef _kl_continuous_bernoulli_uniform(p, q): 608*da0073e9SAndroid Build Coastguard Worker result = -p.entropy() + (q.high - q.low).log() 609*da0073e9SAndroid Build Coastguard Worker return torch.where( 610*da0073e9SAndroid Build Coastguard Worker torch.max( 611*da0073e9SAndroid Build Coastguard Worker torch.ge(q.low, p.support.lower_bound), 612*da0073e9SAndroid Build Coastguard Worker torch.le(q.high, p.support.upper_bound), 613*da0073e9SAndroid Build Coastguard Worker ), 614*da0073e9SAndroid Build Coastguard Worker torch.ones_like(result) * inf, 615*da0073e9SAndroid Build Coastguard Worker result, 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Beta) 620*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, ContinuousBernoulli) 621*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Pareto) 622*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Uniform) 623*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_infinity(p, q): 624*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.rate) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Gamma) 628*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_gamma(p, q): 629*da0073e9SAndroid Build Coastguard Worker ratio = q.rate / p.rate 630*da0073e9SAndroid Build Coastguard Worker t1 = -q.concentration * torch.log(ratio) 631*da0073e9SAndroid Build Coastguard Worker return ( 632*da0073e9SAndroid Build Coastguard Worker t1 633*da0073e9SAndroid Build Coastguard Worker + ratio 634*da0073e9SAndroid Build Coastguard Worker + q.concentration.lgamma() 635*da0073e9SAndroid Build Coastguard Worker + q.concentration * _euler_gamma 636*da0073e9SAndroid Build Coastguard Worker - (1 + _euler_gamma) 637*da0073e9SAndroid Build Coastguard Worker ) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Gumbel) 641*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_gumbel(p, q): 642*da0073e9SAndroid Build Coastguard Worker scale_rate_prod = p.rate * q.scale 643*da0073e9SAndroid Build Coastguard Worker loc_scale_ratio = q.loc / q.scale 644*da0073e9SAndroid Build Coastguard Worker t1 = scale_rate_prod.log() - 1 645*da0073e9SAndroid Build Coastguard Worker t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) 646*da0073e9SAndroid Build Coastguard Worker t3 = scale_rate_prod.reciprocal() 647*da0073e9SAndroid Build Coastguard Worker return t1 - loc_scale_ratio + t2 + t3 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker# TODO: Add Exponential-Laplace KL Divergence 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker@register_kl(Exponential, Normal) 654*da0073e9SAndroid Build Coastguard Workerdef _kl_exponential_normal(p, q): 655*da0073e9SAndroid Build Coastguard Worker var_normal = q.scale.pow(2) 656*da0073e9SAndroid Build Coastguard Worker rate_sqr = p.rate.pow(2) 657*da0073e9SAndroid Build Coastguard Worker t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi) 658*da0073e9SAndroid Build Coastguard Worker t2 = rate_sqr.reciprocal() 659*da0073e9SAndroid Build Coastguard Worker t3 = q.loc / p.rate 660*da0073e9SAndroid Build Coastguard Worker t4 = q.loc.pow(2) * 0.5 661*da0073e9SAndroid Build Coastguard Worker return t1 - 1 + (t2 - t3 + t4) / var_normal 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Beta) 665*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, ContinuousBernoulli) 666*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Pareto) 667*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Uniform) 668*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_infinity(p, q): 669*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.concentration) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Exponential) 673*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_exponential(p, q): 674*da0073e9SAndroid Build Coastguard Worker return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Gumbel) 678*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_gumbel(p, q): 679*da0073e9SAndroid Build Coastguard Worker beta_scale_prod = p.rate * q.scale 680*da0073e9SAndroid Build Coastguard Worker loc_scale_ratio = q.loc / q.scale 681*da0073e9SAndroid Build Coastguard Worker t1 = ( 682*da0073e9SAndroid Build Coastguard Worker (p.concentration - 1) * p.concentration.digamma() 683*da0073e9SAndroid Build Coastguard Worker - p.concentration.lgamma() 684*da0073e9SAndroid Build Coastguard Worker - p.concentration 685*da0073e9SAndroid Build Coastguard Worker ) 686*da0073e9SAndroid Build Coastguard Worker t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod 687*da0073e9SAndroid Build Coastguard Worker t3 = ( 688*da0073e9SAndroid Build Coastguard Worker torch.exp(loc_scale_ratio) 689*da0073e9SAndroid Build Coastguard Worker * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) 690*da0073e9SAndroid Build Coastguard Worker - loc_scale_ratio 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker return t1 + t2 + t3 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker# TODO: Add Gamma-Laplace KL Divergence 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker@register_kl(Gamma, Normal) 699*da0073e9SAndroid Build Coastguard Workerdef _kl_gamma_normal(p, q): 700*da0073e9SAndroid Build Coastguard Worker var_normal = q.scale.pow(2) 701*da0073e9SAndroid Build Coastguard Worker beta_sqr = p.rate.pow(2) 702*da0073e9SAndroid Build Coastguard Worker t1 = ( 703*da0073e9SAndroid Build Coastguard Worker 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) 704*da0073e9SAndroid Build Coastguard Worker - p.concentration 705*da0073e9SAndroid Build Coastguard Worker - p.concentration.lgamma() 706*da0073e9SAndroid Build Coastguard Worker ) 707*da0073e9SAndroid Build Coastguard Worker t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr 708*da0073e9SAndroid Build Coastguard Worker t3 = q.loc * p.concentration / p.rate 709*da0073e9SAndroid Build Coastguard Worker t4 = 0.5 * q.loc.pow(2) 710*da0073e9SAndroid Build Coastguard Worker return ( 711*da0073e9SAndroid Build Coastguard Worker t1 712*da0073e9SAndroid Build Coastguard Worker + (p.concentration - 1) * p.concentration.digamma() 713*da0073e9SAndroid Build Coastguard Worker + (t2 - t3 + t4) / var_normal 714*da0073e9SAndroid Build Coastguard Worker ) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Beta) 718*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, ContinuousBernoulli) 719*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Exponential) 720*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Gamma) 721*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Pareto) 722*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Uniform) 723*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_infinity(p, q): 724*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.loc) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker# TODO: Add Gumbel-Laplace KL Divergence 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker@register_kl(Gumbel, Normal) 731*da0073e9SAndroid Build Coastguard Workerdef _kl_gumbel_normal(p, q): 732*da0073e9SAndroid Build Coastguard Worker param_ratio = p.scale / q.scale 733*da0073e9SAndroid Build Coastguard Worker t1 = (param_ratio / math.sqrt(2 * math.pi)).log() 734*da0073e9SAndroid Build Coastguard Worker t2 = (math.pi * param_ratio * 0.5).pow(2) / 3 735*da0073e9SAndroid Build Coastguard Worker t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5 736*da0073e9SAndroid Build Coastguard Worker return -t1 + t2 + t3 - (_euler_gamma + 1) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Beta) 740*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, ContinuousBernoulli) 741*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Exponential) 742*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Gamma) 743*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Pareto) 744*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Uniform) 745*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_infinity(p, q): 746*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.loc) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker@register_kl(Laplace, Normal) 750*da0073e9SAndroid Build Coastguard Workerdef _kl_laplace_normal(p, q): 751*da0073e9SAndroid Build Coastguard Worker var_normal = q.scale.pow(2) 752*da0073e9SAndroid Build Coastguard Worker scale_sqr_var_ratio = p.scale.pow(2) / var_normal 753*da0073e9SAndroid Build Coastguard Worker t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi) 754*da0073e9SAndroid Build Coastguard Worker t2 = 0.5 * p.loc.pow(2) 755*da0073e9SAndroid Build Coastguard Worker t3 = p.loc * q.loc 756*da0073e9SAndroid Build Coastguard Worker t4 = 0.5 * q.loc.pow(2) 757*da0073e9SAndroid Build Coastguard Worker return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Beta) 761*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, ContinuousBernoulli) 762*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Exponential) 763*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Gamma) 764*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Pareto) 765*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Uniform) 766*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_infinity(p, q): 767*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.loc) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Gumbel) 771*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_gumbel(p, q): 772*da0073e9SAndroid Build Coastguard Worker mean_scale_ratio = p.loc / q.scale 773*da0073e9SAndroid Build Coastguard Worker var_scale_sqr_ratio = (p.scale / q.scale).pow(2) 774*da0073e9SAndroid Build Coastguard Worker loc_scale_ratio = q.loc / q.scale 775*da0073e9SAndroid Build Coastguard Worker t1 = var_scale_sqr_ratio.log() * 0.5 776*da0073e9SAndroid Build Coastguard Worker t2 = mean_scale_ratio - loc_scale_ratio 777*da0073e9SAndroid Build Coastguard Worker t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio) 778*da0073e9SAndroid Build Coastguard Worker return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi))) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker@register_kl(Normal, Laplace) 782*da0073e9SAndroid Build Coastguard Workerdef _kl_normal_laplace(p, q): 783*da0073e9SAndroid Build Coastguard Worker loc_diff = p.loc - q.loc 784*da0073e9SAndroid Build Coastguard Worker scale_ratio = p.scale / q.scale 785*da0073e9SAndroid Build Coastguard Worker loc_diff_scale_ratio = loc_diff / p.scale 786*da0073e9SAndroid Build Coastguard Worker t1 = torch.log(scale_ratio) 787*da0073e9SAndroid Build Coastguard Worker t2 = ( 788*da0073e9SAndroid Build Coastguard Worker math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) 791*da0073e9SAndroid Build Coastguard Worker return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Beta) 795*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, ContinuousBernoulli) 796*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Uniform) 797*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_infinity(p, q): 798*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.scale) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Exponential) 802*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_exponential(p, q): 803*da0073e9SAndroid Build Coastguard Worker scale_rate_prod = p.scale * q.rate 804*da0073e9SAndroid Build Coastguard Worker t1 = (p.alpha / scale_rate_prod).log() 805*da0073e9SAndroid Build Coastguard Worker t2 = p.alpha.reciprocal() 806*da0073e9SAndroid Build Coastguard Worker t3 = p.alpha * scale_rate_prod / (p.alpha - 1) 807*da0073e9SAndroid Build Coastguard Worker result = t1 - t2 + t3 - 1 808*da0073e9SAndroid Build Coastguard Worker result[p.alpha <= 1] = inf 809*da0073e9SAndroid Build Coastguard Worker return result 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Gamma) 813*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_gamma(p, q): 814*da0073e9SAndroid Build Coastguard Worker common_term = p.scale.log() + p.alpha.reciprocal() 815*da0073e9SAndroid Build Coastguard Worker t1 = p.alpha.log() - common_term 816*da0073e9SAndroid Build Coastguard Worker t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 817*da0073e9SAndroid Build Coastguard Worker t3 = (1 - q.concentration) * common_term 818*da0073e9SAndroid Build Coastguard Worker t4 = q.rate * p.alpha * p.scale / (p.alpha - 1) 819*da0073e9SAndroid Build Coastguard Worker result = t1 + t2 + t3 + t4 - 1 820*da0073e9SAndroid Build Coastguard Worker result[p.alpha <= 1] = inf 821*da0073e9SAndroid Build Coastguard Worker return result 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker# TODO: Add Pareto-Laplace KL Divergence 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker@register_kl(Pareto, Normal) 828*da0073e9SAndroid Build Coastguard Workerdef _kl_pareto_normal(p, q): 829*da0073e9SAndroid Build Coastguard Worker var_normal = 2 * q.scale.pow(2) 830*da0073e9SAndroid Build Coastguard Worker common_term = p.scale / (p.alpha - 1) 831*da0073e9SAndroid Build Coastguard Worker t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log() 832*da0073e9SAndroid Build Coastguard Worker t2 = p.alpha.reciprocal() 833*da0073e9SAndroid Build Coastguard Worker t3 = p.alpha * common_term.pow(2) / (p.alpha - 2) 834*da0073e9SAndroid Build Coastguard Worker t4 = (p.alpha * common_term - q.loc).pow(2) 835*da0073e9SAndroid Build Coastguard Worker result = t1 - t2 + (t3 + t4) / var_normal - 1 836*da0073e9SAndroid Build Coastguard Worker result[p.alpha <= 2] = inf 837*da0073e9SAndroid Build Coastguard Worker return result 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Bernoulli) 841*da0073e9SAndroid Build Coastguard Worker@register_kl(Poisson, Binomial) 842*da0073e9SAndroid Build Coastguard Workerdef _kl_poisson_infinity(p, q): 843*da0073e9SAndroid Build Coastguard Worker return _infinite_like(p.rate) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Beta) 847*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_beta(p, q): 848*da0073e9SAndroid Build Coastguard Worker common_term = p.high - p.low 849*da0073e9SAndroid Build Coastguard Worker t1 = torch.log(common_term) 850*da0073e9SAndroid Build Coastguard Worker t2 = ( 851*da0073e9SAndroid Build Coastguard Worker (q.concentration1 - 1) 852*da0073e9SAndroid Build Coastguard Worker * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) 853*da0073e9SAndroid Build Coastguard Worker / common_term 854*da0073e9SAndroid Build Coastguard Worker ) 855*da0073e9SAndroid Build Coastguard Worker t3 = ( 856*da0073e9SAndroid Build Coastguard Worker (q.concentration0 - 1) 857*da0073e9SAndroid Build Coastguard Worker * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) 858*da0073e9SAndroid Build Coastguard Worker / common_term 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker t4 = ( 861*da0073e9SAndroid Build Coastguard Worker q.concentration1.lgamma() 862*da0073e9SAndroid Build Coastguard Worker + q.concentration0.lgamma() 863*da0073e9SAndroid Build Coastguard Worker - (q.concentration1 + q.concentration0).lgamma() 864*da0073e9SAndroid Build Coastguard Worker ) 865*da0073e9SAndroid Build Coastguard Worker result = t3 + t4 - t1 - t2 866*da0073e9SAndroid Build Coastguard Worker result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf 867*da0073e9SAndroid Build Coastguard Worker return result 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, ContinuousBernoulli) 871*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_continuous_bernoulli(p, q): 872*da0073e9SAndroid Build Coastguard Worker result = ( 873*da0073e9SAndroid Build Coastguard Worker -p.entropy() 874*da0073e9SAndroid Build Coastguard Worker - p.mean * q.logits 875*da0073e9SAndroid Build Coastguard Worker - torch.log1p(-q.probs) 876*da0073e9SAndroid Build Coastguard Worker - q._cont_bern_log_norm() 877*da0073e9SAndroid Build Coastguard Worker ) 878*da0073e9SAndroid Build Coastguard Worker return torch.where( 879*da0073e9SAndroid Build Coastguard Worker torch.max( 880*da0073e9SAndroid Build Coastguard Worker torch.ge(p.high, q.support.upper_bound), 881*da0073e9SAndroid Build Coastguard Worker torch.le(p.low, q.support.lower_bound), 882*da0073e9SAndroid Build Coastguard Worker ), 883*da0073e9SAndroid Build Coastguard Worker torch.ones_like(result) * inf, 884*da0073e9SAndroid Build Coastguard Worker result, 885*da0073e9SAndroid Build Coastguard Worker ) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Exponential) 889*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_exponetial(p, q): 890*da0073e9SAndroid Build Coastguard Worker result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log() 891*da0073e9SAndroid Build Coastguard Worker result[p.low < q.support.lower_bound] = inf 892*da0073e9SAndroid Build Coastguard Worker return result 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Gamma) 896*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_gamma(p, q): 897*da0073e9SAndroid Build Coastguard Worker common_term = p.high - p.low 898*da0073e9SAndroid Build Coastguard Worker t1 = common_term.log() 899*da0073e9SAndroid Build Coastguard Worker t2 = q.concentration.lgamma() - q.concentration * q.rate.log() 900*da0073e9SAndroid Build Coastguard Worker t3 = ( 901*da0073e9SAndroid Build Coastguard Worker (1 - q.concentration) 902*da0073e9SAndroid Build Coastguard Worker * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) 903*da0073e9SAndroid Build Coastguard Worker / common_term 904*da0073e9SAndroid Build Coastguard Worker ) 905*da0073e9SAndroid Build Coastguard Worker t4 = q.rate * (p.high + p.low) / 2 906*da0073e9SAndroid Build Coastguard Worker result = -t1 + t2 + t3 + t4 907*da0073e9SAndroid Build Coastguard Worker result[p.low < q.support.lower_bound] = inf 908*da0073e9SAndroid Build Coastguard Worker return result 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Gumbel) 912*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_gumbel(p, q): 913*da0073e9SAndroid Build Coastguard Worker common_term = q.scale / (p.high - p.low) 914*da0073e9SAndroid Build Coastguard Worker high_loc_diff = (p.high - q.loc) / q.scale 915*da0073e9SAndroid Build Coastguard Worker low_loc_diff = (p.low - q.loc) / q.scale 916*da0073e9SAndroid Build Coastguard Worker t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff) 917*da0073e9SAndroid Build Coastguard Worker t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) 918*da0073e9SAndroid Build Coastguard Worker return t1 - t2 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker# TODO: Uniform-Laplace KL Divergence 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Normal) 925*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_normal(p, q): 926*da0073e9SAndroid Build Coastguard Worker common_term = p.high - p.low 927*da0073e9SAndroid Build Coastguard Worker t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log() 928*da0073e9SAndroid Build Coastguard Worker t2 = (common_term).pow(2) / 12 929*da0073e9SAndroid Build Coastguard Worker t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2) 930*da0073e9SAndroid Build Coastguard Worker return t1 + 0.5 * (t2 + t3) / q.scale.pow(2) 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker@register_kl(Uniform, Pareto) 934*da0073e9SAndroid Build Coastguard Workerdef _kl_uniform_pareto(p, q): 935*da0073e9SAndroid Build Coastguard Worker support_uniform = p.high - p.low 936*da0073e9SAndroid Build Coastguard Worker t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log() 937*da0073e9SAndroid Build Coastguard Worker t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform 938*da0073e9SAndroid Build Coastguard Worker result = t2 * (q.alpha + 1) - t1 939*da0073e9SAndroid Build Coastguard Worker result[p.low < q.support.lower_bound] = inf 940*da0073e9SAndroid Build Coastguard Worker return result 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker@register_kl(Independent, Independent) 944*da0073e9SAndroid Build Coastguard Workerdef _kl_independent_independent(p, q): 945*da0073e9SAndroid Build Coastguard Worker if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: 946*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 947*da0073e9SAndroid Build Coastguard Worker result = kl_divergence(p.base_dist, q.base_dist) 948*da0073e9SAndroid Build Coastguard Worker return _sum_rightmost(result, p.reinterpreted_batch_ndims) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker@register_kl(Cauchy, Cauchy) 952*da0073e9SAndroid Build Coastguard Workerdef _kl_cauchy_cauchy(p, q): 953*da0073e9SAndroid Build Coastguard Worker # From https://arxiv.org/abs/1905.10965 954*da0073e9SAndroid Build Coastguard Worker t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log() 955*da0073e9SAndroid Build Coastguard Worker t2 = (4 * p.scale * q.scale).log() 956*da0073e9SAndroid Build Coastguard Worker return t1 - t2 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Workerdef _add_kl_info(): 960*da0073e9SAndroid Build Coastguard Worker """Appends a list of implemented KL functions to the doc for kl_divergence.""" 961*da0073e9SAndroid Build Coastguard Worker rows = [ 962*da0073e9SAndroid Build Coastguard Worker "KL divergence is currently implemented for the following distribution pairs:" 963*da0073e9SAndroid Build Coastguard Worker ] 964*da0073e9SAndroid Build Coastguard Worker for p, q in sorted( 965*da0073e9SAndroid Build Coastguard Worker _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) 966*da0073e9SAndroid Build Coastguard Worker ): 967*da0073e9SAndroid Build Coastguard Worker rows.append( 968*da0073e9SAndroid Build Coastguard Worker f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker kl_info = "\n\t".join(rows) 971*da0073e9SAndroid Build Coastguard Worker if kl_divergence.__doc__: 972*da0073e9SAndroid Build Coastguard Worker kl_divergence.__doc__ += kl_info # type: ignore[operator] 973