1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 8*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker__all__ = ["Laplace"] 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass Laplace(Distribution): 15*da0073e9SAndroid Build Coastguard Worker r""" 16*da0073e9SAndroid Build Coastguard Worker Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker Example:: 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 21*da0073e9SAndroid Build Coastguard Worker >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) 22*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Laplace distributed with loc=0, scale=1 23*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1046]) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker Args: 26*da0073e9SAndroid Build Coastguard Worker loc (float or Tensor): mean of the distribution 27*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): scale of the distribution 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 30*da0073e9SAndroid Build Coastguard Worker support = constraints.real 31*da0073e9SAndroid Build Coastguard Worker has_rsample = True 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker @property 34*da0073e9SAndroid Build Coastguard Worker def mean(self): 35*da0073e9SAndroid Build Coastguard Worker return self.loc 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker @property 38*da0073e9SAndroid Build Coastguard Worker def mode(self): 39*da0073e9SAndroid Build Coastguard Worker return self.loc 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker @property 42*da0073e9SAndroid Build Coastguard Worker def variance(self): 43*da0073e9SAndroid Build Coastguard Worker return 2 * self.scale.pow(2) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker @property 46*da0073e9SAndroid Build Coastguard Worker def stddev(self): 47*da0073e9SAndroid Build Coastguard Worker return (2**0.5) * self.scale 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def __init__(self, loc, scale, validate_args=None): 50*da0073e9SAndroid Build Coastguard Worker self.loc, self.scale = broadcast_all(loc, scale) 51*da0073e9SAndroid Build Coastguard Worker if isinstance(loc, Number) and isinstance(scale, Number): 52*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size() 53*da0073e9SAndroid Build Coastguard Worker else: 54*da0073e9SAndroid Build Coastguard Worker batch_shape = self.loc.size() 55*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 58*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Laplace, _instance) 59*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 60*da0073e9SAndroid Build Coastguard Worker new.loc = self.loc.expand(batch_shape) 61*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 62*da0073e9SAndroid Build Coastguard Worker super(Laplace, new).__init__(batch_shape, validate_args=False) 63*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 64*da0073e9SAndroid Build Coastguard Worker return new 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 67*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 68*da0073e9SAndroid Build Coastguard Worker finfo = torch.finfo(self.loc.dtype) 69*da0073e9SAndroid Build Coastguard Worker if torch._C._get_tracing_state(): 70*da0073e9SAndroid Build Coastguard Worker # [JIT WORKAROUND] lack of support for .uniform_() 71*da0073e9SAndroid Build Coastguard Worker u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1 72*da0073e9SAndroid Build Coastguard Worker return self.loc - self.scale * u.sign() * torch.log1p( 73*da0073e9SAndroid Build Coastguard Worker -u.abs().clamp(min=finfo.tiny) 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker u = self.loc.new(shape).uniform_(finfo.eps - 1, 1) 76*da0073e9SAndroid Build Coastguard Worker # TODO: If we ever implement tensor.nextafter, below is what we want ideally. 77*da0073e9SAndroid Build Coastguard Worker # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) 78*da0073e9SAndroid Build Coastguard Worker return self.loc - self.scale * u.sign() * torch.log1p(-u.abs()) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 81*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 82*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 83*da0073e9SAndroid Build Coastguard Worker return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def cdf(self, value): 86*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 87*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 88*da0073e9SAndroid Build Coastguard Worker return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1( 89*da0073e9SAndroid Build Coastguard Worker -(value - self.loc).abs() / self.scale 90*da0073e9SAndroid Build Coastguard Worker ) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def icdf(self, value): 93*da0073e9SAndroid Build Coastguard Worker term = value - 0.5 94*da0073e9SAndroid Build Coastguard Worker return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs()) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker def entropy(self): 97*da0073e9SAndroid Build Coastguard Worker return 1 + torch.log(2 * self.scale) 98