xref: /aosp_15_r20/external/pytorch/test/distributions/test_constraints.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: distributions"]
2
3import pytest
4
5import torch
6from torch.distributions import biject_to, constraints, transform_to
7from torch.testing._internal.common_cuda import TEST_CUDA
8from torch.testing._internal.common_utils import run_tests
9
10
11EXAMPLES = [
12    (constraints.symmetric, False, [[2.0, 0], [2.0, 2]]),
13    (constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]),
14    (constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]),
15    (constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]),
16    (constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]),
17    (constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]),
18    (constraints.symmetric, True, [[1.0, 2], [2.0, 4]]),
19    (constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]),
20    (constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]),
21    (constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]),
22    (
23        constraints.positive_semidefinite,
24        False,
25        [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
26    ),
27    (
28        constraints.positive_definite,
29        False,
30        [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]],
31    ),
32    (constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]),
33    (
34        constraints.positive_semidefinite,
35        True,
36        [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
37    ),
38    (
39        constraints.positive_definite,
40        False,
41        [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]],
42    ),
43    (constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]),
44    (
45        constraints.positive_semidefinite,
46        True,
47        [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
48    ),
49    (
50        constraints.positive_definite,
51        True,
52        [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]],
53    ),
54]
55
56CONSTRAINTS = [
57    (constraints.real,),
58    (constraints.real_vector,),
59    (constraints.positive,),
60    (constraints.greater_than, [-10.0, -2, 0, 2, 10]),
61    (constraints.greater_than, 0),
62    (constraints.greater_than, 2),
63    (constraints.greater_than, -2),
64    (constraints.greater_than_eq, 0),
65    (constraints.greater_than_eq, 2),
66    (constraints.greater_than_eq, -2),
67    (constraints.less_than, [-10.0, -2, 0, 2, 10]),
68    (constraints.less_than, 0),
69    (constraints.less_than, 2),
70    (constraints.less_than, -2),
71    (constraints.unit_interval,),
72    (constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
73    (constraints.interval, -2, -1),
74    (constraints.interval, 1, 2),
75    (constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]),
76    (constraints.half_open_interval, -2, -1),
77    (constraints.half_open_interval, 1, 2),
78    (constraints.simplex,),
79    (constraints.corr_cholesky,),
80    (constraints.lower_cholesky,),
81    (constraints.positive_definite,),
82]
83
84
85def build_constraint(constraint_fn, args, is_cuda=False):
86    if not args:
87        return constraint_fn
88    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
89    return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
90
91
92@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES)
93@pytest.mark.parametrize(
94    "is_cuda",
95    [
96        False,
97        pytest.param(
98            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
99        ),
100    ],
101)
102def test_constraint(constraint_fn, result, value, is_cuda):
103    t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
104    assert constraint_fn.check(t(value)).all() == result
105
106
107@pytest.mark.parametrize(
108    ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
109)
110@pytest.mark.parametrize(
111    "is_cuda",
112    [
113        False,
114        pytest.param(
115            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
116        ),
117    ],
118)
119def test_biject_to(constraint_fn, args, is_cuda):
120    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
121    try:
122        t = biject_to(constraint)
123    except NotImplementedError:
124        pytest.skip("`biject_to` not implemented.")
125    assert t.bijective, f"biject_to({constraint}) is not bijective"
126    if constraint_fn is constraints.corr_cholesky:
127        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
128        x = torch.randn(6, 6, dtype=torch.double)
129    else:
130        x = torch.randn(5, 5, dtype=torch.double)
131    if is_cuda:
132        x = x.cuda()
133    y = t(x)
134    assert constraint.check(y).all(), "\n".join(
135        [
136            f"Failed to biject_to({constraint})",
137            f"x = {x}",
138            f"biject_to(...)(x) = {y}",
139        ]
140    )
141    x2 = t.inv(y)
142    assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse"
143
144    j = t.log_abs_det_jacobian(x, y)
145    assert j.shape == x.shape[: x.dim() - t.domain.event_dim]
146
147
148@pytest.mark.parametrize(
149    ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS]
150)
151@pytest.mark.parametrize(
152    "is_cuda",
153    [
154        False,
155        pytest.param(
156            True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.")
157        ),
158    ],
159)
160def test_transform_to(constraint_fn, args, is_cuda):
161    constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
162    t = transform_to(constraint)
163    if constraint_fn is constraints.corr_cholesky:
164        # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
165        x = torch.randn(6, 6, dtype=torch.double)
166    else:
167        x = torch.randn(5, 5, dtype=torch.double)
168    if is_cuda:
169        x = x.cuda()
170    y = t(x)
171    assert constraint.check(y).all(), f"Failed to transform_to({constraint})"
172    x2 = t.inv(y)
173    y2 = t(x2)
174    assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse"
175
176
177if __name__ == "__main__":
178    run_tests()
179