xref: /aosp_15_r20/external/pytorch/torch/distributions/wishart.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport math
3*da0073e9SAndroid Build Coastguard Workerimport warnings
4*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Union
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch import nan
9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
10*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.exp_family import ExponentialFamily
11*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.multivariate_normal import _precision_to_scale_tril
12*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import lazy_property
13*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker__all__ = ["Wishart"]
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker_log_2 = math.log(2)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
22*da0073e9SAndroid Build Coastguard Worker    assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
23*da0073e9SAndroid Build Coastguard Worker    return torch.digamma(
24*da0073e9SAndroid Build Coastguard Worker        x.unsqueeze(-1)
25*da0073e9SAndroid Build Coastguard Worker        - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
26*da0073e9SAndroid Build Coastguard Worker    ).sum(-1)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
30*da0073e9SAndroid Build Coastguard Worker    # We assume positive input for this function
31*da0073e9SAndroid Build Coastguard Worker    return x.clamp(min=torch.finfo(x.dtype).eps)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerclass Wishart(ExponentialFamily):
35*da0073e9SAndroid Build Coastguard Worker    r"""
36*da0073e9SAndroid Build Coastguard Worker    Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
37*da0073e9SAndroid Build Coastguard Worker    or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    Example:
40*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
41*da0073e9SAndroid Build Coastguard Worker        >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
42*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # Wishart distributed with mean=`df * I` and
43*da0073e9SAndroid Build Coastguard Worker        >>>             # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    Args:
46*da0073e9SAndroid Build Coastguard Worker        df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
47*da0073e9SAndroid Build Coastguard Worker        covariance_matrix (Tensor): positive-definite covariance matrix
48*da0073e9SAndroid Build Coastguard Worker        precision_matrix (Tensor): positive-definite precision matrix
49*da0073e9SAndroid Build Coastguard Worker        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
50*da0073e9SAndroid Build Coastguard Worker    Note:
51*da0073e9SAndroid Build Coastguard Worker        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
52*da0073e9SAndroid Build Coastguard Worker        :attr:`scale_tril` can be specified.
53*da0073e9SAndroid Build Coastguard Worker        Using :attr:`scale_tril` will be more efficient: all computations internally
54*da0073e9SAndroid Build Coastguard Worker        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
55*da0073e9SAndroid Build Coastguard Worker        :attr:`precision_matrix` is passed instead, it is only used to compute
56*da0073e9SAndroid Build Coastguard Worker        the corresponding lower triangular matrices using a Cholesky decomposition.
57*da0073e9SAndroid Build Coastguard Worker        'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    **References**
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
62*da0073e9SAndroid Build Coastguard Worker    [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
63*da0073e9SAndroid Build Coastguard Worker    [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
64*da0073e9SAndroid Build Coastguard Worker    [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
65*da0073e9SAndroid Build Coastguard Worker    [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
66*da0073e9SAndroid Build Coastguard Worker    """
67*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {
68*da0073e9SAndroid Build Coastguard Worker        "covariance_matrix": constraints.positive_definite,
69*da0073e9SAndroid Build Coastguard Worker        "precision_matrix": constraints.positive_definite,
70*da0073e9SAndroid Build Coastguard Worker        "scale_tril": constraints.lower_cholesky,
71*da0073e9SAndroid Build Coastguard Worker        "df": constraints.greater_than(0),
72*da0073e9SAndroid Build Coastguard Worker    }
73*da0073e9SAndroid Build Coastguard Worker    support = constraints.positive_definite
74*da0073e9SAndroid Build Coastguard Worker    has_rsample = True
75*da0073e9SAndroid Build Coastguard Worker    _mean_carrier_measure = 0
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def __init__(
78*da0073e9SAndroid Build Coastguard Worker        self,
79*da0073e9SAndroid Build Coastguard Worker        df: Union[torch.Tensor, Number],
80*da0073e9SAndroid Build Coastguard Worker        covariance_matrix: Optional[torch.Tensor] = None,
81*da0073e9SAndroid Build Coastguard Worker        precision_matrix: Optional[torch.Tensor] = None,
82*da0073e9SAndroid Build Coastguard Worker        scale_tril: Optional[torch.Tensor] = None,
83*da0073e9SAndroid Build Coastguard Worker        validate_args=None,
84*da0073e9SAndroid Build Coastguard Worker    ):
85*da0073e9SAndroid Build Coastguard Worker        assert (covariance_matrix is not None) + (scale_tril is not None) + (
86*da0073e9SAndroid Build Coastguard Worker            precision_matrix is not None
87*da0073e9SAndroid Build Coastguard Worker        ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        param = next(
90*da0073e9SAndroid Build Coastguard Worker            p
91*da0073e9SAndroid Build Coastguard Worker            for p in (covariance_matrix, precision_matrix, scale_tril)
92*da0073e9SAndroid Build Coastguard Worker            if p is not None
93*da0073e9SAndroid Build Coastguard Worker        )
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        if param.dim() < 2:
96*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
97*da0073e9SAndroid Build Coastguard Worker                "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
98*da0073e9SAndroid Build Coastguard Worker            )
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        if isinstance(df, Number):
101*da0073e9SAndroid Build Coastguard Worker            batch_shape = torch.Size(param.shape[:-2])
102*da0073e9SAndroid Build Coastguard Worker            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
103*da0073e9SAndroid Build Coastguard Worker        else:
104*da0073e9SAndroid Build Coastguard Worker            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
105*da0073e9SAndroid Build Coastguard Worker            self.df = df.expand(batch_shape)
106*da0073e9SAndroid Build Coastguard Worker        event_shape = param.shape[-2:]
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        if self.df.le(event_shape[-1] - 1).any():
109*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
110*da0073e9SAndroid Build Coastguard Worker                f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
111*da0073e9SAndroid Build Coastguard Worker            )
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        if scale_tril is not None:
114*da0073e9SAndroid Build Coastguard Worker            self.scale_tril = param.expand(batch_shape + (-1, -1))
115*da0073e9SAndroid Build Coastguard Worker        elif covariance_matrix is not None:
116*da0073e9SAndroid Build Coastguard Worker            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
117*da0073e9SAndroid Build Coastguard Worker        elif precision_matrix is not None:
118*da0073e9SAndroid Build Coastguard Worker            self.precision_matrix = param.expand(batch_shape + (-1, -1))
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
121*da0073e9SAndroid Build Coastguard Worker        if self.df.lt(event_shape[-1]).any():
122*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
123*da0073e9SAndroid Build Coastguard Worker                "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
124*da0073e9SAndroid Build Coastguard Worker            )
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        super().__init__(batch_shape, event_shape, validate_args=validate_args)
127*da0073e9SAndroid Build Coastguard Worker        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        if scale_tril is not None:
130*da0073e9SAndroid Build Coastguard Worker            self._unbroadcasted_scale_tril = scale_tril
131*da0073e9SAndroid Build Coastguard Worker        elif covariance_matrix is not None:
132*da0073e9SAndroid Build Coastguard Worker            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
133*da0073e9SAndroid Build Coastguard Worker        else:  # precision_matrix is not None
134*da0073e9SAndroid Build Coastguard Worker            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        # Chi2 distribution is needed for Bartlett decomposition sampling
137*da0073e9SAndroid Build Coastguard Worker        self._dist_chi2 = torch.distributions.chi2.Chi2(
138*da0073e9SAndroid Build Coastguard Worker            df=(
139*da0073e9SAndroid Build Coastguard Worker                self.df.unsqueeze(-1)
140*da0073e9SAndroid Build Coastguard Worker                - torch.arange(
141*da0073e9SAndroid Build Coastguard Worker                    self._event_shape[-1],
142*da0073e9SAndroid Build Coastguard Worker                    dtype=self._unbroadcasted_scale_tril.dtype,
143*da0073e9SAndroid Build Coastguard Worker                    device=self._unbroadcasted_scale_tril.device,
144*da0073e9SAndroid Build Coastguard Worker                ).expand(batch_shape + (-1,))
145*da0073e9SAndroid Build Coastguard Worker            )
146*da0073e9SAndroid Build Coastguard Worker        )
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
149*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Wishart, _instance)
150*da0073e9SAndroid Build Coastguard Worker        batch_shape = torch.Size(batch_shape)
151*da0073e9SAndroid Build Coastguard Worker        cov_shape = batch_shape + self.event_shape
152*da0073e9SAndroid Build Coastguard Worker        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
153*da0073e9SAndroid Build Coastguard Worker        new.df = self.df.expand(batch_shape)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        if "covariance_matrix" in self.__dict__:
158*da0073e9SAndroid Build Coastguard Worker            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
159*da0073e9SAndroid Build Coastguard Worker        if "scale_tril" in self.__dict__:
160*da0073e9SAndroid Build Coastguard Worker            new.scale_tril = self.scale_tril.expand(cov_shape)
161*da0073e9SAndroid Build Coastguard Worker        if "precision_matrix" in self.__dict__:
162*da0073e9SAndroid Build Coastguard Worker            new.precision_matrix = self.precision_matrix.expand(cov_shape)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        # Chi2 distribution is needed for Bartlett decomposition sampling
165*da0073e9SAndroid Build Coastguard Worker        new._dist_chi2 = torch.distributions.chi2.Chi2(
166*da0073e9SAndroid Build Coastguard Worker            df=(
167*da0073e9SAndroid Build Coastguard Worker                new.df.unsqueeze(-1)
168*da0073e9SAndroid Build Coastguard Worker                - torch.arange(
169*da0073e9SAndroid Build Coastguard Worker                    self.event_shape[-1],
170*da0073e9SAndroid Build Coastguard Worker                    dtype=new._unbroadcasted_scale_tril.dtype,
171*da0073e9SAndroid Build Coastguard Worker                    device=new._unbroadcasted_scale_tril.device,
172*da0073e9SAndroid Build Coastguard Worker                ).expand(batch_shape + (-1,))
173*da0073e9SAndroid Build Coastguard Worker            )
174*da0073e9SAndroid Build Coastguard Worker        )
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
177*da0073e9SAndroid Build Coastguard Worker        new._validate_args = self._validate_args
178*da0073e9SAndroid Build Coastguard Worker        return new
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @lazy_property
181*da0073e9SAndroid Build Coastguard Worker    def scale_tril(self):
182*da0073e9SAndroid Build Coastguard Worker        return self._unbroadcasted_scale_tril.expand(
183*da0073e9SAndroid Build Coastguard Worker            self._batch_shape + self._event_shape
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    @lazy_property
187*da0073e9SAndroid Build Coastguard Worker    def covariance_matrix(self):
188*da0073e9SAndroid Build Coastguard Worker        return (
189*da0073e9SAndroid Build Coastguard Worker            self._unbroadcasted_scale_tril
190*da0073e9SAndroid Build Coastguard Worker            @ self._unbroadcasted_scale_tril.transpose(-2, -1)
191*da0073e9SAndroid Build Coastguard Worker        ).expand(self._batch_shape + self._event_shape)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    @lazy_property
194*da0073e9SAndroid Build Coastguard Worker    def precision_matrix(self):
195*da0073e9SAndroid Build Coastguard Worker        identity = torch.eye(
196*da0073e9SAndroid Build Coastguard Worker            self._event_shape[-1],
197*da0073e9SAndroid Build Coastguard Worker            device=self._unbroadcasted_scale_tril.device,
198*da0073e9SAndroid Build Coastguard Worker            dtype=self._unbroadcasted_scale_tril.dtype,
199*da0073e9SAndroid Build Coastguard Worker        )
200*da0073e9SAndroid Build Coastguard Worker        return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
201*da0073e9SAndroid Build Coastguard Worker            self._batch_shape + self._event_shape
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    @property
205*da0073e9SAndroid Build Coastguard Worker    def mean(self):
206*da0073e9SAndroid Build Coastguard Worker        return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    @property
209*da0073e9SAndroid Build Coastguard Worker    def mode(self):
210*da0073e9SAndroid Build Coastguard Worker        factor = self.df - self.covariance_matrix.shape[-1] - 1
211*da0073e9SAndroid Build Coastguard Worker        factor[factor <= 0] = nan
212*da0073e9SAndroid Build Coastguard Worker        return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    @property
215*da0073e9SAndroid Build Coastguard Worker    def variance(self):
216*da0073e9SAndroid Build Coastguard Worker        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
217*da0073e9SAndroid Build Coastguard Worker        diag_V = V.diagonal(dim1=-2, dim2=-1)
218*da0073e9SAndroid Build Coastguard Worker        return self.df.view(self._batch_shape + (1, 1)) * (
219*da0073e9SAndroid Build Coastguard Worker            V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
220*da0073e9SAndroid Build Coastguard Worker        )
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    def _bartlett_sampling(self, sample_shape=torch.Size()):
223*da0073e9SAndroid Build Coastguard Worker        p = self._event_shape[-1]  # has singleton shape
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        # Implemented Sampling using Bartlett decomposition
226*da0073e9SAndroid Build Coastguard Worker        noise = _clamp_above_eps(
227*da0073e9SAndroid Build Coastguard Worker            self._dist_chi2.rsample(sample_shape).sqrt()
228*da0073e9SAndroid Build Coastguard Worker        ).diag_embed(dim1=-2, dim2=-1)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        i, j = torch.tril_indices(p, p, offset=-1)
231*da0073e9SAndroid Build Coastguard Worker        noise[..., i, j] = torch.randn(
232*da0073e9SAndroid Build Coastguard Worker            torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
233*da0073e9SAndroid Build Coastguard Worker            dtype=noise.dtype,
234*da0073e9SAndroid Build Coastguard Worker            device=noise.device,
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker        chol = self._unbroadcasted_scale_tril @ noise
237*da0073e9SAndroid Build Coastguard Worker        return chol @ chol.transpose(-2, -1)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    def rsample(
240*da0073e9SAndroid Build Coastguard Worker        self, sample_shape: _size = torch.Size(), max_try_correction=None
241*da0073e9SAndroid Build Coastguard Worker    ) -> torch.Tensor:
242*da0073e9SAndroid Build Coastguard Worker        r"""
243*da0073e9SAndroid Build Coastguard Worker        .. warning::
244*da0073e9SAndroid Build Coastguard Worker            In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
245*da0073e9SAndroid Build Coastguard Worker            Several tries to correct singular samples are performed by default, but it may end up returning
246*da0073e9SAndroid Build Coastguard Worker            singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
247*da0073e9SAndroid Build Coastguard Worker            In those cases, the user should validate the samples and either fix the value of `df`
248*da0073e9SAndroid Build Coastguard Worker            or adjust `max_try_correction` value for argument in `.rsample` accordingly.
249*da0073e9SAndroid Build Coastguard Worker        """
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        if max_try_correction is None:
252*da0073e9SAndroid Build Coastguard Worker            max_try_correction = 3 if torch._C._get_tracing_state() else 10
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        sample_shape = torch.Size(sample_shape)
255*da0073e9SAndroid Build Coastguard Worker        sample = self._bartlett_sampling(sample_shape)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        # Below part is to improve numerical stability temporally and should be removed in the future
258*da0073e9SAndroid Build Coastguard Worker        is_singular = self.support.check(sample)
259*da0073e9SAndroid Build Coastguard Worker        if self._batch_shape:
260*da0073e9SAndroid Build Coastguard Worker            is_singular = is_singular.amax(self._batch_dims)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        if torch._C._get_tracing_state():
263*da0073e9SAndroid Build Coastguard Worker            # Less optimized version for JIT
264*da0073e9SAndroid Build Coastguard Worker            for _ in range(max_try_correction):
265*da0073e9SAndroid Build Coastguard Worker                sample_new = self._bartlett_sampling(sample_shape)
266*da0073e9SAndroid Build Coastguard Worker                sample = torch.where(is_singular, sample_new, sample)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker                is_singular = ~self.support.check(sample)
269*da0073e9SAndroid Build Coastguard Worker                if self._batch_shape:
270*da0073e9SAndroid Build Coastguard Worker                    is_singular = is_singular.amax(self._batch_dims)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        else:
273*da0073e9SAndroid Build Coastguard Worker            # More optimized version with data-dependent control flow.
274*da0073e9SAndroid Build Coastguard Worker            if is_singular.any():
275*da0073e9SAndroid Build Coastguard Worker                warnings.warn("Singular sample detected.")
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker                for _ in range(max_try_correction):
278*da0073e9SAndroid Build Coastguard Worker                    sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
279*da0073e9SAndroid Build Coastguard Worker                    sample[is_singular] = sample_new
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker                    is_singular_new = ~self.support.check(sample_new)
282*da0073e9SAndroid Build Coastguard Worker                    if self._batch_shape:
283*da0073e9SAndroid Build Coastguard Worker                        is_singular_new = is_singular_new.amax(self._batch_dims)
284*da0073e9SAndroid Build Coastguard Worker                    is_singular[is_singular.clone()] = is_singular_new
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker                    if not is_singular.any():
287*da0073e9SAndroid Build Coastguard Worker                        break
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        return sample
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value):
292*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
293*da0073e9SAndroid Build Coastguard Worker            self._validate_sample(value)
294*da0073e9SAndroid Build Coastguard Worker        nu = self.df  # has shape (batch_shape)
295*da0073e9SAndroid Build Coastguard Worker        p = self._event_shape[-1]  # has singleton shape
296*da0073e9SAndroid Build Coastguard Worker        return (
297*da0073e9SAndroid Build Coastguard Worker            -nu
298*da0073e9SAndroid Build Coastguard Worker            * (
299*da0073e9SAndroid Build Coastguard Worker                p * _log_2 / 2
300*da0073e9SAndroid Build Coastguard Worker                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
301*da0073e9SAndroid Build Coastguard Worker                .log()
302*da0073e9SAndroid Build Coastguard Worker                .sum(-1)
303*da0073e9SAndroid Build Coastguard Worker            )
304*da0073e9SAndroid Build Coastguard Worker            - torch.mvlgamma(nu / 2, p=p)
305*da0073e9SAndroid Build Coastguard Worker            + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
306*da0073e9SAndroid Build Coastguard Worker            - torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
307*da0073e9SAndroid Build Coastguard Worker            .diagonal(dim1=-2, dim2=-1)
308*da0073e9SAndroid Build Coastguard Worker            .sum(dim=-1)
309*da0073e9SAndroid Build Coastguard Worker            / 2
310*da0073e9SAndroid Build Coastguard Worker        )
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
313*da0073e9SAndroid Build Coastguard Worker        nu = self.df  # has shape (batch_shape)
314*da0073e9SAndroid Build Coastguard Worker        p = self._event_shape[-1]  # has singleton shape
315*da0073e9SAndroid Build Coastguard Worker        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
316*da0073e9SAndroid Build Coastguard Worker        return (
317*da0073e9SAndroid Build Coastguard Worker            (p + 1)
318*da0073e9SAndroid Build Coastguard Worker            * (
319*da0073e9SAndroid Build Coastguard Worker                p * _log_2 / 2
320*da0073e9SAndroid Build Coastguard Worker                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
321*da0073e9SAndroid Build Coastguard Worker                .log()
322*da0073e9SAndroid Build Coastguard Worker                .sum(-1)
323*da0073e9SAndroid Build Coastguard Worker            )
324*da0073e9SAndroid Build Coastguard Worker            + torch.mvlgamma(nu / 2, p=p)
325*da0073e9SAndroid Build Coastguard Worker            - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
326*da0073e9SAndroid Build Coastguard Worker            + nu * p / 2
327*da0073e9SAndroid Build Coastguard Worker        )
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    @property
330*da0073e9SAndroid Build Coastguard Worker    def _natural_params(self):
331*da0073e9SAndroid Build Coastguard Worker        nu = self.df  # has shape (batch_shape)
332*da0073e9SAndroid Build Coastguard Worker        p = self._event_shape[-1]  # has singleton shape
333*da0073e9SAndroid Build Coastguard Worker        return -self.precision_matrix / 2, (nu - p - 1) / 2
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def _log_normalizer(self, x, y):
336*da0073e9SAndroid Build Coastguard Worker        p = self._event_shape[-1]
337*da0073e9SAndroid Build Coastguard Worker        return (y + (p + 1) / 2) * (
338*da0073e9SAndroid Build Coastguard Worker            -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
339*da0073e9SAndroid Build Coastguard Worker        ) + torch.mvlgamma(y + (p + 1) / 2, p=p)
340