1# mypy: allow-untyped-defs 2from typing import Any, Callable, cast, List, Optional, Union 3 4import torch 5import torch._dynamo 6from torch._dynamo.testing import CompileCounterWithBackend 7from torch.utils.benchmark import Timer 8 9 10__all__ = ["bench_all", "benchmark_compile"] 11 12 13_warned_tensor_cores = False 14_default_float_32_precision = torch.get_float32_matmul_precision() 15 16try: 17 from tabulate import tabulate 18 19 HAS_TABULATE = True 20except ModuleNotFoundError: 21 HAS_TABULATE = False 22 tabulate = None # type: ignore[assignment] 23 print("tabulate is not installed, please pip install tabulate to use this utility") 24 25if HAS_TABULATE: 26 def _enable_tensor_cores(): 27 global _warned_tensor_cores 28 29 if torch.cuda.is_available(): 30 if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0): 31 torch.set_float32_matmul_precision("high") 32 if not _warned_tensor_cores: 33 print("Your GPU supports tensor cores") 34 print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`") 35 _warned_tensor_cores = True 36 37 def _disable_tensor_cores(): 38 torch.set_float32_matmul_precision(_default_float_32_precision) 39 40 def bench_loop( 41 model: Union[torch.nn.Module, Callable], 42 sample_input: Union[torch.Tensor, Any], 43 num_iters: int = 5, 44 optimizer: Optional[torch.optim.Optimizer] = None, 45 loss_fn: Optional[Callable] = None, 46 ): 47 # Define the statement and setup for the benchmark 48 if optimizer and loss_fn: 49 # Training mode 50 stmt = """ 51 output = model(sample_input) 52 loss = loss_fn(output) if loss_fn else output.sum() 53 loss.backward() 54 optimizer.step() 55 optimizer.zero_grad() 56 """ 57 else: 58 # Inference mode 59 stmt = "model(sample_input)" 60 61 # Create the Timer object 62 timer = Timer( 63 stmt=stmt, 64 globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn}, 65 ) 66 67 68 result = timer.timeit(number=num_iters) 69 70 # Get the average time per iteration in milliseconds 71 avg_time = result.mean * 1000 72 return round(avg_time, 2) 73 74 def benchmark_compile( 75 model: Union[torch.nn.Module, Callable], 76 sample_input: Union[torch.Tensor, Any], 77 num_iters: int = 5, 78 backend: Optional[str] = None, 79 mode: Optional[str] = "default", 80 optimizer: Optional[torch.optim.Optimizer] = None, 81 loss_fn : Union[torch.nn.Module, Callable, None] = None, 82 ): 83 """ 84 Use this utility to benchmark torch.compile 85 """ 86 if backend: 87 try: 88 torch._dynamo.reset() 89 compile_counter_with_backend = CompileCounterWithBackend(backend) 90 opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode) 91 92 # Compilation only happens after the first inference 93 compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn) 94 95 running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) 96 97 if compile_counter_with_backend.frame_count == 0: 98 raise RuntimeError("No compilation occurred during benchmarking.") 99 100 if compile_counter_with_backend.frame_count > 1: 101 raise RuntimeError("Recompilation occurred during benchmarking.") 102 103 except Exception as e: 104 print(e) 105 print(f"Failed to compile {backend} with mode {mode}") 106 return None, None 107 else: 108 opt_model = model 109 compilation_time = None 110 running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) 111 112 compilation_time = round(compilation_time, 2) if compilation_time else None 113 running_time = round(running_time, 2) if running_time else None 114 115 116 return compilation_time, running_time 117 118 119 def bench_all( 120 model : Union[torch.nn.Module, Callable], 121 sample_input: Union[torch.Tensor, Any], 122 num_iters : int = 5, 123 optimizer: Optional[torch.optim.Optimizer] = None, 124 loss_fn : Union[torch.nn.Module, Callable, None] = None, 125 ): 126 """ 127 This is a simple utility that can be used to benchmark torch.compile 128 In particular it ensures that your GPU is setup to use tensor cores if it supports its 129 It also tries out all the main backends and prints a table of results so you can easily compare them all 130 Many of the backendds have their own optional dependencies so please pip install them seperately 131 132 You will get one table for inference and another for training 133 If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer 134 135 The important warnings are 136 Your GPU supports tensor cores 137 we will enable it automatically by setting `torch.set_float32_matmul_precision('high')` 138 139 If a compilation fails for any reason including the dependency not being included 140 then we will print Failed to compile {backend} with mode {mode} 141 """ 142 field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"] 143 table = [] 144 145 146 eager_time = None 147 torch._dynamo.reset() 148 _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer) 149 table.append( 150 [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"] 151 ) 152 153 for backend in torch._dynamo.list_backends(): 154 155 if backend == "inductor": 156 mode_options = cast(List[Optional[str]], list(torch._inductor.list_mode_options().keys())) + [None] 157 for mode in mode_options: 158 if mode == "default": 159 continue 160 torch._dynamo.reset() 161 try: 162 if torch.cuda.is_available(): 163 _enable_tensor_cores() 164 compilation_time, running_time = benchmark_compile( 165 model, sample_input, num_iters, backend, mode, optimizer, loss_fn) 166 finally: 167 if torch.cuda.is_available(): 168 _disable_tensor_cores() 169 table.append([ 170 ("Training" if optimizer else "Inference"), 171 backend if backend else "-", 172 mode if mode is not None else "-", 173 f"{compilation_time} ms " if compilation_time else "-", 174 f"{running_time} ms " if running_time else "-", 175 ]) 176 177 else: 178 torch._dynamo.reset() 179 compilation_time, running_time = benchmark_compile( 180 model, sample_input, num_iters, backend, None, optimizer, loss_fn) 181 182 if running_time is not None: 183 table.append([ 184 ("Training" if optimizer else "Inference"), 185 backend, "-", 186 f"{compilation_time} ms " or "-", 187 f"{running_time} ms ", 188 ]) 189 190 191 return tabulate(table, headers=field_names, tablefmt="github") 192