1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transformed_distribution import TransformedDistribution 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import AffineTransform, ExpTransform 9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.uniform import Uniform 10*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all, euler_constant 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker__all__ = ["Gumbel"] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerclass Gumbel(TransformedDistribution): 17*da0073e9SAndroid Build Coastguard Worker r""" 18*da0073e9SAndroid Build Coastguard Worker Samples from a Gumbel Distribution. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker Examples:: 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 23*da0073e9SAndroid Build Coastguard Worker >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) 24*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 25*da0073e9SAndroid Build Coastguard Worker tensor([ 1.0124]) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker Args: 28*da0073e9SAndroid Build Coastguard Worker loc (float or Tensor): Location parameter of the distribution 29*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): Scale parameter of the distribution 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 32*da0073e9SAndroid Build Coastguard Worker support = constraints.real 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def __init__(self, loc, scale, validate_args=None): 35*da0073e9SAndroid Build Coastguard Worker self.loc, self.scale = broadcast_all(loc, scale) 36*da0073e9SAndroid Build Coastguard Worker finfo = torch.finfo(self.loc.dtype) 37*da0073e9SAndroid Build Coastguard Worker if isinstance(loc, Number) and isinstance(scale, Number): 38*da0073e9SAndroid Build Coastguard Worker base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) 39*da0073e9SAndroid Build Coastguard Worker else: 40*da0073e9SAndroid Build Coastguard Worker base_dist = Uniform( 41*da0073e9SAndroid Build Coastguard Worker torch.full_like(self.loc, finfo.tiny), 42*da0073e9SAndroid Build Coastguard Worker torch.full_like(self.loc, 1 - finfo.eps), 43*da0073e9SAndroid Build Coastguard Worker validate_args=validate_args, 44*da0073e9SAndroid Build Coastguard Worker ) 45*da0073e9SAndroid Build Coastguard Worker transforms = [ 46*da0073e9SAndroid Build Coastguard Worker ExpTransform().inv, 47*da0073e9SAndroid Build Coastguard Worker AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), 48*da0073e9SAndroid Build Coastguard Worker ExpTransform().inv, 49*da0073e9SAndroid Build Coastguard Worker AffineTransform(loc=loc, scale=-self.scale), 50*da0073e9SAndroid Build Coastguard Worker ] 51*da0073e9SAndroid Build Coastguard Worker super().__init__(base_dist, transforms, validate_args=validate_args) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 54*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Gumbel, _instance) 55*da0073e9SAndroid Build Coastguard Worker new.loc = self.loc.expand(batch_shape) 56*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 57*da0073e9SAndroid Build Coastguard Worker return super().expand(batch_shape, _instance=new) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker # Explicitly defining the log probability function for Gumbel due to precision issues 60*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 61*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 62*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 63*da0073e9SAndroid Build Coastguard Worker y = (self.loc - value) / self.scale 64*da0073e9SAndroid Build Coastguard Worker return (y - y.exp()) - self.scale.log() 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker @property 67*da0073e9SAndroid Build Coastguard Worker def mean(self): 68*da0073e9SAndroid Build Coastguard Worker return self.loc + self.scale * euler_constant 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker @property 71*da0073e9SAndroid Build Coastguard Worker def mode(self): 72*da0073e9SAndroid Build Coastguard Worker return self.loc 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker @property 75*da0073e9SAndroid Build Coastguard Worker def stddev(self): 76*da0073e9SAndroid Build Coastguard Worker return (math.pi / math.sqrt(6)) * self.scale 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker @property 79*da0073e9SAndroid Build Coastguard Worker def variance(self): 80*da0073e9SAndroid Build Coastguard Worker return self.stddev.pow(2) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def entropy(self): 83*da0073e9SAndroid Build Coastguard Worker return self.scale.log() + (1 + euler_constant) 84