1# mypy: allow-untyped-defs 2import types 3from contextlib import contextmanager 4 5 6# The idea for this parameter is that we forbid bare assignment 7# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our 8# test suite, where it's very easy to forget to undo the change 9# later. 10__allow_nonbracketed_mutation_flag = True 11 12 13def disable_global_flags(): 14 global __allow_nonbracketed_mutation_flag 15 __allow_nonbracketed_mutation_flag = False 16 17 18def flags_frozen(): 19 return not __allow_nonbracketed_mutation_flag 20 21 22@contextmanager 23def __allow_nonbracketed_mutation(): 24 global __allow_nonbracketed_mutation_flag 25 old = __allow_nonbracketed_mutation_flag 26 __allow_nonbracketed_mutation_flag = True 27 try: 28 yield 29 finally: 30 __allow_nonbracketed_mutation_flag = old 31 32 33class ContextProp: 34 def __init__(self, getter, setter): 35 self.getter = getter 36 self.setter = setter 37 38 def __get__(self, obj, objtype): 39 return self.getter() 40 41 def __set__(self, obj, val): 42 if not flags_frozen(): 43 self.setter(val) 44 else: 45 raise RuntimeError( 46 f"not allowed to set {obj.__name__} flags " 47 "after disable_global_flags; please use flags() context manager instead" 48 ) 49 50 51class PropModule(types.ModuleType): 52 def __init__(self, m, name): 53 super().__init__(name) 54 self.m = m 55 56 def __getattr__(self, attr): 57 return self.m.__getattribute__(attr) 58 59 60from torch.backends import ( 61 cpu as cpu, 62 cuda as cuda, 63 cudnn as cudnn, 64 cusparselt as cusparselt, 65 mha as mha, 66 mkl as mkl, 67 mkldnn as mkldnn, 68 mps as mps, 69 nnpack as nnpack, 70 openmp as openmp, 71 quantized as quantized, 72) 73