xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_mkldnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import contextlib
4import functools
5import inspect
6
7import torch
8
9
10# Test whether hardware BF32 math mode enabled. It is enabled only on:
11# - MKLDNN is available
12# - BF16 is supported by MKLDNN
13def bf32_is_not_fp32():
14    if not torch.backends.mkldnn.is_available():
15        return False
16    if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
17        return False
18    return True
19
20
21@contextlib.contextmanager
22def bf32_off():
23    old_matmul_precision = torch.get_float32_matmul_precision()
24    try:
25        torch.set_float32_matmul_precision("highest")
26        yield
27    finally:
28        torch.set_float32_matmul_precision(old_matmul_precision)
29
30
31@contextlib.contextmanager
32def bf32_on(self, bf32_precision=1e-5):
33    old_matmul_precision = torch.get_float32_matmul_precision()
34    old_precision = self.precision
35    try:
36        torch.set_float32_matmul_precision("medium")
37        self.precision = bf32_precision
38        yield
39    finally:
40        torch.set_float32_matmul_precision(old_matmul_precision)
41        self.precision = old_precision
42
43
44# This is a wrapper that wraps a test to run this test twice, one with
45# allow_bf32=True, another with allow_bf32=False. When running with
46# allow_bf32=True, it will use reduced precision as specified by the
47# argument
48def bf32_on_and_off(bf32_precision=1e-5):
49    def with_bf32_disabled(self, function_call):
50        with bf32_off():
51            function_call()
52
53    def with_bf32_enabled(self, function_call):
54        with bf32_on(self, bf32_precision):
55            function_call()
56
57    def wrapper(f):
58        params = inspect.signature(f).parameters
59        arg_names = tuple(params.keys())
60
61        @functools.wraps(f)
62        def wrapped(*args, **kwargs):
63            for k, v in zip(arg_names, args):
64                kwargs[k] = v
65            cond = bf32_is_not_fp32()
66            if "device" in kwargs:
67                cond = cond and (torch.device(kwargs["device"]).type == "cpu")
68            if "dtype" in kwargs:
69                cond = cond and (kwargs["dtype"] == torch.float)
70            if cond:
71                with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
72                with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
73            else:
74                f(**kwargs)
75
76        return wrapped
77
78    return wrapper
79