xref: /aosp_15_r20/external/pytorch/test/test_comparison_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: internals"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass TestComparisonUtils(TestCase):
9*da0073e9SAndroid Build Coastguard Worker    def test_all_equal_no_assert(self):
10*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0.5])
11*da0073e9SAndroid Build Coastguard Worker        torch._assert_tensor_metadata(t, [1], [1], torch.float)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    def test_all_equal_no_assert_nones(self):
14*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0.5])
15*da0073e9SAndroid Build Coastguard Worker        torch._assert_tensor_metadata(t, None, None, None)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    def test_assert_dtype(self):
18*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0.5])
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
21*da0073e9SAndroid Build Coastguard Worker            torch._assert_tensor_metadata(t, None, None, torch.int32)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def test_assert_strides(self):
24*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0.5])
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
27*da0073e9SAndroid Build Coastguard Worker            torch._assert_tensor_metadata(t, None, [3], torch.float)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def test_assert_sizes(self):
30*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0.5])
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
33*da0073e9SAndroid Build Coastguard Worker            torch._assert_tensor_metadata(t, [3], [1], torch.float)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
37*da0073e9SAndroid Build Coastguard Worker    run_tests()
38