1# mypy: allow-untyped-defs 2import math 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.distribution import Distribution 7from torch.distributions.utils import _standard_normal, lazy_property 8from torch.types import _size 9 10 11__all__ = ["MultivariateNormal"] 12 13 14def _batch_mv(bmat, bvec): 15 r""" 16 Performs a batched matrix-vector product, with compatible but different batch shapes. 17 18 This function takes as input `bmat`, containing :math:`n \times n` matrices, and 19 `bvec`, containing length :math:`n` vectors. 20 21 Both `bmat` and `bvec` may have any number of leading dimensions, which correspond 22 to a batch shape. They are not necessarily assumed to have the same batch shape, 23 just ones which can be broadcasted. 24 """ 25 return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) 26 27 28def _batch_mahalanobis(bL, bx): 29 r""" 30 Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` 31 for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. 32 33 Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch 34 shape, but `bL` one should be able to broadcasted to `bx` one. 35 """ 36 n = bx.size(-1) 37 bx_batch_shape = bx.shape[:-1] 38 39 # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), 40 # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve 41 bx_batch_dims = len(bx_batch_shape) 42 bL_batch_dims = bL.dim() - 2 43 outer_batch_dims = bx_batch_dims - bL_batch_dims 44 old_batch_dims = outer_batch_dims + bL_batch_dims 45 new_batch_dims = outer_batch_dims + 2 * bL_batch_dims 46 # Reshape bx with the shape (..., 1, i, j, 1, n) 47 bx_new_shape = bx.shape[:outer_batch_dims] 48 for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): 49 bx_new_shape += (sx // sL, sL) 50 bx_new_shape += (n,) 51 bx = bx.reshape(bx_new_shape) 52 # Permute bx to make it have shape (..., 1, j, i, 1, n) 53 permute_dims = ( 54 list(range(outer_batch_dims)) 55 + list(range(outer_batch_dims, new_batch_dims, 2)) 56 + list(range(outer_batch_dims + 1, new_batch_dims, 2)) 57 + [new_batch_dims] 58 ) 59 bx = bx.permute(permute_dims) 60 61 flat_L = bL.reshape(-1, n, n) # shape = b x n x n 62 flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n 63 flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c 64 M_swap = ( 65 torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) 66 ) # shape = b x c 67 M = M_swap.t() # shape = c x b 68 69 # Now we revert the above reshape and permute operators. 70 permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1) 71 permute_inv_dims = list(range(outer_batch_dims)) 72 for i in range(bL_batch_dims): 73 permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] 74 reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1) 75 return reshaped_M.reshape(bx_batch_shape) 76 77 78def _precision_to_scale_tril(P): 79 # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril 80 Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) 81 L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) 82 Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device) 83 L = torch.linalg.solve_triangular(L_inv, Id, upper=False) 84 return L 85 86 87class MultivariateNormal(Distribution): 88 r""" 89 Creates a multivariate normal (also called Gaussian) distribution 90 parameterized by a mean vector and a covariance matrix. 91 92 The multivariate normal distribution can be parameterized either 93 in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` 94 or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` 95 or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued 96 diagonal entries, such that 97 :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix 98 can be obtained via e.g. Cholesky decomposition of the covariance. 99 100 Example: 101 102 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 103 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 104 >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) 105 >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` 106 tensor([-0.2102, -0.5429]) 107 108 Args: 109 loc (Tensor): mean of the distribution 110 covariance_matrix (Tensor): positive-definite covariance matrix 111 precision_matrix (Tensor): positive-definite precision matrix 112 scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal 113 114 Note: 115 Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or 116 :attr:`scale_tril` can be specified. 117 118 Using :attr:`scale_tril` will be more efficient: all computations internally 119 are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or 120 :attr:`precision_matrix` is passed instead, it is only used to compute 121 the corresponding lower triangular matrices using a Cholesky decomposition. 122 """ 123 arg_constraints = { 124 "loc": constraints.real_vector, 125 "covariance_matrix": constraints.positive_definite, 126 "precision_matrix": constraints.positive_definite, 127 "scale_tril": constraints.lower_cholesky, 128 } 129 support = constraints.real_vector 130 has_rsample = True 131 132 def __init__( 133 self, 134 loc, 135 covariance_matrix=None, 136 precision_matrix=None, 137 scale_tril=None, 138 validate_args=None, 139 ): 140 if loc.dim() < 1: 141 raise ValueError("loc must be at least one-dimensional.") 142 if (covariance_matrix is not None) + (scale_tril is not None) + ( 143 precision_matrix is not None 144 ) != 1: 145 raise ValueError( 146 "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." 147 ) 148 149 if scale_tril is not None: 150 if scale_tril.dim() < 2: 151 raise ValueError( 152 "scale_tril matrix must be at least two-dimensional, " 153 "with optional leading batch dimensions" 154 ) 155 batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) 156 self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) 157 elif covariance_matrix is not None: 158 if covariance_matrix.dim() < 2: 159 raise ValueError( 160 "covariance_matrix must be at least two-dimensional, " 161 "with optional leading batch dimensions" 162 ) 163 batch_shape = torch.broadcast_shapes( 164 covariance_matrix.shape[:-2], loc.shape[:-1] 165 ) 166 self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) 167 else: 168 if precision_matrix.dim() < 2: 169 raise ValueError( 170 "precision_matrix must be at least two-dimensional, " 171 "with optional leading batch dimensions" 172 ) 173 batch_shape = torch.broadcast_shapes( 174 precision_matrix.shape[:-2], loc.shape[:-1] 175 ) 176 self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) 177 self.loc = loc.expand(batch_shape + (-1,)) 178 179 event_shape = self.loc.shape[-1:] 180 super().__init__(batch_shape, event_shape, validate_args=validate_args) 181 182 if scale_tril is not None: 183 self._unbroadcasted_scale_tril = scale_tril 184 elif covariance_matrix is not None: 185 self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) 186 else: # precision_matrix is not None 187 self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) 188 189 def expand(self, batch_shape, _instance=None): 190 new = self._get_checked_instance(MultivariateNormal, _instance) 191 batch_shape = torch.Size(batch_shape) 192 loc_shape = batch_shape + self.event_shape 193 cov_shape = batch_shape + self.event_shape + self.event_shape 194 new.loc = self.loc.expand(loc_shape) 195 new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril 196 if "covariance_matrix" in self.__dict__: 197 new.covariance_matrix = self.covariance_matrix.expand(cov_shape) 198 if "scale_tril" in self.__dict__: 199 new.scale_tril = self.scale_tril.expand(cov_shape) 200 if "precision_matrix" in self.__dict__: 201 new.precision_matrix = self.precision_matrix.expand(cov_shape) 202 super(MultivariateNormal, new).__init__( 203 batch_shape, self.event_shape, validate_args=False 204 ) 205 new._validate_args = self._validate_args 206 return new 207 208 @lazy_property 209 def scale_tril(self): 210 return self._unbroadcasted_scale_tril.expand( 211 self._batch_shape + self._event_shape + self._event_shape 212 ) 213 214 @lazy_property 215 def covariance_matrix(self): 216 return torch.matmul( 217 self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT 218 ).expand(self._batch_shape + self._event_shape + self._event_shape) 219 220 @lazy_property 221 def precision_matrix(self): 222 return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand( 223 self._batch_shape + self._event_shape + self._event_shape 224 ) 225 226 @property 227 def mean(self): 228 return self.loc 229 230 @property 231 def mode(self): 232 return self.loc 233 234 @property 235 def variance(self): 236 return ( 237 self._unbroadcasted_scale_tril.pow(2) 238 .sum(-1) 239 .expand(self._batch_shape + self._event_shape) 240 ) 241 242 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 243 shape = self._extended_shape(sample_shape) 244 eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 245 return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps) 246 247 def log_prob(self, value): 248 if self._validate_args: 249 self._validate_sample(value) 250 diff = value - self.loc 251 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) 252 half_log_det = ( 253 self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 254 ) 255 return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det 256 257 def entropy(self): 258 half_log_det = ( 259 self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 260 ) 261 H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det 262 if len(self._batch_shape) == 0: 263 return H 264 else: 265 return H.expand(self._batch_shape) 266