xref: /aosp_15_r20/external/pytorch/test/jit/test_hash.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import List, Tuple
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24class TestHash(JitTestCase):
25    def test_hash_tuple(self):
26        def fn(t1: Tuple[int, int], t2: Tuple[int, int]) -> bool:
27            return hash(t1) == hash(t2)
28
29        self.checkScript(fn, ((1, 2), (1, 2)))
30        self.checkScript(fn, ((1, 2), (3, 4)))
31        self.checkScript(fn, ((1, 2), (2, 1)))
32
33    def test_hash_tuple_nested_unhashable_type(self):
34        # Tuples may contain unhashable types like `list`, check that we error
35        # properly in that case.
36        @torch.jit.script
37        def fn_unhashable(t1: Tuple[int, List[int]]):
38            return hash(t1)
39
40        with self.assertRaisesRegexWithHighlight(RuntimeError, "unhashable", "hash"):
41            fn_unhashable((1, [1]))
42
43    def test_hash_tensor(self):
44        """Tensors should hash by identity"""
45
46        def fn(t1, t2):
47            return hash(t1) == hash(t2)
48
49        tensor1 = torch.tensor(1)
50        tensor1_clone = torch.tensor(1)
51        tensor2 = torch.tensor(2)
52
53        self.checkScript(fn, (tensor1, tensor1))
54        self.checkScript(fn, (tensor1, tensor1_clone))
55        self.checkScript(fn, (tensor1, tensor2))
56
57    def test_hash_none(self):
58        def fn():
59            n1 = None
60            n2 = None
61            return hash(n1) == hash(n2)
62
63        self.checkScript(fn, ())
64
65    def test_hash_bool(self):
66        def fn(b1: bool, b2: bool):
67            return hash(b1) == hash(b2)
68
69        self.checkScript(fn, (True, False))
70        self.checkScript(fn, (True, True))
71        self.checkScript(fn, (False, True))
72        self.checkScript(fn, (False, False))
73
74    def test_hash_float(self):
75        def fn(f1: float, f2: float):
76            return hash(f1) == hash(f2)
77
78        self.checkScript(fn, (1.2345, 1.2345))
79        self.checkScript(fn, (1.2345, 6.789))
80        self.checkScript(fn, (1.2345, float("inf")))
81        self.checkScript(fn, (float("inf"), float("inf")))
82        self.checkScript(fn, (1.2345, float("nan")))
83        if sys.version_info < (3, 10):
84            # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
85            # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
86            self.checkScript(fn, (float("nan"), float("nan")))
87        self.checkScript(fn, (float("nan"), float("inf")))
88
89    def test_hash_int(self):
90        def fn(i1: int, i2: int):
91            return hash(i1) == hash(i2)
92
93        self.checkScript(fn, (123, 456))
94        self.checkScript(fn, (123, 123))
95        self.checkScript(fn, (123, -123))
96        self.checkScript(fn, (-123, -123))
97        self.checkScript(fn, (123, 0))
98
99    def test_hash_string(self):
100        def fn(s1: str, s2: str):
101            return hash(s1) == hash(s2)
102
103        self.checkScript(fn, ("foo", "foo"))
104        self.checkScript(fn, ("foo", "bar"))
105        self.checkScript(fn, ("foo", ""))
106
107    def test_hash_device(self):
108        def fn(d1: torch.device, d2: torch.device):
109            return hash(d1) == hash(d2)
110
111        gpu0 = torch.device("cuda:0")
112        gpu1 = torch.device("cuda:1")
113        cpu = torch.device("cpu")
114        self.checkScript(fn, (gpu0, gpu0))
115        self.checkScript(fn, (gpu0, gpu1))
116        self.checkScript(fn, (gpu0, cpu))
117        self.checkScript(fn, (cpu, cpu))
118