1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 3*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exponential import Exponential 4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transformed_distribution import TransformedDistribution 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import AffineTransform, ExpTransform 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker__all__ = ["Pareto"] 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass Pareto(TransformedDistribution): 13*da0073e9SAndroid Build Coastguard Worker r""" 14*da0073e9SAndroid Build Coastguard Worker Samples from a Pareto Type 1 distribution. 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker Example:: 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 19*da0073e9SAndroid Build Coastguard Worker >>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) 20*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 21*da0073e9SAndroid Build Coastguard Worker tensor([ 1.5623]) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker Args: 24*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): Scale parameter of the distribution 25*da0073e9SAndroid Build Coastguard Worker alpha (float or Tensor): Shape parameter of the distribution 26*da0073e9SAndroid Build Coastguard Worker """ 27*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def __init__(self, scale, alpha, validate_args=None): 30*da0073e9SAndroid Build Coastguard Worker self.scale, self.alpha = broadcast_all(scale, alpha) 31*da0073e9SAndroid Build Coastguard Worker base_dist = Exponential(self.alpha, validate_args=validate_args) 32*da0073e9SAndroid Build Coastguard Worker transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] 33*da0073e9SAndroid Build Coastguard Worker super().__init__(base_dist, transforms, validate_args=validate_args) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 36*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Pareto, _instance) 37*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 38*da0073e9SAndroid Build Coastguard Worker new.alpha = self.alpha.expand(batch_shape) 39*da0073e9SAndroid Build Coastguard Worker return super().expand(batch_shape, _instance=new) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker @property 42*da0073e9SAndroid Build Coastguard Worker def mean(self): 43*da0073e9SAndroid Build Coastguard Worker # mean is inf for alpha <= 1 44*da0073e9SAndroid Build Coastguard Worker a = self.alpha.clamp(min=1) 45*da0073e9SAndroid Build Coastguard Worker return a * self.scale / (a - 1) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker @property 48*da0073e9SAndroid Build Coastguard Worker def mode(self): 49*da0073e9SAndroid Build Coastguard Worker return self.scale 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker @property 52*da0073e9SAndroid Build Coastguard Worker def variance(self): 53*da0073e9SAndroid Build Coastguard Worker # var is inf for alpha <= 2 54*da0073e9SAndroid Build Coastguard Worker a = self.alpha.clamp(min=2) 55*da0073e9SAndroid Build Coastguard Worker return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2)) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker @constraints.dependent_property(is_discrete=False, event_dim=0) 58*da0073e9SAndroid Build Coastguard Worker def support(self): 59*da0073e9SAndroid Build Coastguard Worker return constraints.greater_than_eq(self.scale) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def entropy(self): 62*da0073e9SAndroid Build Coastguard Worker return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()) 63