xref: /aosp_15_r20/external/pytorch/torch/distributions/von_mises.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3
4import torch
5import torch.jit
6from torch.distributions import constraints
7from torch.distributions.distribution import Distribution
8from torch.distributions.utils import broadcast_all, lazy_property
9
10
11__all__ = ["VonMises"]
12
13
14def _eval_poly(y, coef):
15    coef = list(coef)
16    result = coef.pop()
17    while coef:
18        result = coef.pop() + y * result
19    return result
20
21
22_I0_COEF_SMALL = [
23    1.0,
24    3.5156229,
25    3.0899424,
26    1.2067492,
27    0.2659732,
28    0.360768e-1,
29    0.45813e-2,
30]
31_I0_COEF_LARGE = [
32    0.39894228,
33    0.1328592e-1,
34    0.225319e-2,
35    -0.157565e-2,
36    0.916281e-2,
37    -0.2057706e-1,
38    0.2635537e-1,
39    -0.1647633e-1,
40    0.392377e-2,
41]
42_I1_COEF_SMALL = [
43    0.5,
44    0.87890594,
45    0.51498869,
46    0.15084934,
47    0.2658733e-1,
48    0.301532e-2,
49    0.32411e-3,
50]
51_I1_COEF_LARGE = [
52    0.39894228,
53    -0.3988024e-1,
54    -0.362018e-2,
55    0.163801e-2,
56    -0.1031555e-1,
57    0.2282967e-1,
58    -0.2895312e-1,
59    0.1787654e-1,
60    -0.420059e-2,
61]
62
63_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
64_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
65
66
67def _log_modified_bessel_fn(x, order=0):
68    """
69    Returns ``log(I_order(x))`` for ``x > 0``,
70    where `order` is either 0 or 1.
71    """
72    assert order == 0 or order == 1
73
74    # compute small solution
75    y = x / 3.75
76    y = y * y
77    small = _eval_poly(y, _COEF_SMALL[order])
78    if order == 1:
79        small = x.abs() * small
80    small = small.log()
81
82    # compute large solution
83    y = 3.75 / x
84    large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
85
86    result = torch.where(x < 3.75, small, large)
87    return result
88
89
90@torch.jit.script_if_tracing
91def _rejection_sample(loc, concentration, proposal_r, x):
92    done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
93    while not done.all():
94        u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
95        u1, u2, u3 = u.unbind()
96        z = torch.cos(math.pi * u1)
97        f = (1 + proposal_r * z) / (proposal_r + z)
98        c = concentration * (proposal_r - f)
99        accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
100        if accept.any():
101            x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
102            done = done | accept
103    return (x + math.pi + loc) % (2 * math.pi) - math.pi
104
105
106class VonMises(Distribution):
107    """
108    A circular von Mises distribution.
109
110    This implementation uses polar coordinates. The ``loc`` and ``value`` args
111    can be any real number (to facilitate unconstrained optimization), but are
112    interpreted as angles modulo 2 pi.
113
114    Example::
115        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
116        >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
117        >>> m.sample()  # von Mises distributed with loc=1 and concentration=1
118        tensor([1.9777])
119
120    :param torch.Tensor loc: an angle in radians.
121    :param torch.Tensor concentration: concentration parameter
122    """
123
124    arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
125    support = constraints.real
126    has_rsample = False
127
128    def __init__(self, loc, concentration, validate_args=None):
129        self.loc, self.concentration = broadcast_all(loc, concentration)
130        batch_shape = self.loc.shape
131        event_shape = torch.Size()
132        super().__init__(batch_shape, event_shape, validate_args)
133
134    def log_prob(self, value):
135        if self._validate_args:
136            self._validate_sample(value)
137        log_prob = self.concentration * torch.cos(value - self.loc)
138        log_prob = (
139            log_prob
140            - math.log(2 * math.pi)
141            - _log_modified_bessel_fn(self.concentration, order=0)
142        )
143        return log_prob
144
145    @lazy_property
146    def _loc(self):
147        return self.loc.to(torch.double)
148
149    @lazy_property
150    def _concentration(self):
151        return self.concentration.to(torch.double)
152
153    @lazy_property
154    def _proposal_r(self):
155        kappa = self._concentration
156        tau = 1 + (1 + 4 * kappa**2).sqrt()
157        rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
158        _proposal_r = (1 + rho**2) / (2 * rho)
159        # second order Taylor expansion around 0 for small kappa
160        _proposal_r_taylor = 1 / kappa + kappa
161        return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
162
163    @torch.no_grad()
164    def sample(self, sample_shape=torch.Size()):
165        """
166        The sampling algorithm for the von Mises distribution is based on the
167        following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
168        von Mises distribution." Applied Statistics (1979): 152-157.
169
170        Sampling is always done in double precision internally to avoid a hang
171        in _rejection_sample() for small values of the concentration, which
172        starts to happen for single precision around 1e-4 (see issue #88443).
173        """
174        shape = self._extended_shape(sample_shape)
175        x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
176        return _rejection_sample(
177            self._loc, self._concentration, self._proposal_r, x
178        ).to(self.loc.dtype)
179
180    def expand(self, batch_shape):
181        try:
182            return super().expand(batch_shape)
183        except NotImplementedError:
184            validate_args = self.__dict__.get("_validate_args")
185            loc = self.loc.expand(batch_shape)
186            concentration = self.concentration.expand(batch_shape)
187            return type(self)(loc, concentration, validate_args=validate_args)
188
189    @property
190    def mean(self):
191        """
192        The provided mean is the circular one.
193        """
194        return self.loc
195
196    @property
197    def mode(self):
198        return self.loc
199
200    @lazy_property
201    def variance(self):
202        """
203        The provided variance is the circular one.
204        """
205        return (
206            1
207            - (
208                _log_modified_bessel_fn(self.concentration, order=1)
209                - _log_modified_bessel_fn(self.concentration, order=0)
210            ).exp()
211        )
212