1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 10*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker__all__ = ["Cauchy"] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerclass Cauchy(Distribution): 17*da0073e9SAndroid Build Coastguard Worker r""" 18*da0073e9SAndroid Build Coastguard Worker Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of 19*da0073e9SAndroid Build Coastguard Worker independent normally distributed random variables with means `0` follows a 20*da0073e9SAndroid Build Coastguard Worker Cauchy distribution. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Example:: 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 25*da0073e9SAndroid Build Coastguard Worker >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) 26*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 27*da0073e9SAndroid Build Coastguard Worker tensor([ 2.3214]) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker Args: 30*da0073e9SAndroid Build Coastguard Worker loc (float or Tensor): mode or median of the distribution. 31*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): half width at half maximum. 32*da0073e9SAndroid Build Coastguard Worker """ 33*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 34*da0073e9SAndroid Build Coastguard Worker support = constraints.real 35*da0073e9SAndroid Build Coastguard Worker has_rsample = True 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def __init__(self, loc, scale, validate_args=None): 38*da0073e9SAndroid Build Coastguard Worker self.loc, self.scale = broadcast_all(loc, scale) 39*da0073e9SAndroid Build Coastguard Worker if isinstance(loc, Number) and isinstance(scale, Number): 40*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size() 41*da0073e9SAndroid Build Coastguard Worker else: 42*da0073e9SAndroid Build Coastguard Worker batch_shape = self.loc.size() 43*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 46*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Cauchy, _instance) 47*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 48*da0073e9SAndroid Build Coastguard Worker new.loc = self.loc.expand(batch_shape) 49*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 50*da0073e9SAndroid Build Coastguard Worker super(Cauchy, new).__init__(batch_shape, validate_args=False) 51*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 52*da0073e9SAndroid Build Coastguard Worker return new 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker @property 55*da0073e9SAndroid Build Coastguard Worker def mean(self): 56*da0073e9SAndroid Build Coastguard Worker return torch.full( 57*da0073e9SAndroid Build Coastguard Worker self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker @property 61*da0073e9SAndroid Build Coastguard Worker def mode(self): 62*da0073e9SAndroid Build Coastguard Worker return self.loc 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @property 65*da0073e9SAndroid Build Coastguard Worker def variance(self): 66*da0073e9SAndroid Build Coastguard Worker return torch.full( 67*da0073e9SAndroid Build Coastguard Worker self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 71*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 72*da0073e9SAndroid Build Coastguard Worker eps = self.loc.new(shape).cauchy_() 73*da0073e9SAndroid Build Coastguard Worker return self.loc + eps * self.scale 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 76*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 77*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 78*da0073e9SAndroid Build Coastguard Worker return ( 79*da0073e9SAndroid Build Coastguard Worker -math.log(math.pi) 80*da0073e9SAndroid Build Coastguard Worker - self.scale.log() 81*da0073e9SAndroid Build Coastguard Worker - (((value - self.loc) / self.scale) ** 2).log1p() 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def cdf(self, value): 85*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 86*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 87*da0073e9SAndroid Build Coastguard Worker return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def icdf(self, value): 90*da0073e9SAndroid Build Coastguard Worker return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def entropy(self): 93*da0073e9SAndroid Build Coastguard Worker return math.log(4 * math.pi) + self.scale.log() 94