1# Owner(s): ["module: distributions"] 2 3import pytest 4 5import torch 6from torch.distributions import biject_to, constraints, transform_to 7from torch.testing._internal.common_cuda import TEST_CUDA 8from torch.testing._internal.common_utils import run_tests 9 10 11EXAMPLES = [ 12 (constraints.symmetric, False, [[2.0, 0], [2.0, 2]]), 13 (constraints.positive_semidefinite, False, [[2.0, 0], [2.0, 2]]), 14 (constraints.positive_definite, False, [[2.0, 0], [2.0, 2]]), 15 (constraints.symmetric, True, [[3.0, -5], [-5.0, 3]]), 16 (constraints.positive_semidefinite, False, [[3.0, -5], [-5.0, 3]]), 17 (constraints.positive_definite, False, [[3.0, -5], [-5.0, 3]]), 18 (constraints.symmetric, True, [[1.0, 2], [2.0, 4]]), 19 (constraints.positive_semidefinite, True, [[1.0, 2], [2.0, 4]]), 20 (constraints.positive_definite, False, [[1.0, 2], [2.0, 4]]), 21 (constraints.symmetric, True, [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]]), 22 ( 23 constraints.positive_semidefinite, 24 False, 25 [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]], 26 ), 27 ( 28 constraints.positive_definite, 29 False, 30 [[[1.0, -2], [-2.0, 1]], [[2.0, 3], [3.0, 2]]], 31 ), 32 (constraints.symmetric, True, [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]]), 33 ( 34 constraints.positive_semidefinite, 35 True, 36 [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]], 37 ), 38 ( 39 constraints.positive_definite, 40 False, 41 [[[1.0, -2], [-2.0, 4]], [[1.0, -1], [-1.0, 1]]], 42 ), 43 (constraints.symmetric, True, [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]]), 44 ( 45 constraints.positive_semidefinite, 46 True, 47 [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]], 48 ), 49 ( 50 constraints.positive_definite, 51 True, 52 [[[4.0, 2], [2.0, 4]], [[3.0, -1], [-1.0, 3]]], 53 ), 54] 55 56CONSTRAINTS = [ 57 (constraints.real,), 58 (constraints.real_vector,), 59 (constraints.positive,), 60 (constraints.greater_than, [-10.0, -2, 0, 2, 10]), 61 (constraints.greater_than, 0), 62 (constraints.greater_than, 2), 63 (constraints.greater_than, -2), 64 (constraints.greater_than_eq, 0), 65 (constraints.greater_than_eq, 2), 66 (constraints.greater_than_eq, -2), 67 (constraints.less_than, [-10.0, -2, 0, 2, 10]), 68 (constraints.less_than, 0), 69 (constraints.less_than, 2), 70 (constraints.less_than, -2), 71 (constraints.unit_interval,), 72 (constraints.interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]), 73 (constraints.interval, -2, -1), 74 (constraints.interval, 1, 2), 75 (constraints.half_open_interval, [-4.0, -2, 0, 2, 4], [-3.0, 3, 1, 5, 5]), 76 (constraints.half_open_interval, -2, -1), 77 (constraints.half_open_interval, 1, 2), 78 (constraints.simplex,), 79 (constraints.corr_cholesky,), 80 (constraints.lower_cholesky,), 81 (constraints.positive_definite,), 82] 83 84 85def build_constraint(constraint_fn, args, is_cuda=False): 86 if not args: 87 return constraint_fn 88 t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor 89 return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args)) 90 91 92@pytest.mark.parametrize(("constraint_fn", "result", "value"), EXAMPLES) 93@pytest.mark.parametrize( 94 "is_cuda", 95 [ 96 False, 97 pytest.param( 98 True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 99 ), 100 ], 101) 102def test_constraint(constraint_fn, result, value, is_cuda): 103 t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor 104 assert constraint_fn.check(t(value)).all() == result 105 106 107@pytest.mark.parametrize( 108 ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS] 109) 110@pytest.mark.parametrize( 111 "is_cuda", 112 [ 113 False, 114 pytest.param( 115 True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 116 ), 117 ], 118) 119def test_biject_to(constraint_fn, args, is_cuda): 120 constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) 121 try: 122 t = biject_to(constraint) 123 except NotImplementedError: 124 pytest.skip("`biject_to` not implemented.") 125 assert t.bijective, f"biject_to({constraint}) is not bijective" 126 if constraint_fn is constraints.corr_cholesky: 127 # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) 128 x = torch.randn(6, 6, dtype=torch.double) 129 else: 130 x = torch.randn(5, 5, dtype=torch.double) 131 if is_cuda: 132 x = x.cuda() 133 y = t(x) 134 assert constraint.check(y).all(), "\n".join( 135 [ 136 f"Failed to biject_to({constraint})", 137 f"x = {x}", 138 f"biject_to(...)(x) = {y}", 139 ] 140 ) 141 x2 = t.inv(y) 142 assert torch.allclose(x, x2), f"Error in biject_to({constraint}) inverse" 143 144 j = t.log_abs_det_jacobian(x, y) 145 assert j.shape == x.shape[: x.dim() - t.domain.event_dim] 146 147 148@pytest.mark.parametrize( 149 ("constraint_fn", "args"), [(c[0], c[1:]) for c in CONSTRAINTS] 150) 151@pytest.mark.parametrize( 152 "is_cuda", 153 [ 154 False, 155 pytest.param( 156 True, marks=pytest.mark.skipif(not TEST_CUDA, reason="CUDA not found.") 157 ), 158 ], 159) 160def test_transform_to(constraint_fn, args, is_cuda): 161 constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) 162 t = transform_to(constraint) 163 if constraint_fn is constraints.corr_cholesky: 164 # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) 165 x = torch.randn(6, 6, dtype=torch.double) 166 else: 167 x = torch.randn(5, 5, dtype=torch.double) 168 if is_cuda: 169 x = x.cuda() 170 y = t(x) 171 assert constraint.check(y).all(), f"Failed to transform_to({constraint})" 172 x2 = t.inv(y) 173 y2 = t(x2) 174 assert torch.allclose(y, y2), f"Error in transform_to({constraint}) pseudoinverse" 175 176 177if __name__ == "__main__": 178 run_tests() 179