1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: internals"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass TestComparisonUtils(TestCase): 9*da0073e9SAndroid Build Coastguard Worker def test_all_equal_no_assert(self): 10*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0.5]) 11*da0073e9SAndroid Build Coastguard Worker torch._assert_tensor_metadata(t, [1], [1], torch.float) 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker def test_all_equal_no_assert_nones(self): 14*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0.5]) 15*da0073e9SAndroid Build Coastguard Worker torch._assert_tensor_metadata(t, None, None, None) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker def test_assert_dtype(self): 18*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0.5]) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 21*da0073e9SAndroid Build Coastguard Worker torch._assert_tensor_metadata(t, None, None, torch.int32) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def test_assert_strides(self): 24*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0.5]) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 27*da0073e9SAndroid Build Coastguard Worker torch._assert_tensor_metadata(t, None, [3], torch.float) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def test_assert_sizes(self): 30*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0.5]) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 33*da0073e9SAndroid Build Coastguard Worker torch._assert_tensor_metadata(t, [3], [1], torch.float) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 37*da0073e9SAndroid Build Coastguard Worker run_tests() 38