1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import List, Tuple 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24class TestHash(JitTestCase): 25 def test_hash_tuple(self): 26 def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool: 27 return hash(t1) == hash(t2) 28 29 self.checkScript(fn, ((1, 2), (1, 2))) 30 self.checkScript(fn, ((1, 2), (3, 4))) 31 self.checkScript(fn, ((1, 2), (2, 1))) 32 33 def test_hash_tuple_nested_unhashable_type(self): 34 # Tuples may contain unhashable types like `list`, check that we error 35 # properly in that case. 36 @torch.jit.script 37 def fn_unhashable(t1: Tuple[int, List[int]]): 38 return hash(t1) 39 40 with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"): 41 fn_unhashable((1, [1])) 42 43 def test_hash_tensor(self): 44 """Tensors should hash by identity""" 45 46 def fn(t1, t2): 47 return hash(t1) == hash(t2) 48 49 tensor1 = torch.tensor(1) 50 tensor1_clone = torch.tensor(1) 51 tensor2 = torch.tensor(2) 52 53 self.checkScript(fn, (tensor1, tensor1)) 54 self.checkScript(fn, (tensor1, tensor1_clone)) 55 self.checkScript(fn, (tensor1, tensor2)) 56 57 def test_hash_none(self): 58 def fn(): 59 n1 = None 60 n2 = None 61 return hash(n1) == hash(n2) 62 63 self.checkScript(fn, ()) 64 65 def test_hash_bool(self): 66 def fn(b1: bool, b2: bool): 67 return hash(b1) == hash(b2) 68 69 self.checkScript(fn, (True, False)) 70 self.checkScript(fn, (True, True)) 71 self.checkScript(fn, (False, True)) 72 self.checkScript(fn, (False, False)) 73 74 def test_hash_float(self): 75 def fn(f1: float, f2: float): 76 return hash(f1) == hash(f2) 77 78 self.checkScript(fn, (1.2345, 1.2345)) 79 self.checkScript(fn, (1.2345, 6.789)) 80 self.checkScript(fn, (1.2345, float("inf"))) 81 self.checkScript(fn, (float("inf"), float("inf"))) 82 self.checkScript(fn, (1.2345, float("nan"))) 83 if sys.version_info < (3, 10): 84 # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html : 85 # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity. 86 self.checkScript(fn, (float("nan"), float("nan"))) 87 self.checkScript(fn, (float("nan"), float("inf"))) 88 89 def test_hash_int(self): 90 def fn(i1: int, i2: int): 91 return hash(i1) == hash(i2) 92 93 self.checkScript(fn, (123, 456)) 94 self.checkScript(fn, (123, 123)) 95 self.checkScript(fn, (123, -123)) 96 self.checkScript(fn, (-123, -123)) 97 self.checkScript(fn, (123, 0)) 98 99 def test_hash_string(self): 100 def fn(s1: str, s2: str): 101 return hash(s1) == hash(s2) 102 103 self.checkScript(fn, ("foo", "foo")) 104 self.checkScript(fn, ("foo", "bar")) 105 self.checkScript(fn, ("foo", "")) 106 107 def test_hash_device(self): 108 def fn(d1: torch.device, d2: torch.device): 109 return hash(d1) == hash(d2) 110 111 gpu0 = torch.device("cuda:0") 112 gpu1 = torch.device("cuda:1") 113 cpu = torch.device("cpu") 114 self.checkScript(fn, (gpu0, gpu0)) 115 self.checkScript(fn, (gpu0, gpu1)) 116 self.checkScript(fn, (gpu0, cpu)) 117 self.checkScript(fn, (cpu, cpu)) 118