1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch import nan 4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transformed_distribution import TransformedDistribution 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import AffineTransform, PowerTransform 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.uniform import Uniform 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all, euler_constant 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker__all__ = ["Kumaraswamy"] 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef _moments(a, b, n): 15*da0073e9SAndroid Build Coastguard Worker """ 16*da0073e9SAndroid Build Coastguard Worker Computes nth moment of Kumaraswamy using using torch.lgamma 17*da0073e9SAndroid Build Coastguard Worker """ 18*da0073e9SAndroid Build Coastguard Worker arg1 = 1 + n / a 19*da0073e9SAndroid Build Coastguard Worker log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b) 20*da0073e9SAndroid Build Coastguard Worker return b * torch.exp(log_value) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerclass Kumaraswamy(TransformedDistribution): 24*da0073e9SAndroid Build Coastguard Worker r""" 25*da0073e9SAndroid Build Coastguard Worker Samples from a Kumaraswamy distribution. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker Example:: 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 30*da0073e9SAndroid Build Coastguard Worker >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) 31*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 32*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1729]) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker Args: 35*da0073e9SAndroid Build Coastguard Worker concentration1 (float or Tensor): 1st concentration parameter of the distribution 36*da0073e9SAndroid Build Coastguard Worker (often referred to as alpha) 37*da0073e9SAndroid Build Coastguard Worker concentration0 (float or Tensor): 2nd concentration parameter of the distribution 38*da0073e9SAndroid Build Coastguard Worker (often referred to as beta) 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 41*da0073e9SAndroid Build Coastguard Worker "concentration1": constraints.positive, 42*da0073e9SAndroid Build Coastguard Worker "concentration0": constraints.positive, 43*da0073e9SAndroid Build Coastguard Worker } 44*da0073e9SAndroid Build Coastguard Worker support = constraints.unit_interval 45*da0073e9SAndroid Build Coastguard Worker has_rsample = True 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def __init__(self, concentration1, concentration0, validate_args=None): 48*da0073e9SAndroid Build Coastguard Worker self.concentration1, self.concentration0 = broadcast_all( 49*da0073e9SAndroid Build Coastguard Worker concentration1, concentration0 50*da0073e9SAndroid Build Coastguard Worker ) 51*da0073e9SAndroid Build Coastguard Worker finfo = torch.finfo(self.concentration0.dtype) 52*da0073e9SAndroid Build Coastguard Worker base_dist = Uniform( 53*da0073e9SAndroid Build Coastguard Worker torch.full_like(self.concentration0, 0), 54*da0073e9SAndroid Build Coastguard Worker torch.full_like(self.concentration0, 1), 55*da0073e9SAndroid Build Coastguard Worker validate_args=validate_args, 56*da0073e9SAndroid Build Coastguard Worker ) 57*da0073e9SAndroid Build Coastguard Worker transforms = [ 58*da0073e9SAndroid Build Coastguard Worker PowerTransform(exponent=self.concentration0.reciprocal()), 59*da0073e9SAndroid Build Coastguard Worker AffineTransform(loc=1.0, scale=-1.0), 60*da0073e9SAndroid Build Coastguard Worker PowerTransform(exponent=self.concentration1.reciprocal()), 61*da0073e9SAndroid Build Coastguard Worker ] 62*da0073e9SAndroid Build Coastguard Worker super().__init__(base_dist, transforms, validate_args=validate_args) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 65*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Kumaraswamy, _instance) 66*da0073e9SAndroid Build Coastguard Worker new.concentration1 = self.concentration1.expand(batch_shape) 67*da0073e9SAndroid Build Coastguard Worker new.concentration0 = self.concentration0.expand(batch_shape) 68*da0073e9SAndroid Build Coastguard Worker return super().expand(batch_shape, _instance=new) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker @property 71*da0073e9SAndroid Build Coastguard Worker def mean(self): 72*da0073e9SAndroid Build Coastguard Worker return _moments(self.concentration1, self.concentration0, 1) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker @property 75*da0073e9SAndroid Build Coastguard Worker def mode(self): 76*da0073e9SAndroid Build Coastguard Worker # Evaluate in log-space for numerical stability. 77*da0073e9SAndroid Build Coastguard Worker log_mode = ( 78*da0073e9SAndroid Build Coastguard Worker self.concentration0.reciprocal() * (-self.concentration0).log1p() 79*da0073e9SAndroid Build Coastguard Worker - (-self.concentration0 * self.concentration1).log1p() 80*da0073e9SAndroid Build Coastguard Worker ) 81*da0073e9SAndroid Build Coastguard Worker log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan 82*da0073e9SAndroid Build Coastguard Worker return log_mode.exp() 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @property 85*da0073e9SAndroid Build Coastguard Worker def variance(self): 86*da0073e9SAndroid Build Coastguard Worker return _moments(self.concentration1, self.concentration0, 2) - torch.pow( 87*da0073e9SAndroid Build Coastguard Worker self.mean, 2 88*da0073e9SAndroid Build Coastguard Worker ) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def entropy(self): 91*da0073e9SAndroid Build Coastguard Worker t1 = 1 - self.concentration1.reciprocal() 92*da0073e9SAndroid Build Coastguard Worker t0 = 1 - self.concentration0.reciprocal() 93*da0073e9SAndroid Build Coastguard Worker H0 = torch.digamma(self.concentration0 + 1) + euler_constant 94*da0073e9SAndroid Build Coastguard Worker return ( 95*da0073e9SAndroid Build Coastguard Worker t0 96*da0073e9SAndroid Build Coastguard Worker + t1 * H0 97*da0073e9SAndroid Build Coastguard Worker - torch.log(self.concentration1) 98*da0073e9SAndroid Build Coastguard Worker - torch.log(self.concentration0) 99*da0073e9SAndroid Build Coastguard Worker ) 100