xref: /aosp_15_r20/external/pytorch/torch/distributions/constraints.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3The following constraints are implemented:
4
5- ``constraints.boolean``
6- ``constraints.cat``
7- ``constraints.corr_cholesky``
8- ``constraints.dependent``
9- ``constraints.greater_than(lower_bound)``
10- ``constraints.greater_than_eq(lower_bound)``
11- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
12- ``constraints.integer_interval(lower_bound, upper_bound)``
13- ``constraints.interval(lower_bound, upper_bound)``
14- ``constraints.less_than(upper_bound)``
15- ``constraints.lower_cholesky``
16- ``constraints.lower_triangular``
17- ``constraints.multinomial``
18- ``constraints.nonnegative``
19- ``constraints.nonnegative_integer``
20- ``constraints.one_hot``
21- ``constraints.positive_integer``
22- ``constraints.positive``
23- ``constraints.positive_semidefinite``
24- ``constraints.positive_definite``
25- ``constraints.real_vector``
26- ``constraints.real``
27- ``constraints.simplex``
28- ``constraints.symmetric``
29- ``constraints.stack``
30- ``constraints.square``
31- ``constraints.symmetric``
32- ``constraints.unit_interval``
33"""
34
35import torch
36
37
38__all__ = [
39    "Constraint",
40    "boolean",
41    "cat",
42    "corr_cholesky",
43    "dependent",
44    "dependent_property",
45    "greater_than",
46    "greater_than_eq",
47    "independent",
48    "integer_interval",
49    "interval",
50    "half_open_interval",
51    "is_dependent",
52    "less_than",
53    "lower_cholesky",
54    "lower_triangular",
55    "multinomial",
56    "nonnegative",
57    "nonnegative_integer",
58    "one_hot",
59    "positive",
60    "positive_semidefinite",
61    "positive_definite",
62    "positive_integer",
63    "real",
64    "real_vector",
65    "simplex",
66    "square",
67    "stack",
68    "symmetric",
69    "unit_interval",
70]
71
72
73class Constraint:
74    """
75    Abstract base class for constraints.
76
77    A constraint object represents a region over which a variable is valid,
78    e.g. within which a variable can be optimized.
79
80    Attributes:
81        is_discrete (bool): Whether constrained space is discrete.
82            Defaults to False.
83        event_dim (int): Number of rightmost dimensions that together define
84            an event. The :meth:`check` method will remove this many dimensions
85            when computing validity.
86    """
87
88    is_discrete = False  # Default to continuous.
89    event_dim = 0  # Default to univariate.
90
91    def check(self, value):
92        """
93        Returns a byte tensor of ``sample_shape + batch_shape`` indicating
94        whether each event in value satisfies this constraint.
95        """
96        raise NotImplementedError
97
98    def __repr__(self):
99        return self.__class__.__name__[1:] + "()"
100
101
102class _Dependent(Constraint):
103    """
104    Placeholder for variables whose support depends on other variables.
105    These variables obey no simple coordinate-wise constraints.
106
107    Args:
108        is_discrete (bool): Optional value of ``.is_discrete`` in case this
109            can be computed statically. If not provided, access to the
110            ``.is_discrete`` attribute will raise a NotImplementedError.
111        event_dim (int): Optional value of ``.event_dim`` in case this
112            can be computed statically. If not provided, access to the
113            ``.event_dim`` attribute will raise a NotImplementedError.
114    """
115
116    def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
117        self._is_discrete = is_discrete
118        self._event_dim = event_dim
119        super().__init__()
120
121    @property
122    def is_discrete(self):
123        if self._is_discrete is NotImplemented:
124            raise NotImplementedError(".is_discrete cannot be determined statically")
125        return self._is_discrete
126
127    @property
128    def event_dim(self):
129        if self._event_dim is NotImplemented:
130            raise NotImplementedError(".event_dim cannot be determined statically")
131        return self._event_dim
132
133    def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
134        """
135        Support for syntax to customize static attributes::
136
137            constraints.dependent(is_discrete=True, event_dim=1)
138        """
139        if is_discrete is NotImplemented:
140            is_discrete = self._is_discrete
141        if event_dim is NotImplemented:
142            event_dim = self._event_dim
143        return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
144
145    def check(self, x):
146        raise ValueError("Cannot determine validity of dependent constraint")
147
148
149def is_dependent(constraint):
150    """
151    Checks if ``constraint`` is a ``_Dependent`` object.
152
153    Args:
154        constraint : A ``Constraint`` object.
155
156    Returns:
157        ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise.
158
159    Examples:
160        >>> import torch
161        >>> from torch.distributions import Bernoulli
162        >>> from torch.distributions.constraints import is_dependent
163
164        >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True))
165        >>> constraint1 = dist.arg_constraints["probs"]
166        >>> constraint2 = dist.arg_constraints["logits"]
167
168        >>> for constraint in [constraint1, constraint2]:
169        >>>     if is_dependent(constraint):
170        >>>         continue
171    """
172    return isinstance(constraint, _Dependent)
173
174
175class _DependentProperty(property, _Dependent):
176    """
177    Decorator that extends @property to act like a `Dependent` constraint when
178    called on a class and act like a property when called on an object.
179
180    Example::
181
182        class Uniform(Distribution):
183            def __init__(self, low, high):
184                self.low = low
185                self.high = high
186            @constraints.dependent_property(is_discrete=False, event_dim=0)
187            def support(self):
188                return constraints.interval(self.low, self.high)
189
190    Args:
191        fn (Callable): The function to be decorated.
192        is_discrete (bool): Optional value of ``.is_discrete`` in case this
193            can be computed statically. If not provided, access to the
194            ``.is_discrete`` attribute will raise a NotImplementedError.
195        event_dim (int): Optional value of ``.event_dim`` in case this
196            can be computed statically. If not provided, access to the
197            ``.event_dim`` attribute will raise a NotImplementedError.
198    """
199
200    def __init__(
201        self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
202    ):
203        super().__init__(fn)
204        self._is_discrete = is_discrete
205        self._event_dim = event_dim
206
207    def __call__(self, fn):
208        """
209        Support for syntax to customize static attributes::
210
211            @constraints.dependent_property(is_discrete=True, event_dim=1)
212            def support(self):
213                ...
214        """
215        return _DependentProperty(
216            fn, is_discrete=self._is_discrete, event_dim=self._event_dim
217        )
218
219
220class _IndependentConstraint(Constraint):
221    """
222    Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
223    dims in :meth:`check`, so that an event is valid only if all its
224    independent entries are valid.
225    """
226
227    def __init__(self, base_constraint, reinterpreted_batch_ndims):
228        assert isinstance(base_constraint, Constraint)
229        assert isinstance(reinterpreted_batch_ndims, int)
230        assert reinterpreted_batch_ndims >= 0
231        self.base_constraint = base_constraint
232        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
233        super().__init__()
234
235    @property
236    def is_discrete(self):
237        return self.base_constraint.is_discrete
238
239    @property
240    def event_dim(self):
241        return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
242
243    def check(self, value):
244        result = self.base_constraint.check(value)
245        if result.dim() < self.reinterpreted_batch_ndims:
246            expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
247            raise ValueError(
248                f"Expected value.dim() >= {expected} but got {value.dim()}"
249            )
250        result = result.reshape(
251            result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
252        )
253        result = result.all(-1)
254        return result
255
256    def __repr__(self):
257        return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
258
259
260class _Boolean(Constraint):
261    """
262    Constrain to the two values `{0, 1}`.
263    """
264
265    is_discrete = True
266
267    def check(self, value):
268        return (value == 0) | (value == 1)
269
270
271class _OneHot(Constraint):
272    """
273    Constrain to one-hot vectors.
274    """
275
276    is_discrete = True
277    event_dim = 1
278
279    def check(self, value):
280        is_boolean = (value == 0) | (value == 1)
281        is_normalized = value.sum(-1).eq(1)
282        return is_boolean.all(-1) & is_normalized
283
284
285class _IntegerInterval(Constraint):
286    """
287    Constrain to an integer interval `[lower_bound, upper_bound]`.
288    """
289
290    is_discrete = True
291
292    def __init__(self, lower_bound, upper_bound):
293        self.lower_bound = lower_bound
294        self.upper_bound = upper_bound
295        super().__init__()
296
297    def check(self, value):
298        return (
299            (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
300        )
301
302    def __repr__(self):
303        fmt_string = self.__class__.__name__[1:]
304        fmt_string += (
305            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
306        )
307        return fmt_string
308
309
310class _IntegerLessThan(Constraint):
311    """
312    Constrain to an integer interval `(-inf, upper_bound]`.
313    """
314
315    is_discrete = True
316
317    def __init__(self, upper_bound):
318        self.upper_bound = upper_bound
319        super().__init__()
320
321    def check(self, value):
322        return (value % 1 == 0) & (value <= self.upper_bound)
323
324    def __repr__(self):
325        fmt_string = self.__class__.__name__[1:]
326        fmt_string += f"(upper_bound={self.upper_bound})"
327        return fmt_string
328
329
330class _IntegerGreaterThan(Constraint):
331    """
332    Constrain to an integer interval `[lower_bound, inf)`.
333    """
334
335    is_discrete = True
336
337    def __init__(self, lower_bound):
338        self.lower_bound = lower_bound
339        super().__init__()
340
341    def check(self, value):
342        return (value % 1 == 0) & (value >= self.lower_bound)
343
344    def __repr__(self):
345        fmt_string = self.__class__.__name__[1:]
346        fmt_string += f"(lower_bound={self.lower_bound})"
347        return fmt_string
348
349
350class _Real(Constraint):
351    """
352    Trivially constrain to the extended real line `[-inf, inf]`.
353    """
354
355    def check(self, value):
356        return value == value  # False for NANs.
357
358
359class _GreaterThan(Constraint):
360    """
361    Constrain to a real half line `(lower_bound, inf]`.
362    """
363
364    def __init__(self, lower_bound):
365        self.lower_bound = lower_bound
366        super().__init__()
367
368    def check(self, value):
369        return self.lower_bound < value
370
371    def __repr__(self):
372        fmt_string = self.__class__.__name__[1:]
373        fmt_string += f"(lower_bound={self.lower_bound})"
374        return fmt_string
375
376
377class _GreaterThanEq(Constraint):
378    """
379    Constrain to a real half line `[lower_bound, inf)`.
380    """
381
382    def __init__(self, lower_bound):
383        self.lower_bound = lower_bound
384        super().__init__()
385
386    def check(self, value):
387        return self.lower_bound <= value
388
389    def __repr__(self):
390        fmt_string = self.__class__.__name__[1:]
391        fmt_string += f"(lower_bound={self.lower_bound})"
392        return fmt_string
393
394
395class _LessThan(Constraint):
396    """
397    Constrain to a real half line `[-inf, upper_bound)`.
398    """
399
400    def __init__(self, upper_bound):
401        self.upper_bound = upper_bound
402        super().__init__()
403
404    def check(self, value):
405        return value < self.upper_bound
406
407    def __repr__(self):
408        fmt_string = self.__class__.__name__[1:]
409        fmt_string += f"(upper_bound={self.upper_bound})"
410        return fmt_string
411
412
413class _Interval(Constraint):
414    """
415    Constrain to a real interval `[lower_bound, upper_bound]`.
416    """
417
418    def __init__(self, lower_bound, upper_bound):
419        self.lower_bound = lower_bound
420        self.upper_bound = upper_bound
421        super().__init__()
422
423    def check(self, value):
424        return (self.lower_bound <= value) & (value <= self.upper_bound)
425
426    def __repr__(self):
427        fmt_string = self.__class__.__name__[1:]
428        fmt_string += (
429            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
430        )
431        return fmt_string
432
433
434class _HalfOpenInterval(Constraint):
435    """
436    Constrain to a real interval `[lower_bound, upper_bound)`.
437    """
438
439    def __init__(self, lower_bound, upper_bound):
440        self.lower_bound = lower_bound
441        self.upper_bound = upper_bound
442        super().__init__()
443
444    def check(self, value):
445        return (self.lower_bound <= value) & (value < self.upper_bound)
446
447    def __repr__(self):
448        fmt_string = self.__class__.__name__[1:]
449        fmt_string += (
450            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
451        )
452        return fmt_string
453
454
455class _Simplex(Constraint):
456    """
457    Constrain to the unit simplex in the innermost (rightmost) dimension.
458    Specifically: `x >= 0` and `x.sum(-1) == 1`.
459    """
460
461    event_dim = 1
462
463    def check(self, value):
464        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
465
466
467class _Multinomial(Constraint):
468    """
469    Constrain to nonnegative integer values summing to at most an upper bound.
470
471    Note due to limitations of the Multinomial distribution, this currently
472    checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
473    this may be strengthened to ``value.sum(-1) == upper_bound``.
474    """
475
476    is_discrete = True
477    event_dim = 1
478
479    def __init__(self, upper_bound):
480        self.upper_bound = upper_bound
481
482    def check(self, x):
483        return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
484
485
486class _LowerTriangular(Constraint):
487    """
488    Constrain to lower-triangular square matrices.
489    """
490
491    event_dim = 2
492
493    def check(self, value):
494        value_tril = value.tril()
495        return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
496
497
498class _LowerCholesky(Constraint):
499    """
500    Constrain to lower-triangular square matrices with positive diagonals.
501    """
502
503    event_dim = 2
504
505    def check(self, value):
506        value_tril = value.tril()
507        lower_triangular = (
508            (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
509        )
510
511        positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
512        return lower_triangular & positive_diagonal
513
514
515class _CorrCholesky(Constraint):
516    """
517    Constrain to lower-triangular square matrices with positive diagonals and each
518    row vector being of unit length.
519    """
520
521    event_dim = 2
522
523    def check(self, value):
524        tol = (
525            torch.finfo(value.dtype).eps * value.size(-1) * 10
526        )  # 10 is an adjustable fudge factor
527        row_norm = torch.linalg.norm(value.detach(), dim=-1)
528        unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
529        return _LowerCholesky().check(value) & unit_row_norm
530
531
532class _Square(Constraint):
533    """
534    Constrain to square matrices.
535    """
536
537    event_dim = 2
538
539    def check(self, value):
540        return torch.full(
541            size=value.shape[:-2],
542            fill_value=(value.shape[-2] == value.shape[-1]),
543            dtype=torch.bool,
544            device=value.device,
545        )
546
547
548class _Symmetric(_Square):
549    """
550    Constrain to Symmetric square matrices.
551    """
552
553    def check(self, value):
554        square_check = super().check(value)
555        if not square_check.all():
556            return square_check
557        return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
558
559
560class _PositiveSemidefinite(_Symmetric):
561    """
562    Constrain to positive-semidefinite matrices.
563    """
564
565    def check(self, value):
566        sym_check = super().check(value)
567        if not sym_check.all():
568            return sym_check
569        return torch.linalg.eigvalsh(value).ge(0).all(-1)
570
571
572class _PositiveDefinite(_Symmetric):
573    """
574    Constrain to positive-definite matrices.
575    """
576
577    def check(self, value):
578        sym_check = super().check(value)
579        if not sym_check.all():
580            return sym_check
581        return torch.linalg.cholesky_ex(value).info.eq(0)
582
583
584class _Cat(Constraint):
585    """
586    Constraint functor that applies a sequence of constraints
587    `cseq` at the submatrices at dimension `dim`,
588    each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
589    """
590
591    def __init__(self, cseq, dim=0, lengths=None):
592        assert all(isinstance(c, Constraint) for c in cseq)
593        self.cseq = list(cseq)
594        if lengths is None:
595            lengths = [1] * len(self.cseq)
596        self.lengths = list(lengths)
597        assert len(self.lengths) == len(self.cseq)
598        self.dim = dim
599        super().__init__()
600
601    @property
602    def is_discrete(self):
603        return any(c.is_discrete for c in self.cseq)
604
605    @property
606    def event_dim(self):
607        return max(c.event_dim for c in self.cseq)
608
609    def check(self, value):
610        assert -value.dim() <= self.dim < value.dim()
611        checks = []
612        start = 0
613        for constr, length in zip(self.cseq, self.lengths):
614            v = value.narrow(self.dim, start, length)
615            checks.append(constr.check(v))
616            start = start + length  # avoid += for jit compat
617        return torch.cat(checks, self.dim)
618
619
620class _Stack(Constraint):
621    """
622    Constraint functor that applies a sequence of constraints
623    `cseq` at the submatrices at dimension `dim`,
624    in a way compatible with :func:`torch.stack`.
625    """
626
627    def __init__(self, cseq, dim=0):
628        assert all(isinstance(c, Constraint) for c in cseq)
629        self.cseq = list(cseq)
630        self.dim = dim
631        super().__init__()
632
633    @property
634    def is_discrete(self):
635        return any(c.is_discrete for c in self.cseq)
636
637    @property
638    def event_dim(self):
639        dim = max(c.event_dim for c in self.cseq)
640        if self.dim + dim < 0:
641            dim += 1
642        return dim
643
644    def check(self, value):
645        assert -value.dim() <= self.dim < value.dim()
646        vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
647        return torch.stack(
648            [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
649        )
650
651
652# Public interface.
653dependent = _Dependent()
654dependent_property = _DependentProperty
655independent = _IndependentConstraint
656boolean = _Boolean()
657one_hot = _OneHot()
658nonnegative_integer = _IntegerGreaterThan(0)
659positive_integer = _IntegerGreaterThan(1)
660integer_interval = _IntegerInterval
661real = _Real()
662real_vector = independent(real, 1)
663positive = _GreaterThan(0.0)
664nonnegative = _GreaterThanEq(0.0)
665greater_than = _GreaterThan
666greater_than_eq = _GreaterThanEq
667less_than = _LessThan
668multinomial = _Multinomial
669unit_interval = _Interval(0.0, 1.0)
670interval = _Interval
671half_open_interval = _HalfOpenInterval
672simplex = _Simplex()
673lower_triangular = _LowerTriangular()
674lower_cholesky = _LowerCholesky()
675corr_cholesky = _CorrCholesky()
676square = _Square()
677symmetric = _Symmetric()
678positive_semidefinite = _PositiveSemidefinite()
679positive_definite = _PositiveDefinite()
680cat = _Cat
681stack = _Stack
682