1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exponential import Exponential 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.gumbel import euler_constant 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transformed_distribution import TransformedDistribution 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import AffineTransform, PowerTransform 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker__all__ = ["Weibull"] 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass Weibull(TransformedDistribution): 15*da0073e9SAndroid Build Coastguard Worker r""" 16*da0073e9SAndroid Build Coastguard Worker Samples from a two-parameter Weibull distribution. 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker Example: 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 21*da0073e9SAndroid Build Coastguard Worker >>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) 22*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 23*da0073e9SAndroid Build Coastguard Worker tensor([ 0.4784]) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker Args: 26*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): Scale parameter of distribution (lambda). 27*da0073e9SAndroid Build Coastguard Worker concentration (float or Tensor): Concentration parameter of distribution (k/shape). 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 30*da0073e9SAndroid Build Coastguard Worker "scale": constraints.positive, 31*da0073e9SAndroid Build Coastguard Worker "concentration": constraints.positive, 32*da0073e9SAndroid Build Coastguard Worker } 33*da0073e9SAndroid Build Coastguard Worker support = constraints.positive 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def __init__(self, scale, concentration, validate_args=None): 36*da0073e9SAndroid Build Coastguard Worker self.scale, self.concentration = broadcast_all(scale, concentration) 37*da0073e9SAndroid Build Coastguard Worker self.concentration_reciprocal = self.concentration.reciprocal() 38*da0073e9SAndroid Build Coastguard Worker base_dist = Exponential( 39*da0073e9SAndroid Build Coastguard Worker torch.ones_like(self.scale), validate_args=validate_args 40*da0073e9SAndroid Build Coastguard Worker ) 41*da0073e9SAndroid Build Coastguard Worker transforms = [ 42*da0073e9SAndroid Build Coastguard Worker PowerTransform(exponent=self.concentration_reciprocal), 43*da0073e9SAndroid Build Coastguard Worker AffineTransform(loc=0, scale=self.scale), 44*da0073e9SAndroid Build Coastguard Worker ] 45*da0073e9SAndroid Build Coastguard Worker super().__init__(base_dist, transforms, validate_args=validate_args) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 48*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Weibull, _instance) 49*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 50*da0073e9SAndroid Build Coastguard Worker new.concentration = self.concentration.expand(batch_shape) 51*da0073e9SAndroid Build Coastguard Worker new.concentration_reciprocal = new.concentration.reciprocal() 52*da0073e9SAndroid Build Coastguard Worker base_dist = self.base_dist.expand(batch_shape) 53*da0073e9SAndroid Build Coastguard Worker transforms = [ 54*da0073e9SAndroid Build Coastguard Worker PowerTransform(exponent=new.concentration_reciprocal), 55*da0073e9SAndroid Build Coastguard Worker AffineTransform(loc=0, scale=new.scale), 56*da0073e9SAndroid Build Coastguard Worker ] 57*da0073e9SAndroid Build Coastguard Worker super(Weibull, new).__init__(base_dist, transforms, validate_args=False) 58*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 59*da0073e9SAndroid Build Coastguard Worker return new 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @property 62*da0073e9SAndroid Build Coastguard Worker def mean(self): 63*da0073e9SAndroid Build Coastguard Worker return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal)) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker @property 66*da0073e9SAndroid Build Coastguard Worker def mode(self): 67*da0073e9SAndroid Build Coastguard Worker return ( 68*da0073e9SAndroid Build Coastguard Worker self.scale 69*da0073e9SAndroid Build Coastguard Worker * ((self.concentration - 1) / self.concentration) 70*da0073e9SAndroid Build Coastguard Worker ** self.concentration.reciprocal() 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker @property 74*da0073e9SAndroid Build Coastguard Worker def variance(self): 75*da0073e9SAndroid Build Coastguard Worker return self.scale.pow(2) * ( 76*da0073e9SAndroid Build Coastguard Worker torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) 77*da0073e9SAndroid Build Coastguard Worker - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal)) 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def entropy(self): 81*da0073e9SAndroid Build Coastguard Worker return ( 82*da0073e9SAndroid Build Coastguard Worker euler_constant * (1 - self.concentration_reciprocal) 83*da0073e9SAndroid Build Coastguard Worker + torch.log(self.scale * self.concentration_reciprocal) 84*da0073e9SAndroid Build Coastguard Worker + 1 85*da0073e9SAndroid Build Coastguard Worker ) 86