1# mypy: allow-untyped-defs 2import math 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.distribution import Distribution 7from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv 8from torch.distributions.utils import _standard_normal, lazy_property 9from torch.types import _size 10 11 12__all__ = ["LowRankMultivariateNormal"] 13 14 15def _batch_capacitance_tril(W, D): 16 r""" 17 Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` 18 and a batch of vectors :math:`D`. 19 """ 20 m = W.size(-1) 21 Wt_Dinv = W.mT / D.unsqueeze(-2) 22 K = torch.matmul(Wt_Dinv, W).contiguous() 23 K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K 24 return torch.linalg.cholesky(K) 25 26 27def _batch_lowrank_logdet(W, D, capacitance_tril): 28 r""" 29 Uses "matrix determinant lemma":: 30 log|W @ W.T + D| = log|C| + log|D|, 31 where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute 32 the log determinant. 33 """ 34 return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum( 35 -1 36 ) 37 38 39def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): 40 r""" 41 Uses "Woodbury matrix identity":: 42 inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), 43 where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared 44 Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. 45 """ 46 Wt_Dinv = W.mT / D.unsqueeze(-2) 47 Wt_Dinv_x = _batch_mv(Wt_Dinv, x) 48 mahalanobis_term1 = (x.pow(2) / D).sum(-1) 49 mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) 50 return mahalanobis_term1 - mahalanobis_term2 51 52 53class LowRankMultivariateNormal(Distribution): 54 r""" 55 Creates a multivariate normal distribution with covariance matrix having a low-rank form 56 parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: 57 58 covariance_matrix = cov_factor @ cov_factor.T + cov_diag 59 60 Example: 61 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 62 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 63 >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) 64 >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` 65 tensor([-0.2102, -0.5429]) 66 67 Args: 68 loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` 69 cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape 70 `batch_shape + event_shape + (rank,)` 71 cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape 72 `batch_shape + event_shape` 73 74 Note: 75 The computation for determinant and inverse of covariance matrix is avoided when 76 `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity 77 <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and 78 `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_. 79 Thanks to these formulas, we just need to compute the determinant and inverse of 80 the small size "capacitance" matrix:: 81 82 capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor 83 """ 84 arg_constraints = { 85 "loc": constraints.real_vector, 86 "cov_factor": constraints.independent(constraints.real, 2), 87 "cov_diag": constraints.independent(constraints.positive, 1), 88 } 89 support = constraints.real_vector 90 has_rsample = True 91 92 def __init__(self, loc, cov_factor, cov_diag, validate_args=None): 93 if loc.dim() < 1: 94 raise ValueError("loc must be at least one-dimensional.") 95 event_shape = loc.shape[-1:] 96 if cov_factor.dim() < 2: 97 raise ValueError( 98 "cov_factor must be at least two-dimensional, " 99 "with optional leading batch dimensions" 100 ) 101 if cov_factor.shape[-2:-1] != event_shape: 102 raise ValueError( 103 f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" 104 ) 105 if cov_diag.shape[-1:] != event_shape: 106 raise ValueError( 107 f"cov_diag must be a batch of vectors with shape {event_shape}" 108 ) 109 110 loc_ = loc.unsqueeze(-1) 111 cov_diag_ = cov_diag.unsqueeze(-1) 112 try: 113 loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( 114 loc_, cov_factor, cov_diag_ 115 ) 116 except RuntimeError as e: 117 raise ValueError( 118 f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" 119 ) from e 120 self.loc = loc_[..., 0] 121 self.cov_diag = cov_diag_[..., 0] 122 batch_shape = self.loc.shape[:-1] 123 124 self._unbroadcasted_cov_factor = cov_factor 125 self._unbroadcasted_cov_diag = cov_diag 126 self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) 127 super().__init__(batch_shape, event_shape, validate_args=validate_args) 128 129 def expand(self, batch_shape, _instance=None): 130 new = self._get_checked_instance(LowRankMultivariateNormal, _instance) 131 batch_shape = torch.Size(batch_shape) 132 loc_shape = batch_shape + self.event_shape 133 new.loc = self.loc.expand(loc_shape) 134 new.cov_diag = self.cov_diag.expand(loc_shape) 135 new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:]) 136 new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor 137 new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag 138 new._capacitance_tril = self._capacitance_tril 139 super(LowRankMultivariateNormal, new).__init__( 140 batch_shape, self.event_shape, validate_args=False 141 ) 142 new._validate_args = self._validate_args 143 return new 144 145 @property 146 def mean(self): 147 return self.loc 148 149 @property 150 def mode(self): 151 return self.loc 152 153 @lazy_property 154 def variance(self): 155 return ( 156 self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag 157 ).expand(self._batch_shape + self._event_shape) 158 159 @lazy_property 160 def scale_tril(self): 161 # The following identity is used to increase the numerically computation stability 162 # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): 163 # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 164 # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, 165 # hence it is well-conditioned and safe to take Cholesky decomposition. 166 n = self._event_shape[0] 167 cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) 168 Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze 169 K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() 170 K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K 171 scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) 172 return scale_tril.expand( 173 self._batch_shape + self._event_shape + self._event_shape 174 ) 175 176 @lazy_property 177 def covariance_matrix(self): 178 covariance_matrix = torch.matmul( 179 self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT 180 ) + torch.diag_embed(self._unbroadcasted_cov_diag) 181 return covariance_matrix.expand( 182 self._batch_shape + self._event_shape + self._event_shape 183 ) 184 185 @lazy_property 186 def precision_matrix(self): 187 # We use "Woodbury matrix identity" to take advantage of low rank form:: 188 # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) 189 # where :math:`C` is the capacitance matrix. 190 Wt_Dinv = ( 191 self._unbroadcasted_cov_factor.mT 192 / self._unbroadcasted_cov_diag.unsqueeze(-2) 193 ) 194 A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) 195 precision_matrix = ( 196 torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A 197 ) 198 return precision_matrix.expand( 199 self._batch_shape + self._event_shape + self._event_shape 200 ) 201 202 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 203 shape = self._extended_shape(sample_shape) 204 W_shape = shape[:-1] + self.cov_factor.shape[-1:] 205 eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) 206 eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 207 return ( 208 self.loc 209 + _batch_mv(self._unbroadcasted_cov_factor, eps_W) 210 + self._unbroadcasted_cov_diag.sqrt() * eps_D 211 ) 212 213 def log_prob(self, value): 214 if self._validate_args: 215 self._validate_sample(value) 216 diff = value - self.loc 217 M = _batch_lowrank_mahalanobis( 218 self._unbroadcasted_cov_factor, 219 self._unbroadcasted_cov_diag, 220 diff, 221 self._capacitance_tril, 222 ) 223 log_det = _batch_lowrank_logdet( 224 self._unbroadcasted_cov_factor, 225 self._unbroadcasted_cov_diag, 226 self._capacitance_tril, 227 ) 228 return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) 229 230 def entropy(self): 231 log_det = _batch_lowrank_logdet( 232 self._unbroadcasted_cov_factor, 233 self._unbroadcasted_cov_diag, 234 self._capacitance_tril, 235 ) 236 H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) 237 if len(self._batch_shape) == 0: 238 return H 239 else: 240 return H.expand(self._batch_shape) 241