1# mypy: allow-untyped-defs 2import sys 3from contextlib import contextmanager 4from typing import TYPE_CHECKING 5 6import torch 7from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule 8 9 10def is_available(): 11 r"""Return whether PyTorch is built with MKL-DNN support.""" 12 return torch._C._has_mkldnn 13 14 15VERBOSE_OFF = 0 16VERBOSE_ON = 1 17VERBOSE_ON_CREATION = 2 18 19 20class verbose: 21 """ 22 On-demand oneDNN (former MKL-DNN) verbosing functionality. 23 24 To make it easier to debug performance issues, oneDNN can dump verbose 25 messages containing information like kernel size, input data size and 26 execution duration while executing the kernel. The verbosing functionality 27 can be invoked via an environment variable named `DNNL_VERBOSE`. However, 28 this methodology dumps messages in all steps. Those are a large amount of 29 verbose messages. Moreover, for investigating the performance issues, 30 generally taking verbose messages for one single iteration is enough. 31 This on-demand verbosing functionality makes it possible to control scope 32 for verbose message dumping. In the following example, verbose messages 33 will be dumped out for the second inference only. 34 35 .. highlight:: python 36 .. code-block:: python 37 38 import torch 39 model(data) 40 with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): 41 model(data) 42 43 Args: 44 level: Verbose level 45 - ``VERBOSE_OFF``: Disable verbosing 46 - ``VERBOSE_ON``: Enable verbosing 47 - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation 48 """ 49 50 def __init__(self, level): 51 self.level = level 52 53 def __enter__(self): 54 if self.level == VERBOSE_OFF: 55 return 56 st = torch._C._verbose.mkldnn_set_verbose(self.level) 57 assert ( 58 st 59 ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." 60 return self 61 62 def __exit__(self, exc_type, exc_val, exc_tb): 63 torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF) 64 return False 65 66 67def set_flags(_enabled, _deterministic=None): 68 orig_flags = (torch._C._get_mkldnn_enabled(), torch._C._get_mkldnn_deterministic()) 69 torch._C._set_mkldnn_enabled(_enabled) 70 if _deterministic is not None: 71 torch._C._set_mkldnn_deterministic(_deterministic) 72 return orig_flags 73 74 75@contextmanager 76def flags(enabled=False, deterministic=False): 77 with __allow_nonbracketed_mutation(): 78 orig_flags = set_flags(enabled, deterministic) 79 try: 80 yield 81 finally: 82 with __allow_nonbracketed_mutation(): 83 set_flags(*orig_flags) 84 85 86class MkldnnModule(PropModule): 87 def __init__(self, m, name): 88 super().__init__(m, name) 89 90 enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) 91 deterministic = ContextProp( 92 torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic 93 ) 94 95 96if TYPE_CHECKING: 97 enabled: ContextProp 98 deterministic: ContextProp 99 100sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__) 101