xref: /aosp_15_r20/external/pytorch/torch/distributions/cauchy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3from numbers import Number
4
5import torch
6from torch import inf, nan
7from torch.distributions import constraints
8from torch.distributions.distribution import Distribution
9from torch.distributions.utils import broadcast_all
10from torch.types import _size
11
12
13__all__ = ["Cauchy"]
14
15
16class Cauchy(Distribution):
17    r"""
18    Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
19    independent normally distributed random variables with means `0` follows a
20    Cauchy distribution.
21
22    Example::
23
24        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
25        >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
26        >>> m.sample()  # sample from a Cauchy distribution with loc=0 and scale=1
27        tensor([ 2.3214])
28
29    Args:
30        loc (float or Tensor): mode or median of the distribution.
31        scale (float or Tensor): half width at half maximum.
32    """
33    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
34    support = constraints.real
35    has_rsample = True
36
37    def __init__(self, loc, scale, validate_args=None):
38        self.loc, self.scale = broadcast_all(loc, scale)
39        if isinstance(loc, Number) and isinstance(scale, Number):
40            batch_shape = torch.Size()
41        else:
42            batch_shape = self.loc.size()
43        super().__init__(batch_shape, validate_args=validate_args)
44
45    def expand(self, batch_shape, _instance=None):
46        new = self._get_checked_instance(Cauchy, _instance)
47        batch_shape = torch.Size(batch_shape)
48        new.loc = self.loc.expand(batch_shape)
49        new.scale = self.scale.expand(batch_shape)
50        super(Cauchy, new).__init__(batch_shape, validate_args=False)
51        new._validate_args = self._validate_args
52        return new
53
54    @property
55    def mean(self):
56        return torch.full(
57            self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
58        )
59
60    @property
61    def mode(self):
62        return self.loc
63
64    @property
65    def variance(self):
66        return torch.full(
67            self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
68        )
69
70    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
71        shape = self._extended_shape(sample_shape)
72        eps = self.loc.new(shape).cauchy_()
73        return self.loc + eps * self.scale
74
75    def log_prob(self, value):
76        if self._validate_args:
77            self._validate_sample(value)
78        return (
79            -math.log(math.pi)
80            - self.scale.log()
81            - (((value - self.loc) / self.scale) ** 2).log1p()
82        )
83
84    def cdf(self, value):
85        if self._validate_args:
86            self._validate_sample(value)
87        return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
88
89    def icdf(self, value):
90        return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
91
92    def entropy(self):
93        return math.log(4 * math.pi) + self.scale.log()
94