xref: /aosp_15_r20/external/pytorch/torch/utils/_cpp_extension_versioner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3
4
5Entry = collections.namedtuple('Entry', 'version, hash')
6
7
8def update_hash(seed, value):
9    # Good old boost::hash_combine
10    # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
11    return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
12
13
14def hash_source_files(hash_value, source_files):
15    for filename in source_files:
16        with open(filename) as file:
17            hash_value = update_hash(hash_value, file.read())
18    return hash_value
19
20
21def hash_build_arguments(hash_value, build_arguments):
22    for group in build_arguments:
23        if group:
24            for argument in group:
25                hash_value = update_hash(hash_value, argument)
26    return hash_value
27
28
29class ExtensionVersioner:
30    def __init__(self):
31        self.entries = {}
32
33    def get_version(self, name):
34        entry = self.entries.get(name)
35        return None if entry is None else entry.version
36
37    def bump_version_if_changed(self,
38                                name,
39                                source_files,
40                                build_arguments,
41                                build_directory,
42                                with_cuda,
43                                is_python_module,
44                                is_standalone):
45        hash_value = 0
46        hash_value = hash_source_files(hash_value, source_files)
47        hash_value = hash_build_arguments(hash_value, build_arguments)
48        hash_value = update_hash(hash_value, build_directory)
49        hash_value = update_hash(hash_value, with_cuda)
50        hash_value = update_hash(hash_value, is_python_module)
51        hash_value = update_hash(hash_value, is_standalone)
52
53        entry = self.entries.get(name)
54        if entry is None:
55            self.entries[name] = entry = Entry(0, hash_value)
56        elif hash_value != entry.hash:
57            self.entries[name] = entry = Entry(entry.version + 1, hash_value)
58
59        return entry.version
60