xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, cast, List, Optional, Union
3
4import torch
5import torch._dynamo
6from torch._dynamo.testing import CompileCounterWithBackend
7from torch.utils.benchmark import Timer
8
9
10__all__ = ["bench_all", "benchmark_compile"]
11
12
13_warned_tensor_cores = False
14_default_float_32_precision = torch.get_float32_matmul_precision()
15
16try:
17    from tabulate import tabulate
18
19    HAS_TABULATE = True
20except ModuleNotFoundError:
21    HAS_TABULATE = False
22    tabulate = None  # type: ignore[assignment]
23    print("tabulate is not installed, please pip install tabulate to use this utility")
24
25if HAS_TABULATE:
26    def _enable_tensor_cores():
27        global _warned_tensor_cores
28
29        if torch.cuda.is_available():
30            if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0):
31                torch.set_float32_matmul_precision("high")
32                if not _warned_tensor_cores:
33                    print("Your GPU supports tensor cores")
34                    print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")
35                    _warned_tensor_cores = True
36
37    def _disable_tensor_cores():
38        torch.set_float32_matmul_precision(_default_float_32_precision)
39
40    def bench_loop(
41        model: Union[torch.nn.Module, Callable],
42        sample_input: Union[torch.Tensor, Any],
43        num_iters: int = 5,
44        optimizer: Optional[torch.optim.Optimizer] = None,
45        loss_fn: Optional[Callable] = None,
46    ):
47        # Define the statement and setup for the benchmark
48        if optimizer and loss_fn:
49            # Training mode
50            stmt = """
51    output = model(sample_input)
52    loss = loss_fn(output) if loss_fn else output.sum()
53    loss.backward()
54    optimizer.step()
55    optimizer.zero_grad()
56            """
57        else:
58            # Inference mode
59            stmt = "model(sample_input)"
60
61        # Create the Timer object
62        timer = Timer(
63            stmt=stmt,
64            globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn},
65        )
66
67
68        result = timer.timeit(number=num_iters)
69
70        # Get the average time per iteration in milliseconds
71        avg_time = result.mean * 1000
72        return round(avg_time, 2)
73
74    def benchmark_compile(
75        model: Union[torch.nn.Module, Callable],
76        sample_input: Union[torch.Tensor, Any],
77        num_iters: int = 5,
78        backend: Optional[str] = None,
79        mode: Optional[str] = "default",
80        optimizer: Optional[torch.optim.Optimizer] = None,
81        loss_fn : Union[torch.nn.Module, Callable, None] = None,
82    ):
83        """
84        Use this utility to benchmark torch.compile
85        """
86        if backend:
87            try:
88                torch._dynamo.reset()
89                compile_counter_with_backend = CompileCounterWithBackend(backend)
90                opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode)
91
92                # Compilation only happens after the first inference
93                compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn)
94
95                running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
96
97                if compile_counter_with_backend.frame_count == 0:
98                    raise RuntimeError("No compilation occurred during benchmarking.")
99
100                if compile_counter_with_backend.frame_count > 1:
101                    raise RuntimeError("Recompilation occurred during benchmarking.")
102
103            except Exception as e:
104                print(e)
105                print(f"Failed to compile {backend} with mode {mode}")
106                return None, None
107        else:
108            opt_model = model
109            compilation_time = None
110            running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
111
112        compilation_time = round(compilation_time, 2) if compilation_time else None
113        running_time = round(running_time, 2) if running_time else None
114
115
116        return compilation_time, running_time
117
118
119    def bench_all(
120        model : Union[torch.nn.Module, Callable],
121        sample_input: Union[torch.Tensor, Any],
122        num_iters : int = 5,
123        optimizer: Optional[torch.optim.Optimizer] = None,
124        loss_fn : Union[torch.nn.Module, Callable, None] = None,
125    ):
126        """
127        This is a simple utility that can be used to benchmark torch.compile
128        In particular it ensures that your GPU is setup to use tensor cores if it supports its
129        It also tries out all the main backends and prints a table of results so you can easily compare them all
130        Many of the backendds have their own optional dependencies so please pip install them seperately
131
132        You will get one table for inference and another for training
133        If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer
134
135        The important warnings are
136        Your GPU supports tensor cores
137        we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`
138
139        If a compilation fails for any reason including the dependency not being included
140        then we will print Failed to compile {backend} with mode {mode}
141        """
142        field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"]
143        table = []
144
145
146        eager_time = None
147        torch._dynamo.reset()
148        _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer)
149        table.append(
150            [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"]
151        )
152
153        for backend in torch._dynamo.list_backends():
154
155            if backend == "inductor":
156                mode_options = cast(List[Optional[str]], list(torch._inductor.list_mode_options().keys())) + [None]
157                for mode in mode_options:
158                    if mode == "default":
159                        continue
160                    torch._dynamo.reset()
161                    try:
162                        if torch.cuda.is_available():
163                            _enable_tensor_cores()
164                        compilation_time, running_time = benchmark_compile(
165                            model, sample_input, num_iters, backend, mode, optimizer, loss_fn)
166                    finally:
167                        if torch.cuda.is_available():
168                            _disable_tensor_cores()
169                            table.append([
170                                ("Training" if optimizer else "Inference"),
171                                backend if backend else "-",
172                                mode if mode is not None else "-",
173                                f"{compilation_time} ms " if compilation_time else "-",
174                                f"{running_time} ms " if running_time else "-",
175                            ])
176
177            else:
178                torch._dynamo.reset()
179                compilation_time, running_time = benchmark_compile(
180                    model, sample_input, num_iters, backend, None, optimizer, loss_fn)
181
182                if running_time is not None:
183                    table.append([
184                        ("Training" if optimizer else "Inference"),
185                        backend, "-",
186                        f"{compilation_time} ms " or "-",
187                        f"{running_time} ms ",
188                    ])
189
190
191        return tabulate(table, headers=field_names, tablefmt="github")
192