1# mypy: allow-untyped-defs 2from contextlib import contextmanager 3 4import torch 5from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule 6 7 8__all__ = ["is_available", "flags", "set_flags"] 9 10 11def is_available(): 12 r"""Return whether PyTorch is built with NNPACK support.""" 13 return torch._nnpack_available() 14 15 16def set_flags(_enabled): 17 r"""Set if nnpack is enabled globally""" 18 orig_flags = (torch._C._get_nnpack_enabled(),) 19 torch._C._set_nnpack_enabled(_enabled) 20 return orig_flags 21 22 23@contextmanager 24def flags(enabled=False): 25 r"""Context manager for setting if nnpack is enabled globally""" 26 with __allow_nonbracketed_mutation(): 27 orig_flags = set_flags(enabled) 28 try: 29 yield 30 finally: 31 with __allow_nonbracketed_mutation(): 32 set_flags(orig_flags[0]) 33