xref: /aosp_15_r20/external/pytorch/torch/distributions/multivariate_normal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3
4import torch
5from torch.distributions import constraints
6from torch.distributions.distribution import Distribution
7from torch.distributions.utils import _standard_normal, lazy_property
8from torch.types import _size
9
10
11__all__ = ["MultivariateNormal"]
12
13
14def _batch_mv(bmat, bvec):
15    r"""
16    Performs a batched matrix-vector product, with compatible but different batch shapes.
17
18    This function takes as input `bmat`, containing :math:`n \times n` matrices, and
19    `bvec`, containing length :math:`n` vectors.
20
21    Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
22    to a batch shape. They are not necessarily assumed to have the same batch shape,
23    just ones which can be broadcasted.
24    """
25    return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
26
27
28def _batch_mahalanobis(bL, bx):
29    r"""
30    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
31    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
32
33    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
34    shape, but `bL` one should be able to broadcasted to `bx` one.
35    """
36    n = bx.size(-1)
37    bx_batch_shape = bx.shape[:-1]
38
39    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
40    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
41    bx_batch_dims = len(bx_batch_shape)
42    bL_batch_dims = bL.dim() - 2
43    outer_batch_dims = bx_batch_dims - bL_batch_dims
44    old_batch_dims = outer_batch_dims + bL_batch_dims
45    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
46    # Reshape bx with the shape (..., 1, i, j, 1, n)
47    bx_new_shape = bx.shape[:outer_batch_dims]
48    for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
49        bx_new_shape += (sx // sL, sL)
50    bx_new_shape += (n,)
51    bx = bx.reshape(bx_new_shape)
52    # Permute bx to make it have shape (..., 1, j, i, 1, n)
53    permute_dims = (
54        list(range(outer_batch_dims))
55        + list(range(outer_batch_dims, new_batch_dims, 2))
56        + list(range(outer_batch_dims + 1, new_batch_dims, 2))
57        + [new_batch_dims]
58    )
59    bx = bx.permute(permute_dims)
60
61    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
62    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
63    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
64    M_swap = (
65        torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
66    )  # shape = b x c
67    M = M_swap.t()  # shape = c x b
68
69    # Now we revert the above reshape and permute operators.
70    permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
71    permute_inv_dims = list(range(outer_batch_dims))
72    for i in range(bL_batch_dims):
73        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
74    reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
75    return reshaped_M.reshape(bx_batch_shape)
76
77
78def _precision_to_scale_tril(P):
79    # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
80    Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
81    L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
82    Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
83    L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
84    return L
85
86
87class MultivariateNormal(Distribution):
88    r"""
89    Creates a multivariate normal (also called Gaussian) distribution
90    parameterized by a mean vector and a covariance matrix.
91
92    The multivariate normal distribution can be parameterized either
93    in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
94    or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
95    or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
96    diagonal entries, such that
97    :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
98    can be obtained via e.g. Cholesky decomposition of the covariance.
99
100    Example:
101
102        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
103        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
104        >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
105        >>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
106        tensor([-0.2102, -0.5429])
107
108    Args:
109        loc (Tensor): mean of the distribution
110        covariance_matrix (Tensor): positive-definite covariance matrix
111        precision_matrix (Tensor): positive-definite precision matrix
112        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
113
114    Note:
115        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
116        :attr:`scale_tril` can be specified.
117
118        Using :attr:`scale_tril` will be more efficient: all computations internally
119        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
120        :attr:`precision_matrix` is passed instead, it is only used to compute
121        the corresponding lower triangular matrices using a Cholesky decomposition.
122    """
123    arg_constraints = {
124        "loc": constraints.real_vector,
125        "covariance_matrix": constraints.positive_definite,
126        "precision_matrix": constraints.positive_definite,
127        "scale_tril": constraints.lower_cholesky,
128    }
129    support = constraints.real_vector
130    has_rsample = True
131
132    def __init__(
133        self,
134        loc,
135        covariance_matrix=None,
136        precision_matrix=None,
137        scale_tril=None,
138        validate_args=None,
139    ):
140        if loc.dim() < 1:
141            raise ValueError("loc must be at least one-dimensional.")
142        if (covariance_matrix is not None) + (scale_tril is not None) + (
143            precision_matrix is not None
144        ) != 1:
145            raise ValueError(
146                "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
147            )
148
149        if scale_tril is not None:
150            if scale_tril.dim() < 2:
151                raise ValueError(
152                    "scale_tril matrix must be at least two-dimensional, "
153                    "with optional leading batch dimensions"
154                )
155            batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
156            self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
157        elif covariance_matrix is not None:
158            if covariance_matrix.dim() < 2:
159                raise ValueError(
160                    "covariance_matrix must be at least two-dimensional, "
161                    "with optional leading batch dimensions"
162                )
163            batch_shape = torch.broadcast_shapes(
164                covariance_matrix.shape[:-2], loc.shape[:-1]
165            )
166            self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
167        else:
168            if precision_matrix.dim() < 2:
169                raise ValueError(
170                    "precision_matrix must be at least two-dimensional, "
171                    "with optional leading batch dimensions"
172                )
173            batch_shape = torch.broadcast_shapes(
174                precision_matrix.shape[:-2], loc.shape[:-1]
175            )
176            self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
177        self.loc = loc.expand(batch_shape + (-1,))
178
179        event_shape = self.loc.shape[-1:]
180        super().__init__(batch_shape, event_shape, validate_args=validate_args)
181
182        if scale_tril is not None:
183            self._unbroadcasted_scale_tril = scale_tril
184        elif covariance_matrix is not None:
185            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
186        else:  # precision_matrix is not None
187            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
188
189    def expand(self, batch_shape, _instance=None):
190        new = self._get_checked_instance(MultivariateNormal, _instance)
191        batch_shape = torch.Size(batch_shape)
192        loc_shape = batch_shape + self.event_shape
193        cov_shape = batch_shape + self.event_shape + self.event_shape
194        new.loc = self.loc.expand(loc_shape)
195        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
196        if "covariance_matrix" in self.__dict__:
197            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
198        if "scale_tril" in self.__dict__:
199            new.scale_tril = self.scale_tril.expand(cov_shape)
200        if "precision_matrix" in self.__dict__:
201            new.precision_matrix = self.precision_matrix.expand(cov_shape)
202        super(MultivariateNormal, new).__init__(
203            batch_shape, self.event_shape, validate_args=False
204        )
205        new._validate_args = self._validate_args
206        return new
207
208    @lazy_property
209    def scale_tril(self):
210        return self._unbroadcasted_scale_tril.expand(
211            self._batch_shape + self._event_shape + self._event_shape
212        )
213
214    @lazy_property
215    def covariance_matrix(self):
216        return torch.matmul(
217            self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
218        ).expand(self._batch_shape + self._event_shape + self._event_shape)
219
220    @lazy_property
221    def precision_matrix(self):
222        return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
223            self._batch_shape + self._event_shape + self._event_shape
224        )
225
226    @property
227    def mean(self):
228        return self.loc
229
230    @property
231    def mode(self):
232        return self.loc
233
234    @property
235    def variance(self):
236        return (
237            self._unbroadcasted_scale_tril.pow(2)
238            .sum(-1)
239            .expand(self._batch_shape + self._event_shape)
240        )
241
242    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
243        shape = self._extended_shape(sample_shape)
244        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
245        return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
246
247    def log_prob(self, value):
248        if self._validate_args:
249            self._validate_sample(value)
250        diff = value - self.loc
251        M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
252        half_log_det = (
253            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
254        )
255        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
256
257    def entropy(self):
258        half_log_det = (
259            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
260        )
261        H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
262        if len(self._batch_shape) == 0:
263            return H
264        else:
265            return H.expand(self._batch_shape)
266