1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport math 3*da0073e9SAndroid Build Coastguard Workerimport warnings 4*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Union 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch import nan 9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 10*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily 11*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.multivariate_normal import _precision_to_scale_tril 12*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import lazy_property 13*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker__all__ = ["Wishart"] 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker_log_2 = math.log(2) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerdef _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor: 22*da0073e9SAndroid Build Coastguard Worker assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function." 23*da0073e9SAndroid Build Coastguard Worker return torch.digamma( 24*da0073e9SAndroid Build Coastguard Worker x.unsqueeze(-1) 25*da0073e9SAndroid Build Coastguard Worker - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) 26*da0073e9SAndroid Build Coastguard Worker ).sum(-1) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef _clamp_above_eps(x: torch.Tensor) -> torch.Tensor: 30*da0073e9SAndroid Build Coastguard Worker # We assume positive input for this function 31*da0073e9SAndroid Build Coastguard Worker return x.clamp(min=torch.finfo(x.dtype).eps) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerclass Wishart(ExponentialFamily): 35*da0073e9SAndroid Build Coastguard Worker r""" 36*da0073e9SAndroid Build Coastguard Worker Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, 37*da0073e9SAndroid Build Coastguard Worker or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker Example: 40*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") 41*da0073e9SAndroid Build Coastguard Worker >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) 42*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Wishart distributed with mean=`df * I` and 43*da0073e9SAndroid Build Coastguard Worker >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker Args: 46*da0073e9SAndroid Build Coastguard Worker df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 47*da0073e9SAndroid Build Coastguard Worker covariance_matrix (Tensor): positive-definite covariance matrix 48*da0073e9SAndroid Build Coastguard Worker precision_matrix (Tensor): positive-definite precision matrix 49*da0073e9SAndroid Build Coastguard Worker scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal 50*da0073e9SAndroid Build Coastguard Worker Note: 51*da0073e9SAndroid Build Coastguard Worker Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or 52*da0073e9SAndroid Build Coastguard Worker :attr:`scale_tril` can be specified. 53*da0073e9SAndroid Build Coastguard Worker Using :attr:`scale_tril` will be more efficient: all computations internally 54*da0073e9SAndroid Build Coastguard Worker are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or 55*da0073e9SAndroid Build Coastguard Worker :attr:`precision_matrix` is passed instead, it is only used to compute 56*da0073e9SAndroid Build Coastguard Worker the corresponding lower triangular matrices using a Cholesky decomposition. 57*da0073e9SAndroid Build Coastguard Worker 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker **References** 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. 62*da0073e9SAndroid Build Coastguard Worker [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. 63*da0073e9SAndroid Build Coastguard Worker [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. 64*da0073e9SAndroid Build Coastguard Worker [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. 65*da0073e9SAndroid Build Coastguard Worker [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. 66*da0073e9SAndroid Build Coastguard Worker """ 67*da0073e9SAndroid Build Coastguard Worker arg_constraints = { 68*da0073e9SAndroid Build Coastguard Worker "covariance_matrix": constraints.positive_definite, 69*da0073e9SAndroid Build Coastguard Worker "precision_matrix": constraints.positive_definite, 70*da0073e9SAndroid Build Coastguard Worker "scale_tril": constraints.lower_cholesky, 71*da0073e9SAndroid Build Coastguard Worker "df": constraints.greater_than(0), 72*da0073e9SAndroid Build Coastguard Worker } 73*da0073e9SAndroid Build Coastguard Worker support = constraints.positive_definite 74*da0073e9SAndroid Build Coastguard Worker has_rsample = True 75*da0073e9SAndroid Build Coastguard Worker _mean_carrier_measure = 0 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def __init__( 78*da0073e9SAndroid Build Coastguard Worker self, 79*da0073e9SAndroid Build Coastguard Worker df: Union[torch.Tensor, Number], 80*da0073e9SAndroid Build Coastguard Worker covariance_matrix: Optional[torch.Tensor] = None, 81*da0073e9SAndroid Build Coastguard Worker precision_matrix: Optional[torch.Tensor] = None, 82*da0073e9SAndroid Build Coastguard Worker scale_tril: Optional[torch.Tensor] = None, 83*da0073e9SAndroid Build Coastguard Worker validate_args=None, 84*da0073e9SAndroid Build Coastguard Worker ): 85*da0073e9SAndroid Build Coastguard Worker assert (covariance_matrix is not None) + (scale_tril is not None) + ( 86*da0073e9SAndroid Build Coastguard Worker precision_matrix is not None 87*da0073e9SAndroid Build Coastguard Worker ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker param = next( 90*da0073e9SAndroid Build Coastguard Worker p 91*da0073e9SAndroid Build Coastguard Worker for p in (covariance_matrix, precision_matrix, scale_tril) 92*da0073e9SAndroid Build Coastguard Worker if p is not None 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker if param.dim() < 2: 96*da0073e9SAndroid Build Coastguard Worker raise ValueError( 97*da0073e9SAndroid Build Coastguard Worker "scale_tril must be at least two-dimensional, with optional leading batch dimensions" 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker if isinstance(df, Number): 101*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(param.shape[:-2]) 102*da0073e9SAndroid Build Coastguard Worker self.df = torch.tensor(df, dtype=param.dtype, device=param.device) 103*da0073e9SAndroid Build Coastguard Worker else: 104*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) 105*da0073e9SAndroid Build Coastguard Worker self.df = df.expand(batch_shape) 106*da0073e9SAndroid Build Coastguard Worker event_shape = param.shape[-2:] 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker if self.df.le(event_shape[-1] - 1).any(): 109*da0073e9SAndroid Build Coastguard Worker raise ValueError( 110*da0073e9SAndroid Build Coastguard Worker f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker if scale_tril is not None: 114*da0073e9SAndroid Build Coastguard Worker self.scale_tril = param.expand(batch_shape + (-1, -1)) 115*da0073e9SAndroid Build Coastguard Worker elif covariance_matrix is not None: 116*da0073e9SAndroid Build Coastguard Worker self.covariance_matrix = param.expand(batch_shape + (-1, -1)) 117*da0073e9SAndroid Build Coastguard Worker elif precision_matrix is not None: 118*da0073e9SAndroid Build Coastguard Worker self.precision_matrix = param.expand(batch_shape + (-1, -1)) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1) 121*da0073e9SAndroid Build Coastguard Worker if self.df.lt(event_shape[-1]).any(): 122*da0073e9SAndroid Build Coastguard Worker warnings.warn( 123*da0073e9SAndroid Build Coastguard Worker "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, event_shape, validate_args=validate_args) 127*da0073e9SAndroid Build Coastguard Worker self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker if scale_tril is not None: 130*da0073e9SAndroid Build Coastguard Worker self._unbroadcasted_scale_tril = scale_tril 131*da0073e9SAndroid Build Coastguard Worker elif covariance_matrix is not None: 132*da0073e9SAndroid Build Coastguard Worker self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) 133*da0073e9SAndroid Build Coastguard Worker else: # precision_matrix is not None 134*da0073e9SAndroid Build Coastguard Worker self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker # Chi2 distribution is needed for Bartlett decomposition sampling 137*da0073e9SAndroid Build Coastguard Worker self._dist_chi2 = torch.distributions.chi2.Chi2( 138*da0073e9SAndroid Build Coastguard Worker df=( 139*da0073e9SAndroid Build Coastguard Worker self.df.unsqueeze(-1) 140*da0073e9SAndroid Build Coastguard Worker - torch.arange( 141*da0073e9SAndroid Build Coastguard Worker self._event_shape[-1], 142*da0073e9SAndroid Build Coastguard Worker dtype=self._unbroadcasted_scale_tril.dtype, 143*da0073e9SAndroid Build Coastguard Worker device=self._unbroadcasted_scale_tril.device, 144*da0073e9SAndroid Build Coastguard Worker ).expand(batch_shape + (-1,)) 145*da0073e9SAndroid Build Coastguard Worker ) 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 149*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Wishart, _instance) 150*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 151*da0073e9SAndroid Build Coastguard Worker cov_shape = batch_shape + self.event_shape 152*da0073e9SAndroid Build Coastguard Worker new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape) 153*da0073e9SAndroid Build Coastguard Worker new.df = self.df.expand(batch_shape) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker if "covariance_matrix" in self.__dict__: 158*da0073e9SAndroid Build Coastguard Worker new.covariance_matrix = self.covariance_matrix.expand(cov_shape) 159*da0073e9SAndroid Build Coastguard Worker if "scale_tril" in self.__dict__: 160*da0073e9SAndroid Build Coastguard Worker new.scale_tril = self.scale_tril.expand(cov_shape) 161*da0073e9SAndroid Build Coastguard Worker if "precision_matrix" in self.__dict__: 162*da0073e9SAndroid Build Coastguard Worker new.precision_matrix = self.precision_matrix.expand(cov_shape) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker # Chi2 distribution is needed for Bartlett decomposition sampling 165*da0073e9SAndroid Build Coastguard Worker new._dist_chi2 = torch.distributions.chi2.Chi2( 166*da0073e9SAndroid Build Coastguard Worker df=( 167*da0073e9SAndroid Build Coastguard Worker new.df.unsqueeze(-1) 168*da0073e9SAndroid Build Coastguard Worker - torch.arange( 169*da0073e9SAndroid Build Coastguard Worker self.event_shape[-1], 170*da0073e9SAndroid Build Coastguard Worker dtype=new._unbroadcasted_scale_tril.dtype, 171*da0073e9SAndroid Build Coastguard Worker device=new._unbroadcasted_scale_tril.device, 172*da0073e9SAndroid Build Coastguard Worker ).expand(batch_shape + (-1,)) 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) 177*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 178*da0073e9SAndroid Build Coastguard Worker return new 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @lazy_property 181*da0073e9SAndroid Build Coastguard Worker def scale_tril(self): 182*da0073e9SAndroid Build Coastguard Worker return self._unbroadcasted_scale_tril.expand( 183*da0073e9SAndroid Build Coastguard Worker self._batch_shape + self._event_shape 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker @lazy_property 187*da0073e9SAndroid Build Coastguard Worker def covariance_matrix(self): 188*da0073e9SAndroid Build Coastguard Worker return ( 189*da0073e9SAndroid Build Coastguard Worker self._unbroadcasted_scale_tril 190*da0073e9SAndroid Build Coastguard Worker @ self._unbroadcasted_scale_tril.transpose(-2, -1) 191*da0073e9SAndroid Build Coastguard Worker ).expand(self._batch_shape + self._event_shape) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker @lazy_property 194*da0073e9SAndroid Build Coastguard Worker def precision_matrix(self): 195*da0073e9SAndroid Build Coastguard Worker identity = torch.eye( 196*da0073e9SAndroid Build Coastguard Worker self._event_shape[-1], 197*da0073e9SAndroid Build Coastguard Worker device=self._unbroadcasted_scale_tril.device, 198*da0073e9SAndroid Build Coastguard Worker dtype=self._unbroadcasted_scale_tril.dtype, 199*da0073e9SAndroid Build Coastguard Worker ) 200*da0073e9SAndroid Build Coastguard Worker return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( 201*da0073e9SAndroid Build Coastguard Worker self._batch_shape + self._event_shape 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker @property 205*da0073e9SAndroid Build Coastguard Worker def mean(self): 206*da0073e9SAndroid Build Coastguard Worker return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker @property 209*da0073e9SAndroid Build Coastguard Worker def mode(self): 210*da0073e9SAndroid Build Coastguard Worker factor = self.df - self.covariance_matrix.shape[-1] - 1 211*da0073e9SAndroid Build Coastguard Worker factor[factor <= 0] = nan 212*da0073e9SAndroid Build Coastguard Worker return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker @property 215*da0073e9SAndroid Build Coastguard Worker def variance(self): 216*da0073e9SAndroid Build Coastguard Worker V = self.covariance_matrix # has shape (batch_shape x event_shape) 217*da0073e9SAndroid Build Coastguard Worker diag_V = V.diagonal(dim1=-2, dim2=-1) 218*da0073e9SAndroid Build Coastguard Worker return self.df.view(self._batch_shape + (1, 1)) * ( 219*da0073e9SAndroid Build Coastguard Worker V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker def _bartlett_sampling(self, sample_shape=torch.Size()): 223*da0073e9SAndroid Build Coastguard Worker p = self._event_shape[-1] # has singleton shape 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker # Implemented Sampling using Bartlett decomposition 226*da0073e9SAndroid Build Coastguard Worker noise = _clamp_above_eps( 227*da0073e9SAndroid Build Coastguard Worker self._dist_chi2.rsample(sample_shape).sqrt() 228*da0073e9SAndroid Build Coastguard Worker ).diag_embed(dim1=-2, dim2=-1) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker i, j = torch.tril_indices(p, p, offset=-1) 231*da0073e9SAndroid Build Coastguard Worker noise[..., i, j] = torch.randn( 232*da0073e9SAndroid Build Coastguard Worker torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),), 233*da0073e9SAndroid Build Coastguard Worker dtype=noise.dtype, 234*da0073e9SAndroid Build Coastguard Worker device=noise.device, 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker chol = self._unbroadcasted_scale_tril @ noise 237*da0073e9SAndroid Build Coastguard Worker return chol @ chol.transpose(-2, -1) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker def rsample( 240*da0073e9SAndroid Build Coastguard Worker self, sample_shape: _size = torch.Size(), max_try_correction=None 241*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 242*da0073e9SAndroid Build Coastguard Worker r""" 243*da0073e9SAndroid Build Coastguard Worker .. warning:: 244*da0073e9SAndroid Build Coastguard Worker In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. 245*da0073e9SAndroid Build Coastguard Worker Several tries to correct singular samples are performed by default, but it may end up returning 246*da0073e9SAndroid Build Coastguard Worker singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. 247*da0073e9SAndroid Build Coastguard Worker In those cases, the user should validate the samples and either fix the value of `df` 248*da0073e9SAndroid Build Coastguard Worker or adjust `max_try_correction` value for argument in `.rsample` accordingly. 249*da0073e9SAndroid Build Coastguard Worker """ 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker if max_try_correction is None: 252*da0073e9SAndroid Build Coastguard Worker max_try_correction = 3 if torch._C._get_tracing_state() else 10 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker sample_shape = torch.Size(sample_shape) 255*da0073e9SAndroid Build Coastguard Worker sample = self._bartlett_sampling(sample_shape) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker # Below part is to improve numerical stability temporally and should be removed in the future 258*da0073e9SAndroid Build Coastguard Worker is_singular = self.support.check(sample) 259*da0073e9SAndroid Build Coastguard Worker if self._batch_shape: 260*da0073e9SAndroid Build Coastguard Worker is_singular = is_singular.amax(self._batch_dims) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker if torch._C._get_tracing_state(): 263*da0073e9SAndroid Build Coastguard Worker # Less optimized version for JIT 264*da0073e9SAndroid Build Coastguard Worker for _ in range(max_try_correction): 265*da0073e9SAndroid Build Coastguard Worker sample_new = self._bartlett_sampling(sample_shape) 266*da0073e9SAndroid Build Coastguard Worker sample = torch.where(is_singular, sample_new, sample) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker is_singular = ~self.support.check(sample) 269*da0073e9SAndroid Build Coastguard Worker if self._batch_shape: 270*da0073e9SAndroid Build Coastguard Worker is_singular = is_singular.amax(self._batch_dims) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker else: 273*da0073e9SAndroid Build Coastguard Worker # More optimized version with data-dependent control flow. 274*da0073e9SAndroid Build Coastguard Worker if is_singular.any(): 275*da0073e9SAndroid Build Coastguard Worker warnings.warn("Singular sample detected.") 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker for _ in range(max_try_correction): 278*da0073e9SAndroid Build Coastguard Worker sample_new = self._bartlett_sampling(is_singular[is_singular].shape) 279*da0073e9SAndroid Build Coastguard Worker sample[is_singular] = sample_new 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker is_singular_new = ~self.support.check(sample_new) 282*da0073e9SAndroid Build Coastguard Worker if self._batch_shape: 283*da0073e9SAndroid Build Coastguard Worker is_singular_new = is_singular_new.amax(self._batch_dims) 284*da0073e9SAndroid Build Coastguard Worker is_singular[is_singular.clone()] = is_singular_new 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker if not is_singular.any(): 287*da0073e9SAndroid Build Coastguard Worker break 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker return sample 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 292*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 293*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 294*da0073e9SAndroid Build Coastguard Worker nu = self.df # has shape (batch_shape) 295*da0073e9SAndroid Build Coastguard Worker p = self._event_shape[-1] # has singleton shape 296*da0073e9SAndroid Build Coastguard Worker return ( 297*da0073e9SAndroid Build Coastguard Worker -nu 298*da0073e9SAndroid Build Coastguard Worker * ( 299*da0073e9SAndroid Build Coastguard Worker p * _log_2 / 2 300*da0073e9SAndroid Build Coastguard Worker + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) 301*da0073e9SAndroid Build Coastguard Worker .log() 302*da0073e9SAndroid Build Coastguard Worker .sum(-1) 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker - torch.mvlgamma(nu / 2, p=p) 305*da0073e9SAndroid Build Coastguard Worker + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet 306*da0073e9SAndroid Build Coastguard Worker - torch.cholesky_solve(value, self._unbroadcasted_scale_tril) 307*da0073e9SAndroid Build Coastguard Worker .diagonal(dim1=-2, dim2=-1) 308*da0073e9SAndroid Build Coastguard Worker .sum(dim=-1) 309*da0073e9SAndroid Build Coastguard Worker / 2 310*da0073e9SAndroid Build Coastguard Worker ) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def entropy(self): 313*da0073e9SAndroid Build Coastguard Worker nu = self.df # has shape (batch_shape) 314*da0073e9SAndroid Build Coastguard Worker p = self._event_shape[-1] # has singleton shape 315*da0073e9SAndroid Build Coastguard Worker V = self.covariance_matrix # has shape (batch_shape x event_shape) 316*da0073e9SAndroid Build Coastguard Worker return ( 317*da0073e9SAndroid Build Coastguard Worker (p + 1) 318*da0073e9SAndroid Build Coastguard Worker * ( 319*da0073e9SAndroid Build Coastguard Worker p * _log_2 / 2 320*da0073e9SAndroid Build Coastguard Worker + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) 321*da0073e9SAndroid Build Coastguard Worker .log() 322*da0073e9SAndroid Build Coastguard Worker .sum(-1) 323*da0073e9SAndroid Build Coastguard Worker ) 324*da0073e9SAndroid Build Coastguard Worker + torch.mvlgamma(nu / 2, p=p) 325*da0073e9SAndroid Build Coastguard Worker - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) 326*da0073e9SAndroid Build Coastguard Worker + nu * p / 2 327*da0073e9SAndroid Build Coastguard Worker ) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker @property 330*da0073e9SAndroid Build Coastguard Worker def _natural_params(self): 331*da0073e9SAndroid Build Coastguard Worker nu = self.df # has shape (batch_shape) 332*da0073e9SAndroid Build Coastguard Worker p = self._event_shape[-1] # has singleton shape 333*da0073e9SAndroid Build Coastguard Worker return -self.precision_matrix / 2, (nu - p - 1) / 2 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def _log_normalizer(self, x, y): 336*da0073e9SAndroid Build Coastguard Worker p = self._event_shape[-1] 337*da0073e9SAndroid Build Coastguard Worker return (y + (p + 1) / 2) * ( 338*da0073e9SAndroid Build Coastguard Worker -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p 339*da0073e9SAndroid Build Coastguard Worker ) + torch.mvlgamma(y + (p + 1) / 2, p=p) 340