xref: /aosp_15_r20/external/pytorch/torch/distributions/studentT.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport math
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Chi2, constraints
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution
8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import _standard_normal, broadcast_all
9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker__all__ = ["StudentT"]
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass StudentT(Distribution):
16*da0073e9SAndroid Build Coastguard Worker    r"""
17*da0073e9SAndroid Build Coastguard Worker    Creates a Student's t-distribution parameterized by degree of
18*da0073e9SAndroid Build Coastguard Worker    freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    Example::
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
23*da0073e9SAndroid Build Coastguard Worker        >>> m = StudentT(torch.tensor([2.0]))
24*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # Student's t-distributed with degrees of freedom=2
25*da0073e9SAndroid Build Coastguard Worker        tensor([ 0.1046])
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    Args:
28*da0073e9SAndroid Build Coastguard Worker        df (float or Tensor): degrees of freedom
29*da0073e9SAndroid Build Coastguard Worker        loc (float or Tensor): mean of the distribution
30*da0073e9SAndroid Build Coastguard Worker        scale (float or Tensor): scale of the distribution
31*da0073e9SAndroid Build Coastguard Worker    """
32*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {
33*da0073e9SAndroid Build Coastguard Worker        "df": constraints.positive,
34*da0073e9SAndroid Build Coastguard Worker        "loc": constraints.real,
35*da0073e9SAndroid Build Coastguard Worker        "scale": constraints.positive,
36*da0073e9SAndroid Build Coastguard Worker    }
37*da0073e9SAndroid Build Coastguard Worker    support = constraints.real
38*da0073e9SAndroid Build Coastguard Worker    has_rsample = True
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    @property
41*da0073e9SAndroid Build Coastguard Worker    def mean(self):
42*da0073e9SAndroid Build Coastguard Worker        m = self.loc.clone(memory_format=torch.contiguous_format)
43*da0073e9SAndroid Build Coastguard Worker        m[self.df <= 1] = nan
44*da0073e9SAndroid Build Coastguard Worker        return m
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    @property
47*da0073e9SAndroid Build Coastguard Worker    def mode(self):
48*da0073e9SAndroid Build Coastguard Worker        return self.loc
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    @property
51*da0073e9SAndroid Build Coastguard Worker    def variance(self):
52*da0073e9SAndroid Build Coastguard Worker        m = self.df.clone(memory_format=torch.contiguous_format)
53*da0073e9SAndroid Build Coastguard Worker        m[self.df > 2] = (
54*da0073e9SAndroid Build Coastguard Worker            self.scale[self.df > 2].pow(2)
55*da0073e9SAndroid Build Coastguard Worker            * self.df[self.df > 2]
56*da0073e9SAndroid Build Coastguard Worker            / (self.df[self.df > 2] - 2)
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker        m[(self.df <= 2) & (self.df > 1)] = inf
59*da0073e9SAndroid Build Coastguard Worker        m[self.df <= 1] = nan
60*da0073e9SAndroid Build Coastguard Worker        return m
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
63*da0073e9SAndroid Build Coastguard Worker        self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
64*da0073e9SAndroid Build Coastguard Worker        self._chi2 = Chi2(self.df)
65*da0073e9SAndroid Build Coastguard Worker        batch_shape = self.df.size()
66*da0073e9SAndroid Build Coastguard Worker        super().__init__(batch_shape, validate_args=validate_args)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
69*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(StudentT, _instance)
70*da0073e9SAndroid Build Coastguard Worker        batch_shape = torch.Size(batch_shape)
71*da0073e9SAndroid Build Coastguard Worker        new.df = self.df.expand(batch_shape)
72*da0073e9SAndroid Build Coastguard Worker        new.loc = self.loc.expand(batch_shape)
73*da0073e9SAndroid Build Coastguard Worker        new.scale = self.scale.expand(batch_shape)
74*da0073e9SAndroid Build Coastguard Worker        new._chi2 = self._chi2.expand(batch_shape)
75*da0073e9SAndroid Build Coastguard Worker        super(StudentT, new).__init__(batch_shape, validate_args=False)
76*da0073e9SAndroid Build Coastguard Worker        new._validate_args = self._validate_args
77*da0073e9SAndroid Build Coastguard Worker        return new
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
80*da0073e9SAndroid Build Coastguard Worker        # NOTE: This does not agree with scipy implementation as much as other distributions.
81*da0073e9SAndroid Build Coastguard Worker        # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
82*da0073e9SAndroid Build Coastguard Worker        # parameters seems to help.
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        #   X ~ Normal(0, 1)
85*da0073e9SAndroid Build Coastguard Worker        #   Z ~ Chi2(df)
86*da0073e9SAndroid Build Coastguard Worker        #   Y = X / sqrt(Z / df) ~ StudentT(df)
87*da0073e9SAndroid Build Coastguard Worker        shape = self._extended_shape(sample_shape)
88*da0073e9SAndroid Build Coastguard Worker        X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
89*da0073e9SAndroid Build Coastguard Worker        Z = self._chi2.rsample(sample_shape)
90*da0073e9SAndroid Build Coastguard Worker        Y = X * torch.rsqrt(Z / self.df)
91*da0073e9SAndroid Build Coastguard Worker        return self.loc + self.scale * Y
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value):
94*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
95*da0073e9SAndroid Build Coastguard Worker            self._validate_sample(value)
96*da0073e9SAndroid Build Coastguard Worker        y = (value - self.loc) / self.scale
97*da0073e9SAndroid Build Coastguard Worker        Z = (
98*da0073e9SAndroid Build Coastguard Worker            self.scale.log()
99*da0073e9SAndroid Build Coastguard Worker            + 0.5 * self.df.log()
100*da0073e9SAndroid Build Coastguard Worker            + 0.5 * math.log(math.pi)
101*da0073e9SAndroid Build Coastguard Worker            + torch.lgamma(0.5 * self.df)
102*da0073e9SAndroid Build Coastguard Worker            - torch.lgamma(0.5 * (self.df + 1.0))
103*da0073e9SAndroid Build Coastguard Worker        )
104*da0073e9SAndroid Build Coastguard Worker        return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
107*da0073e9SAndroid Build Coastguard Worker        lbeta = (
108*da0073e9SAndroid Build Coastguard Worker            torch.lgamma(0.5 * self.df)
109*da0073e9SAndroid Build Coastguard Worker            + math.lgamma(0.5)
110*da0073e9SAndroid Build Coastguard Worker            - torch.lgamma(0.5 * (self.df + 1))
111*da0073e9SAndroid Build Coastguard Worker        )
112*da0073e9SAndroid Build Coastguard Worker        return (
113*da0073e9SAndroid Build Coastguard Worker            self.scale.log()
114*da0073e9SAndroid Build Coastguard Worker            + 0.5
115*da0073e9SAndroid Build Coastguard Worker            * (self.df + 1)
116*da0073e9SAndroid Build Coastguard Worker            * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
117*da0073e9SAndroid Build Coastguard Worker            + 0.5 * self.df.log()
118*da0073e9SAndroid Build Coastguard Worker            + lbeta
119*da0073e9SAndroid Build Coastguard Worker        )
120