1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Chi2, constraints 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import _standard_normal, broadcast_all 9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker__all__ = ["StudentT"] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass StudentT(Distribution): 16*da0073e9SAndroid Build Coastguard Worker r""" 17*da0073e9SAndroid Build Coastguard Worker Creates a Student's t-distribution parameterized by degree of 18*da0073e9SAndroid Build Coastguard Worker freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker Example:: 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 23*da0073e9SAndroid Build Coastguard Worker >>> m = StudentT(torch.tensor([2.0])) 24*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Student's t-distributed with degrees of freedom=2 25*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1046]) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker Args: 28*da0073e9SAndroid Build Coastguard Worker df (float or Tensor): degrees of freedom 29*da0073e9SAndroid Build Coastguard Worker loc (float or Tensor): mean of the distribution 30*da0073e9SAndroid Build Coastguard Worker scale (float or Tensor): scale of the distribution 31*da0073e9SAndroid Build Coastguard Worker """ 32*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 33*da0073e9SAndroid Build Coastguard Worker "df": constraints.positive, 34*da0073e9SAndroid Build Coastguard Worker "loc": constraints.real, 35*da0073e9SAndroid Build Coastguard Worker "scale": constraints.positive, 36*da0073e9SAndroid Build Coastguard Worker } 37*da0073e9SAndroid Build Coastguard Worker support = constraints.real 38*da0073e9SAndroid Build Coastguard Worker has_rsample = True 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker @property 41*da0073e9SAndroid Build Coastguard Worker def mean(self): 42*da0073e9SAndroid Build Coastguard Worker m = self.loc.clone(memory_format=torch.contiguous_format) 43*da0073e9SAndroid Build Coastguard Worker m[self.df <= 1] = nan 44*da0073e9SAndroid Build Coastguard Worker return m 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker @property 47*da0073e9SAndroid Build Coastguard Worker def mode(self): 48*da0073e9SAndroid Build Coastguard Worker return self.loc 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker @property 51*da0073e9SAndroid Build Coastguard Worker def variance(self): 52*da0073e9SAndroid Build Coastguard Worker m = self.df.clone(memory_format=torch.contiguous_format) 53*da0073e9SAndroid Build Coastguard Worker m[self.df > 2] = ( 54*da0073e9SAndroid Build Coastguard Worker self.scale[self.df > 2].pow(2) 55*da0073e9SAndroid Build Coastguard Worker * self.df[self.df > 2] 56*da0073e9SAndroid Build Coastguard Worker / (self.df[self.df > 2] - 2) 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker m[(self.df <= 2) & (self.df > 1)] = inf 59*da0073e9SAndroid Build Coastguard Worker m[self.df <= 1] = nan 60*da0073e9SAndroid Build Coastguard Worker return m 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): 63*da0073e9SAndroid Build Coastguard Worker self.df, self.loc, self.scale = broadcast_all(df, loc, scale) 64*da0073e9SAndroid Build Coastguard Worker self._chi2 = Chi2(self.df) 65*da0073e9SAndroid Build Coastguard Worker batch_shape = self.df.size() 66*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, validate_args=validate_args) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 69*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(StudentT, _instance) 70*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 71*da0073e9SAndroid Build Coastguard Worker new.df = self.df.expand(batch_shape) 72*da0073e9SAndroid Build Coastguard Worker new.loc = self.loc.expand(batch_shape) 73*da0073e9SAndroid Build Coastguard Worker new.scale = self.scale.expand(batch_shape) 74*da0073e9SAndroid Build Coastguard Worker new._chi2 = self._chi2.expand(batch_shape) 75*da0073e9SAndroid Build Coastguard Worker super(StudentT, new).__init__(batch_shape, validate_args=False) 76*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 77*da0073e9SAndroid Build Coastguard Worker return new 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 80*da0073e9SAndroid Build Coastguard Worker # NOTE: This does not agree with scipy implementation as much as other distributions. 81*da0073e9SAndroid Build Coastguard Worker # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor 82*da0073e9SAndroid Build Coastguard Worker # parameters seems to help. 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker # X ~ Normal(0, 1) 85*da0073e9SAndroid Build Coastguard Worker # Z ~ Chi2(df) 86*da0073e9SAndroid Build Coastguard Worker # Y = X / sqrt(Z / df) ~ StudentT(df) 87*da0073e9SAndroid Build Coastguard Worker shape = self._extended_shape(sample_shape) 88*da0073e9SAndroid Build Coastguard Worker X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device) 89*da0073e9SAndroid Build Coastguard Worker Z = self._chi2.rsample(sample_shape) 90*da0073e9SAndroid Build Coastguard Worker Y = X * torch.rsqrt(Z / self.df) 91*da0073e9SAndroid Build Coastguard Worker return self.loc + self.scale * Y 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 94*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 95*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 96*da0073e9SAndroid Build Coastguard Worker y = (value - self.loc) / self.scale 97*da0073e9SAndroid Build Coastguard Worker Z = ( 98*da0073e9SAndroid Build Coastguard Worker self.scale.log() 99*da0073e9SAndroid Build Coastguard Worker + 0.5 * self.df.log() 100*da0073e9SAndroid Build Coastguard Worker + 0.5 * math.log(math.pi) 101*da0073e9SAndroid Build Coastguard Worker + torch.lgamma(0.5 * self.df) 102*da0073e9SAndroid Build Coastguard Worker - torch.lgamma(0.5 * (self.df + 1.0)) 103*da0073e9SAndroid Build Coastguard Worker ) 104*da0073e9SAndroid Build Coastguard Worker return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def entropy(self): 107*da0073e9SAndroid Build Coastguard Worker lbeta = ( 108*da0073e9SAndroid Build Coastguard Worker torch.lgamma(0.5 * self.df) 109*da0073e9SAndroid Build Coastguard Worker + math.lgamma(0.5) 110*da0073e9SAndroid Build Coastguard Worker - torch.lgamma(0.5 * (self.df + 1)) 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker return ( 113*da0073e9SAndroid Build Coastguard Worker self.scale.log() 114*da0073e9SAndroid Build Coastguard Worker + 0.5 115*da0073e9SAndroid Build Coastguard Worker * (self.df + 1) 116*da0073e9SAndroid Build Coastguard Worker * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) 117*da0073e9SAndroid Build Coastguard Worker + 0.5 * self.df.log() 118*da0073e9SAndroid Build Coastguard Worker + lbeta 119*da0073e9SAndroid Build Coastguard Worker ) 120