1# mypy: allow-untyped-defs 2import torch 3from torch.distributions import constraints 4from torch.distributions.gamma import Gamma 5from torch.distributions.transformed_distribution import TransformedDistribution 6from torch.distributions.transforms import PowerTransform 7 8 9__all__ = ["InverseGamma"] 10 11 12class InverseGamma(TransformedDistribution): 13 r""" 14 Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate` 15 where:: 16 17 X ~ Gamma(concentration, rate) 18 Y = 1 / X ~ InverseGamma(concentration, rate) 19 20 Example:: 21 22 >>> # xdoctest: +IGNORE_WANT("non-deterinistic") 23 >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) 24 >>> m.sample() 25 tensor([ 1.2953]) 26 27 Args: 28 concentration (float or Tensor): shape parameter of the distribution 29 (often referred to as alpha) 30 rate (float or Tensor): rate = 1 / scale of the distribution 31 (often referred to as beta) 32 """ 33 arg_constraints = { 34 "concentration": constraints.positive, 35 "rate": constraints.positive, 36 } 37 support = constraints.positive 38 has_rsample = True 39 40 def __init__(self, concentration, rate, validate_args=None): 41 base_dist = Gamma(concentration, rate, validate_args=validate_args) 42 neg_one = -base_dist.rate.new_ones(()) 43 super().__init__( 44 base_dist, PowerTransform(neg_one), validate_args=validate_args 45 ) 46 47 def expand(self, batch_shape, _instance=None): 48 new = self._get_checked_instance(InverseGamma, _instance) 49 return super().expand(batch_shape, _instance=new) 50 51 @property 52 def concentration(self): 53 return self.base_dist.concentration 54 55 @property 56 def rate(self): 57 return self.base_dist.rate 58 59 @property 60 def mean(self): 61 result = self.rate / (self.concentration - 1) 62 return torch.where(self.concentration > 1, result, torch.inf) 63 64 @property 65 def mode(self): 66 return self.rate / (self.concentration + 1) 67 68 @property 69 def variance(self): 70 result = self.rate.square() / ( 71 (self.concentration - 1).square() * (self.concentration - 2) 72 ) 73 return torch.where(self.concentration > 2, result, torch.inf) 74 75 def entropy(self): 76 return ( 77 self.concentration 78 + self.rate.log() 79 + self.concentration.lgamma() 80 - (1 + self.concentration) * self.concentration.digamma() 81 ) 82