xref: /aosp_15_r20/external/pytorch/torch/backends/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import types
3from contextlib import contextmanager
4
5
6# The idea for this parameter is that we forbid bare assignment
7# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
8# test suite, where it's very easy to forget to undo the change
9# later.
10__allow_nonbracketed_mutation_flag = True
11
12
13def disable_global_flags():
14    global __allow_nonbracketed_mutation_flag
15    __allow_nonbracketed_mutation_flag = False
16
17
18def flags_frozen():
19    return not __allow_nonbracketed_mutation_flag
20
21
22@contextmanager
23def __allow_nonbracketed_mutation():
24    global __allow_nonbracketed_mutation_flag
25    old = __allow_nonbracketed_mutation_flag
26    __allow_nonbracketed_mutation_flag = True
27    try:
28        yield
29    finally:
30        __allow_nonbracketed_mutation_flag = old
31
32
33class ContextProp:
34    def __init__(self, getter, setter):
35        self.getter = getter
36        self.setter = setter
37
38    def __get__(self, obj, objtype):
39        return self.getter()
40
41    def __set__(self, obj, val):
42        if not flags_frozen():
43            self.setter(val)
44        else:
45            raise RuntimeError(
46                f"not allowed to set {obj.__name__} flags "
47                "after disable_global_flags; please use flags() context manager instead"
48            )
49
50
51class PropModule(types.ModuleType):
52    def __init__(self, m, name):
53        super().__init__(name)
54        self.m = m
55
56    def __getattr__(self, attr):
57        return self.m.__getattribute__(attr)
58
59
60from torch.backends import (
61    cpu as cpu,
62    cuda as cuda,
63    cudnn as cudnn,
64    cusparselt as cusparselt,
65    mha as mha,
66    mkl as mkl,
67    mkldnn as mkldnn,
68    mps as mps,
69    nnpack as nnpack,
70    openmp as openmp,
71    quantized as quantized,
72)
73