1# mypy: allow-untyped-defs 2import math 3 4import torch 5from torch import inf, nan 6from torch.distributions import Chi2, constraints 7from torch.distributions.distribution import Distribution 8from torch.distributions.utils import _standard_normal, broadcast_all 9from torch.types import _size 10 11 12__all__ = ["StudentT"] 13 14 15class StudentT(Distribution): 16 r""" 17 Creates a Student's t-distribution parameterized by degree of 18 freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. 19 20 Example:: 21 22 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 23 >>> m = StudentT(torch.tensor([2.0])) 24 >>> m.sample() # Student's t-distributed with degrees of freedom=2 25 tensor([ 0.1046]) 26 27 Args: 28 df (float or Tensor): degrees of freedom 29 loc (float or Tensor): mean of the distribution 30 scale (float or Tensor): scale of the distribution 31 """ 32 arg_constraints = { 33 "df": constraints.positive, 34 "loc": constraints.real, 35 "scale": constraints.positive, 36 } 37 support = constraints.real 38 has_rsample = True 39 40 @property 41 def mean(self): 42 m = self.loc.clone(memory_format=torch.contiguous_format) 43 m[self.df <= 1] = nan 44 return m 45 46 @property 47 def mode(self): 48 return self.loc 49 50 @property 51 def variance(self): 52 m = self.df.clone(memory_format=torch.contiguous_format) 53 m[self.df > 2] = ( 54 self.scale[self.df > 2].pow(2) 55 * self.df[self.df > 2] 56 / (self.df[self.df > 2] - 2) 57 ) 58 m[(self.df <= 2) & (self.df > 1)] = inf 59 m[self.df <= 1] = nan 60 return m 61 62 def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): 63 self.df, self.loc, self.scale = broadcast_all(df, loc, scale) 64 self._chi2 = Chi2(self.df) 65 batch_shape = self.df.size() 66 super().__init__(batch_shape, validate_args=validate_args) 67 68 def expand(self, batch_shape, _instance=None): 69 new = self._get_checked_instance(StudentT, _instance) 70 batch_shape = torch.Size(batch_shape) 71 new.df = self.df.expand(batch_shape) 72 new.loc = self.loc.expand(batch_shape) 73 new.scale = self.scale.expand(batch_shape) 74 new._chi2 = self._chi2.expand(batch_shape) 75 super(StudentT, new).__init__(batch_shape, validate_args=False) 76 new._validate_args = self._validate_args 77 return new 78 79 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 80 # NOTE: This does not agree with scipy implementation as much as other distributions. 81 # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor 82 # parameters seems to help. 83 84 # X ~ Normal(0, 1) 85 # Z ~ Chi2(df) 86 # Y = X / sqrt(Z / df) ~ StudentT(df) 87 shape = self._extended_shape(sample_shape) 88 X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device) 89 Z = self._chi2.rsample(sample_shape) 90 Y = X * torch.rsqrt(Z / self.df) 91 return self.loc + self.scale * Y 92 93 def log_prob(self, value): 94 if self._validate_args: 95 self._validate_sample(value) 96 y = (value - self.loc) / self.scale 97 Z = ( 98 self.scale.log() 99 + 0.5 * self.df.log() 100 + 0.5 * math.log(math.pi) 101 + torch.lgamma(0.5 * self.df) 102 - torch.lgamma(0.5 * (self.df + 1.0)) 103 ) 104 return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z 105 106 def entropy(self): 107 lbeta = ( 108 torch.lgamma(0.5 * self.df) 109 + math.lgamma(0.5) 110 - torch.lgamma(0.5 * (self.df + 1)) 111 ) 112 return ( 113 self.scale.log() 114 + 0.5 115 * (self.df + 1) 116 * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) 117 + 0.5 * self.df.log() 118 + lbeta 119 ) 120