1# mypy: allow-untyped-defs 2import math 3 4import torch 5import torch.jit 6from torch.distributions import constraints 7from torch.distributions.distribution import Distribution 8from torch.distributions.utils import broadcast_all, lazy_property 9 10 11__all__ = ["VonMises"] 12 13 14def _eval_poly(y, coef): 15 coef = list(coef) 16 result = coef.pop() 17 while coef: 18 result = coef.pop() + y * result 19 return result 20 21 22_I0_COEF_SMALL = [ 23 1.0, 24 3.5156229, 25 3.0899424, 26 1.2067492, 27 0.2659732, 28 0.360768e-1, 29 0.45813e-2, 30] 31_I0_COEF_LARGE = [ 32 0.39894228, 33 0.1328592e-1, 34 0.225319e-2, 35 -0.157565e-2, 36 0.916281e-2, 37 -0.2057706e-1, 38 0.2635537e-1, 39 -0.1647633e-1, 40 0.392377e-2, 41] 42_I1_COEF_SMALL = [ 43 0.5, 44 0.87890594, 45 0.51498869, 46 0.15084934, 47 0.2658733e-1, 48 0.301532e-2, 49 0.32411e-3, 50] 51_I1_COEF_LARGE = [ 52 0.39894228, 53 -0.3988024e-1, 54 -0.362018e-2, 55 0.163801e-2, 56 -0.1031555e-1, 57 0.2282967e-1, 58 -0.2895312e-1, 59 0.1787654e-1, 60 -0.420059e-2, 61] 62 63_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] 64_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] 65 66 67def _log_modified_bessel_fn(x, order=0): 68 """ 69 Returns ``log(I_order(x))`` for ``x > 0``, 70 where `order` is either 0 or 1. 71 """ 72 assert order == 0 or order == 1 73 74 # compute small solution 75 y = x / 3.75 76 y = y * y 77 small = _eval_poly(y, _COEF_SMALL[order]) 78 if order == 1: 79 small = x.abs() * small 80 small = small.log() 81 82 # compute large solution 83 y = 3.75 / x 84 large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() 85 86 result = torch.where(x < 3.75, small, large) 87 return result 88 89 90@torch.jit.script_if_tracing 91def _rejection_sample(loc, concentration, proposal_r, x): 92 done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) 93 while not done.all(): 94 u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) 95 u1, u2, u3 = u.unbind() 96 z = torch.cos(math.pi * u1) 97 f = (1 + proposal_r * z) / (proposal_r + z) 98 c = concentration * (proposal_r - f) 99 accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) 100 if accept.any(): 101 x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) 102 done = done | accept 103 return (x + math.pi + loc) % (2 * math.pi) - math.pi 104 105 106class VonMises(Distribution): 107 """ 108 A circular von Mises distribution. 109 110 This implementation uses polar coordinates. The ``loc`` and ``value`` args 111 can be any real number (to facilitate unconstrained optimization), but are 112 interpreted as angles modulo 2 pi. 113 114 Example:: 115 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 116 >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) 117 >>> m.sample() # von Mises distributed with loc=1 and concentration=1 118 tensor([1.9777]) 119 120 :param torch.Tensor loc: an angle in radians. 121 :param torch.Tensor concentration: concentration parameter 122 """ 123 124 arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} 125 support = constraints.real 126 has_rsample = False 127 128 def __init__(self, loc, concentration, validate_args=None): 129 self.loc, self.concentration = broadcast_all(loc, concentration) 130 batch_shape = self.loc.shape 131 event_shape = torch.Size() 132 super().__init__(batch_shape, event_shape, validate_args) 133 134 def log_prob(self, value): 135 if self._validate_args: 136 self._validate_sample(value) 137 log_prob = self.concentration * torch.cos(value - self.loc) 138 log_prob = ( 139 log_prob 140 - math.log(2 * math.pi) 141 - _log_modified_bessel_fn(self.concentration, order=0) 142 ) 143 return log_prob 144 145 @lazy_property 146 def _loc(self): 147 return self.loc.to(torch.double) 148 149 @lazy_property 150 def _concentration(self): 151 return self.concentration.to(torch.double) 152 153 @lazy_property 154 def _proposal_r(self): 155 kappa = self._concentration 156 tau = 1 + (1 + 4 * kappa**2).sqrt() 157 rho = (tau - (2 * tau).sqrt()) / (2 * kappa) 158 _proposal_r = (1 + rho**2) / (2 * rho) 159 # second order Taylor expansion around 0 for small kappa 160 _proposal_r_taylor = 1 / kappa + kappa 161 return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r) 162 163 @torch.no_grad() 164 def sample(self, sample_shape=torch.Size()): 165 """ 166 The sampling algorithm for the von Mises distribution is based on the 167 following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the 168 von Mises distribution." Applied Statistics (1979): 152-157. 169 170 Sampling is always done in double precision internally to avoid a hang 171 in _rejection_sample() for small values of the concentration, which 172 starts to happen for single precision around 1e-4 (see issue #88443). 173 """ 174 shape = self._extended_shape(sample_shape) 175 x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device) 176 return _rejection_sample( 177 self._loc, self._concentration, self._proposal_r, x 178 ).to(self.loc.dtype) 179 180 def expand(self, batch_shape): 181 try: 182 return super().expand(batch_shape) 183 except NotImplementedError: 184 validate_args = self.__dict__.get("_validate_args") 185 loc = self.loc.expand(batch_shape) 186 concentration = self.concentration.expand(batch_shape) 187 return type(self)(loc, concentration, validate_args=validate_args) 188 189 @property 190 def mean(self): 191 """ 192 The provided mean is the circular one. 193 """ 194 return self.loc 195 196 @property 197 def mode(self): 198 return self.loc 199 200 @lazy_property 201 def variance(self): 202 """ 203 The provided variance is the circular one. 204 """ 205 return ( 206 1 207 - ( 208 _log_modified_bessel_fn(self.concentration, order=1) 209 - _log_modified_bessel_fn(self.concentration, order=0) 210 ).exp() 211 ) 212