1# mypy: ignore-errors 2 3import contextlib 4import functools 5import inspect 6 7import torch 8 9 10# Test whether hardware BF32 math mode enabled. It is enabled only on: 11# - MKLDNN is available 12# - BF16 is supported by MKLDNN 13def bf32_is_not_fp32(): 14 if not torch.backends.mkldnn.is_available(): 15 return False 16 if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): 17 return False 18 return True 19 20 21@contextlib.contextmanager 22def bf32_off(): 23 old_matmul_precision = torch.get_float32_matmul_precision() 24 try: 25 torch.set_float32_matmul_precision("highest") 26 yield 27 finally: 28 torch.set_float32_matmul_precision(old_matmul_precision) 29 30 31@contextlib.contextmanager 32def bf32_on(self, bf32_precision=1e-5): 33 old_matmul_precision = torch.get_float32_matmul_precision() 34 old_precision = self.precision 35 try: 36 torch.set_float32_matmul_precision("medium") 37 self.precision = bf32_precision 38 yield 39 finally: 40 torch.set_float32_matmul_precision(old_matmul_precision) 41 self.precision = old_precision 42 43 44# This is a wrapper that wraps a test to run this test twice, one with 45# allow_bf32=True, another with allow_bf32=False. When running with 46# allow_bf32=True, it will use reduced precision as specified by the 47# argument 48def bf32_on_and_off(bf32_precision=1e-5): 49 def with_bf32_disabled(self, function_call): 50 with bf32_off(): 51 function_call() 52 53 def with_bf32_enabled(self, function_call): 54 with bf32_on(self, bf32_precision): 55 function_call() 56 57 def wrapper(f): 58 params = inspect.signature(f).parameters 59 arg_names = tuple(params.keys()) 60 61 @functools.wraps(f) 62 def wrapped(*args, **kwargs): 63 for k, v in zip(arg_names, args): 64 kwargs[k] = v 65 cond = bf32_is_not_fp32() 66 if "device" in kwargs: 67 cond = cond and (torch.device(kwargs["device"]).type == "cpu") 68 if "dtype" in kwargs: 69 cond = cond and (kwargs["dtype"] == torch.float) 70 if cond: 71 with_bf32_disabled(kwargs["self"], lambda: f(**kwargs)) 72 with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) 73 else: 74 f(**kwargs) 75 76 return wrapped 77 78 return wrapper 79