1# mypy: allow-untyped-defs 2import math 3from numbers import Number 4 5import torch 6from torch.distributions import constraints 7from torch.distributions.transformed_distribution import TransformedDistribution 8from torch.distributions.transforms import AffineTransform, ExpTransform 9from torch.distributions.uniform import Uniform 10from torch.distributions.utils import broadcast_all, euler_constant 11 12 13__all__ = ["Gumbel"] 14 15 16class Gumbel(TransformedDistribution): 17 r""" 18 Samples from a Gumbel Distribution. 19 20 Examples:: 21 22 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 23 >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) 24 >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 25 tensor([ 1.0124]) 26 27 Args: 28 loc (float or Tensor): Location parameter of the distribution 29 scale (float or Tensor): Scale parameter of the distribution 30 """ 31 arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 32 support = constraints.real 33 34 def __init__(self, loc, scale, validate_args=None): 35 self.loc, self.scale = broadcast_all(loc, scale) 36 finfo = torch.finfo(self.loc.dtype) 37 if isinstance(loc, Number) and isinstance(scale, Number): 38 base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) 39 else: 40 base_dist = Uniform( 41 torch.full_like(self.loc, finfo.tiny), 42 torch.full_like(self.loc, 1 - finfo.eps), 43 validate_args=validate_args, 44 ) 45 transforms = [ 46 ExpTransform().inv, 47 AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), 48 ExpTransform().inv, 49 AffineTransform(loc=loc, scale=-self.scale), 50 ] 51 super().__init__(base_dist, transforms, validate_args=validate_args) 52 53 def expand(self, batch_shape, _instance=None): 54 new = self._get_checked_instance(Gumbel, _instance) 55 new.loc = self.loc.expand(batch_shape) 56 new.scale = self.scale.expand(batch_shape) 57 return super().expand(batch_shape, _instance=new) 58 59 # Explicitly defining the log probability function for Gumbel due to precision issues 60 def log_prob(self, value): 61 if self._validate_args: 62 self._validate_sample(value) 63 y = (self.loc - value) / self.scale 64 return (y - y.exp()) - self.scale.log() 65 66 @property 67 def mean(self): 68 return self.loc + self.scale * euler_constant 69 70 @property 71 def mode(self): 72 return self.loc 73 74 @property 75 def stddev(self): 76 return (math.pi / math.sqrt(6)) * self.scale 77 78 @property 79 def variance(self): 80 return self.stddev.pow(2) 81 82 def entropy(self): 83 return self.scale.log() + (1 + euler_constant) 84