1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: distributions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport pytest 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker@pytest.mark.parametrize( 11*da0073e9SAndroid Build Coastguard Worker "shape", 12*da0073e9SAndroid Build Coastguard Worker [ 13*da0073e9SAndroid Build Coastguard Worker (2, 2), 14*da0073e9SAndroid Build Coastguard Worker (3, 3), 15*da0073e9SAndroid Build Coastguard Worker (2, 4, 4), 16*da0073e9SAndroid Build Coastguard Worker (2, 2, 4, 4), 17*da0073e9SAndroid Build Coastguard Worker ], 18*da0073e9SAndroid Build Coastguard Worker) 19*da0073e9SAndroid Build Coastguard Workerdef test_tril_matrix_to_vec(shape): 20*da0073e9SAndroid Build Coastguard Worker mat = torch.randn(shape) 21*da0073e9SAndroid Build Coastguard Worker n = mat.shape[-1] 22*da0073e9SAndroid Build Coastguard Worker for diag in range(-n, n): 23*da0073e9SAndroid Build Coastguard Worker actual = mat.tril(diag) 24*da0073e9SAndroid Build Coastguard Worker vec = tril_matrix_to_vec(actual, diag) 25*da0073e9SAndroid Build Coastguard Worker tril_mat = vec_to_tril_matrix(vec, diag) 26*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(tril_mat, actual) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 30*da0073e9SAndroid Build Coastguard Worker run_tests() 31