xref: /aosp_15_r20/external/pytorch/torch/distributions/lowrank_multivariate_normal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
8from torch.distributions.utils import _standard_normal, lazy_property
9from torch.types import _size
10
11
12__all__ = ["LowRankMultivariateNormal"]
13
14
15def _batch_capacitance_tril(W, D):
16    r"""
17    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
18    and a batch of vectors :math:`D`.
19    """
20    m = W.size(-1)
21    Wt_Dinv = W.mT / D.unsqueeze(-2)
22    K = torch.matmul(Wt_Dinv, W).contiguous()
23    K.view(-1, m * m)[:, :: m + 1] += 1  # add identity matrix to K
24    return torch.linalg.cholesky(K)
25
26
27def _batch_lowrank_logdet(W, D, capacitance_tril):
28    r"""
29    Uses "matrix determinant lemma"::
30        log|W @ W.T + D| = log|C| + log|D|,
31    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
32    the log determinant.
33    """
34    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
35        -1
36    )
37
38
39def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
40    r"""
41    Uses "Woodbury matrix identity"::
42        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
43    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
44    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
45    """
46    Wt_Dinv = W.mT / D.unsqueeze(-2)
47    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
48    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
49    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
50    return mahalanobis_term1 - mahalanobis_term2
51
52
53class LowRankMultivariateNormal(Distribution):
54    r"""
55    Creates a multivariate normal distribution with covariance matrix having a low-rank form
56    parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
57
58        covariance_matrix = cov_factor @ cov_factor.T + cov_diag
59
60    Example:
61        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
62        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
63        >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
64        >>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
65        tensor([-0.2102, -0.5429])
66
67    Args:
68        loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
69        cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
70            `batch_shape + event_shape + (rank,)`
71        cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
72            `batch_shape + event_shape`
73
74    Note:
75        The computation for determinant and inverse of covariance matrix is avoided when
76        `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
77        <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
78        `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
79        Thanks to these formulas, we just need to compute the determinant and inverse of
80        the small size "capacitance" matrix::
81
82            capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
83    """
84    arg_constraints = {
85        "loc": constraints.real_vector,
86        "cov_factor": constraints.independent(constraints.real, 2),
87        "cov_diag": constraints.independent(constraints.positive, 1),
88    }
89    support = constraints.real_vector
90    has_rsample = True
91
92    def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
93        if loc.dim() < 1:
94            raise ValueError("loc must be at least one-dimensional.")
95        event_shape = loc.shape[-1:]
96        if cov_factor.dim() < 2:
97            raise ValueError(
98                "cov_factor must be at least two-dimensional, "
99                "with optional leading batch dimensions"
100            )
101        if cov_factor.shape[-2:-1] != event_shape:
102            raise ValueError(
103                f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
104            )
105        if cov_diag.shape[-1:] != event_shape:
106            raise ValueError(
107                f"cov_diag must be a batch of vectors with shape {event_shape}"
108            )
109
110        loc_ = loc.unsqueeze(-1)
111        cov_diag_ = cov_diag.unsqueeze(-1)
112        try:
113            loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
114                loc_, cov_factor, cov_diag_
115            )
116        except RuntimeError as e:
117            raise ValueError(
118                f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
119            ) from e
120        self.loc = loc_[..., 0]
121        self.cov_diag = cov_diag_[..., 0]
122        batch_shape = self.loc.shape[:-1]
123
124        self._unbroadcasted_cov_factor = cov_factor
125        self._unbroadcasted_cov_diag = cov_diag
126        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
127        super().__init__(batch_shape, event_shape, validate_args=validate_args)
128
129    def expand(self, batch_shape, _instance=None):
130        new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
131        batch_shape = torch.Size(batch_shape)
132        loc_shape = batch_shape + self.event_shape
133        new.loc = self.loc.expand(loc_shape)
134        new.cov_diag = self.cov_diag.expand(loc_shape)
135        new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
136        new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
137        new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
138        new._capacitance_tril = self._capacitance_tril
139        super(LowRankMultivariateNormal, new).__init__(
140            batch_shape, self.event_shape, validate_args=False
141        )
142        new._validate_args = self._validate_args
143        return new
144
145    @property
146    def mean(self):
147        return self.loc
148
149    @property
150    def mode(self):
151        return self.loc
152
153    @lazy_property
154    def variance(self):
155        return (
156            self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
157        ).expand(self._batch_shape + self._event_shape)
158
159    @lazy_property
160    def scale_tril(self):
161        # The following identity is used to increase the numerically computation stability
162        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
163        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
164        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
165        # hence it is well-conditioned and safe to take Cholesky decomposition.
166        n = self._event_shape[0]
167        cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
168        Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
169        K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
170        K.view(-1, n * n)[:, :: n + 1] += 1  # add identity matrix to K
171        scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
172        return scale_tril.expand(
173            self._batch_shape + self._event_shape + self._event_shape
174        )
175
176    @lazy_property
177    def covariance_matrix(self):
178        covariance_matrix = torch.matmul(
179            self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
180        ) + torch.diag_embed(self._unbroadcasted_cov_diag)
181        return covariance_matrix.expand(
182            self._batch_shape + self._event_shape + self._event_shape
183        )
184
185    @lazy_property
186    def precision_matrix(self):
187        # We use "Woodbury matrix identity" to take advantage of low rank form::
188        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
189        # where :math:`C` is the capacitance matrix.
190        Wt_Dinv = (
191            self._unbroadcasted_cov_factor.mT
192            / self._unbroadcasted_cov_diag.unsqueeze(-2)
193        )
194        A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
195        precision_matrix = (
196            torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
197        )
198        return precision_matrix.expand(
199            self._batch_shape + self._event_shape + self._event_shape
200        )
201
202    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
203        shape = self._extended_shape(sample_shape)
204        W_shape = shape[:-1] + self.cov_factor.shape[-1:]
205        eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
206        eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
207        return (
208            self.loc
209            + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
210            + self._unbroadcasted_cov_diag.sqrt() * eps_D
211        )
212
213    def log_prob(self, value):
214        if self._validate_args:
215            self._validate_sample(value)
216        diff = value - self.loc
217        M = _batch_lowrank_mahalanobis(
218            self._unbroadcasted_cov_factor,
219            self._unbroadcasted_cov_diag,
220            diff,
221            self._capacitance_tril,
222        )
223        log_det = _batch_lowrank_logdet(
224            self._unbroadcasted_cov_factor,
225            self._unbroadcasted_cov_diag,
226            self._capacitance_tril,
227        )
228        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
229
230    def entropy(self):
231        log_det = _batch_lowrank_logdet(
232            self._unbroadcasted_cov_factor,
233            self._unbroadcasted_cov_diag,
234            self._capacitance_tril,
235        )
236        H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
237        if len(self._batch_shape) == 0:
238            return H
239        else:
240            return H.expand(self._batch_shape)
241