xref: /aosp_15_r20/external/pytorch/test/distributions/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport pytest
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize(
11*da0073e9SAndroid Build Coastguard Worker    "shape",
12*da0073e9SAndroid Build Coastguard Worker    [
13*da0073e9SAndroid Build Coastguard Worker        (2, 2),
14*da0073e9SAndroid Build Coastguard Worker        (3, 3),
15*da0073e9SAndroid Build Coastguard Worker        (2, 4, 4),
16*da0073e9SAndroid Build Coastguard Worker        (2, 2, 4, 4),
17*da0073e9SAndroid Build Coastguard Worker    ],
18*da0073e9SAndroid Build Coastguard Worker)
19*da0073e9SAndroid Build Coastguard Workerdef test_tril_matrix_to_vec(shape):
20*da0073e9SAndroid Build Coastguard Worker    mat = torch.randn(shape)
21*da0073e9SAndroid Build Coastguard Worker    n = mat.shape[-1]
22*da0073e9SAndroid Build Coastguard Worker    for diag in range(-n, n):
23*da0073e9SAndroid Build Coastguard Worker        actual = mat.tril(diag)
24*da0073e9SAndroid Build Coastguard Worker        vec = tril_matrix_to_vec(actual, diag)
25*da0073e9SAndroid Build Coastguard Worker        tril_mat = vec_to_tril_matrix(vec, diag)
26*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(tril_mat, actual)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
30*da0073e9SAndroid Build Coastguard Worker    run_tests()
31