xref: /aosp_15_r20/external/pytorch/torch/distributions/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from functools import update_wrapper
3from numbers import Number
4from typing import Any, Dict
5
6import torch
7import torch.nn.functional as F
8from torch.overrides import is_tensor_like
9
10
11euler_constant = 0.57721566490153286060  # Euler Mascheroni Constant
12
13__all__ = [
14    "broadcast_all",
15    "logits_to_probs",
16    "clamp_probs",
17    "probs_to_logits",
18    "lazy_property",
19    "tril_matrix_to_vec",
20    "vec_to_tril_matrix",
21]
22
23
24def broadcast_all(*values):
25    r"""
26    Given a list of values (possibly containing numbers), returns a list where each
27    value is broadcasted based on the following rules:
28      - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
29      - numbers.Number instances (scalars) are upcast to tensors having
30        the same size and type as the first tensor passed to `values`.  If all the
31        values are scalars, then they are upcasted to scalar Tensors.
32
33    Args:
34        values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__)
35
36    Raises:
37        ValueError: if any of the values is not a `numbers.Number` instance,
38            a `torch.*Tensor` instance, or an instance implementing __torch_function__
39    """
40    if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
41        raise ValueError(
42            "Input arguments must all be instances of numbers.Number, "
43            "torch.Tensor or objects implementing __torch_function__."
44        )
45    if not all(is_tensor_like(v) for v in values):
46        options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
47        for value in values:
48            if isinstance(value, torch.Tensor):
49                options = dict(dtype=value.dtype, device=value.device)
50                break
51        new_values = [
52            v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
53        ]
54        return torch.broadcast_tensors(*new_values)
55    return torch.broadcast_tensors(*values)
56
57
58def _standard_normal(shape, dtype, device):
59    if torch._C._get_tracing_state():
60        # [JIT WORKAROUND] lack of support for .normal_()
61        return torch.normal(
62            torch.zeros(shape, dtype=dtype, device=device),
63            torch.ones(shape, dtype=dtype, device=device),
64        )
65    return torch.empty(shape, dtype=dtype, device=device).normal_()
66
67
68def _sum_rightmost(value, dim):
69    r"""
70    Sum out ``dim`` many rightmost dimensions of a given tensor.
71
72    Args:
73        value (Tensor): A tensor of ``.dim()`` at least ``dim``.
74        dim (int): The number of rightmost dims to sum out.
75    """
76    if dim == 0:
77        return value
78    required_shape = value.shape[:-dim] + (-1,)
79    return value.reshape(required_shape).sum(-1)
80
81
82def logits_to_probs(logits, is_binary=False):
83    r"""
84    Converts a tensor of logits into probabilities. Note that for the
85    binary case, each value denotes log odds, whereas for the
86    multi-dimensional case, the values along the last dimension denote
87    the log probabilities (possibly unnormalized) of the events.
88    """
89    if is_binary:
90        return torch.sigmoid(logits)
91    return F.softmax(logits, dim=-1)
92
93
94def clamp_probs(probs):
95    """Clamps the probabilities to be in the open interval `(0, 1)`.
96
97    The probabilities would be clamped between `eps` and `1 - eps`,
98    and `eps` would be the smallest representable positive number for the input data type.
99
100    Args:
101        probs (Tensor): A tensor of probabilities.
102
103    Returns:
104        Tensor: The clamped probabilities.
105
106    Examples:
107        >>> probs = torch.tensor([0.0, 0.5, 1.0])
108        >>> clamp_probs(probs)
109        tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
110
111        >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
112        >>> clamp_probs(probs)
113        tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
114
115    """
116    eps = torch.finfo(probs.dtype).eps
117    return probs.clamp(min=eps, max=1 - eps)
118
119
120def probs_to_logits(probs, is_binary=False):
121    r"""
122    Converts a tensor of probabilities into logits. For the binary case,
123    this denotes the probability of occurrence of the event indexed by `1`.
124    For the multi-dimensional case, the values along the last dimension
125    denote the probabilities of occurrence of each of the events.
126    """
127    ps_clamped = clamp_probs(probs)
128    if is_binary:
129        return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
130    return torch.log(ps_clamped)
131
132
133class lazy_property:
134    r"""
135    Used as a decorator for lazy loading of class attributes. This uses a
136    non-data descriptor that calls the wrapped method to compute the property on
137    first call; thereafter replacing the wrapped method into an instance
138    attribute.
139    """
140
141    def __init__(self, wrapped):
142        self.wrapped = wrapped
143        update_wrapper(self, wrapped)  # type:ignore[arg-type]
144
145    def __get__(self, instance, obj_type=None):
146        if instance is None:
147            return _lazy_property_and_property(self.wrapped)
148        with torch.enable_grad():
149            value = self.wrapped(instance)
150        setattr(instance, self.wrapped.__name__, value)
151        return value
152
153
154class _lazy_property_and_property(lazy_property, property):
155    """We want lazy properties to look like multiple things.
156
157    * property when Sphinx autodoc looks
158    * lazy_property when Distribution validate_args looks
159    """
160
161    def __init__(self, wrapped):
162        property.__init__(self, wrapped)
163
164
165def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
166    r"""
167    Convert a `D x D` matrix or a batch of matrices into a (batched) vector
168    which comprises of lower triangular elements from the matrix in row order.
169    """
170    n = mat.shape[-1]
171    if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
172        raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")
173    arange = torch.arange(n, device=mat.device)
174    tril_mask = arange < arange.view(-1, 1) + (diag + 1)
175    vec = mat[..., tril_mask]
176    return vec
177
178
179def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
180    r"""
181    Convert a vector or a batch of vectors into a batched `D x D`
182    lower triangular matrix containing elements from the vector in row order.
183    """
184    # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
185    n = (
186        -(1 + 2 * diag)
187        + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
188    ) / 2
189    eps = torch.finfo(vec.dtype).eps
190    if not torch._C._get_tracing_state() and (round(n) - n > eps):
191        raise ValueError(
192            f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
193            + "the lower triangular part of a square D x D matrix."
194        )
195    n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
196    mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
197    arange = torch.arange(n, device=vec.device)
198    tril_mask = arange < arange.view(-1, 1) + (diag + 1)
199    mat[..., tril_mask] = vec
200    return mat
201