xref: /aosp_15_r20/external/pytorch/torch/backends/nnpack/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3
4import torch
5from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
6
7
8__all__ = ["is_available", "flags", "set_flags"]
9
10
11def is_available():
12    r"""Return whether PyTorch is built with NNPACK support."""
13    return torch._nnpack_available()
14
15
16def set_flags(_enabled):
17    r"""Set if nnpack is enabled globally"""
18    orig_flags = (torch._C._get_nnpack_enabled(),)
19    torch._C._set_nnpack_enabled(_enabled)
20    return orig_flags
21
22
23@contextmanager
24def flags(enabled=False):
25    r"""Context manager for setting if nnpack is enabled globally"""
26    with __allow_nonbracketed_mutation():
27        orig_flags = set_flags(enabled)
28    try:
29        yield
30    finally:
31        with __allow_nonbracketed_mutation():
32            set_flags(orig_flags[0])
33