xref: /aosp_15_r20/external/pytorch/torch/distributions/inverse_gamma.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.distributions import constraints
4from torch.distributions.gamma import Gamma
5from torch.distributions.transformed_distribution import TransformedDistribution
6from torch.distributions.transforms import PowerTransform
7
8
9__all__ = ["InverseGamma"]
10
11
12class InverseGamma(TransformedDistribution):
13    r"""
14    Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate`
15    where::
16
17        X ~ Gamma(concentration, rate)
18        Y = 1 / X ~ InverseGamma(concentration, rate)
19
20    Example::
21
22        >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
23        >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))
24        >>> m.sample()
25        tensor([ 1.2953])
26
27    Args:
28        concentration (float or Tensor): shape parameter of the distribution
29            (often referred to as alpha)
30        rate (float or Tensor): rate = 1 / scale of the distribution
31            (often referred to as beta)
32    """
33    arg_constraints = {
34        "concentration": constraints.positive,
35        "rate": constraints.positive,
36    }
37    support = constraints.positive
38    has_rsample = True
39
40    def __init__(self, concentration, rate, validate_args=None):
41        base_dist = Gamma(concentration, rate, validate_args=validate_args)
42        neg_one = -base_dist.rate.new_ones(())
43        super().__init__(
44            base_dist, PowerTransform(neg_one), validate_args=validate_args
45        )
46
47    def expand(self, batch_shape, _instance=None):
48        new = self._get_checked_instance(InverseGamma, _instance)
49        return super().expand(batch_shape, _instance=new)
50
51    @property
52    def concentration(self):
53        return self.base_dist.concentration
54
55    @property
56    def rate(self):
57        return self.base_dist.rate
58
59    @property
60    def mean(self):
61        result = self.rate / (self.concentration - 1)
62        return torch.where(self.concentration > 1, result, torch.inf)
63
64    @property
65    def mode(self):
66        return self.rate / (self.concentration + 1)
67
68    @property
69    def variance(self):
70        result = self.rate.square() / (
71            (self.concentration - 1).square() * (self.concentration - 2)
72        )
73        return torch.where(self.concentration > 2, result, torch.inf)
74
75    def entropy(self):
76        return (
77            self.concentration
78            + self.rate.log()
79            + self.concentration.lgamma()
80            - (1 + self.concentration) * self.concentration.digamma()
81        )
82