1"""Logic for converting human-readable benchmarks into executable form. 2 3This is mostly string manipulation, with just a bit of importlib magic. 4""" 5# mypy: ignore-errors 6import importlib.abc 7import importlib.util 8import itertools as it 9import os 10import re 11import textwrap 12import uuid 13from typing import List, Optional, Tuple, TYPE_CHECKING 14 15import torch 16 17 18if TYPE_CHECKING: 19 # See the note in api.py for why this is necessary. 20 from torch.utils.benchmark.utils.timer import Language 21else: 22 from torch.utils.benchmark import Language 23 24from core.api import AutogradMode, AutoLabels, GroupedBenchmark, RuntimeMode, TimerArgs 25from core.types import FlatDefinition, FlatIntermediateDefinition, Label 26from core.utils import get_temp_dir 27 28 29_ALL_MODES = tuple( 30 it.product( 31 RuntimeMode, 32 AutogradMode, 33 Language, 34 ) 35) 36 37 38def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]: 39 """Returns the path a saved model if one can be constructed from `spec`. 40 41 Because TorchScript requires actual source code in order to script a 42 model, we can't simply `eval` an appropriate model string. Instead, we 43 must write the correct source to a temporary Python file and then import 44 the TorchScript model from that temporary file. 45 46 `model_src` must contain `jit_model = ...`, which `materialize` will supply. 47 """ 48 # Double check. 49 assert "jit_model = " in model_src, f"Missing jit_model definition:\n{model_src}" 50 51 # `torch.utils.benchmark.Timer` will automatically import torch, so we 52 # need to match that convention. 53 model_src = f"import torch\n{model_src}" 54 55 model_root = os.path.join(get_temp_dir(), "TorchScript_models") 56 os.makedirs(model_root, exist_ok=True) 57 module_path = os.path.join(model_root, f"torchscript_{name}.py") 58 artifact_path = os.path.join(model_root, f"torchscript_{name}.pt") 59 60 if os.path.exists(module_path): 61 # The uuid in `name` should protect against this, but it doesn't hurt 62 # to confirm. 63 raise ValueError(f"File {module_path} already exists.") 64 65 with open(module_path, "w") as f: 66 f.write(model_src) 67 68 # Import magic to actually load our function. 69 module_spec = importlib.util.spec_from_file_location( 70 f"torchscript__{name}", module_path 71 ) 72 assert module_spec is not None 73 module = importlib.util.module_from_spec(module_spec) 74 loader = module_spec.loader 75 assert loader is not None 76 77 loader.exec_module(module) 78 79 # And again, the type checker has no way of knowing that this line is valid. 80 jit_model = module.jit_model # type: ignore[attr-defined] 81 assert isinstance( 82 jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule) 83 ), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}" 84 jit_model.save(artifact_path) # type: ignore[call-arg] 85 86 # Cleanup now that we have the actual serialized model. 87 os.remove(module_path) 88 return artifact_path 89 90 91def _get_stmt( 92 benchmark: GroupedBenchmark, 93 runtime: RuntimeMode, 94 autograd: AutogradMode, 95 language: Language, 96) -> Optional[str]: 97 """Specialize a GroupedBenchmark for a particular configuration.""" 98 is_python = language == Language.PYTHON 99 100 # During GroupedBenchmark construction, py_fwd_stmt and cpp_fwd_stmt are 101 # set to the eager invocation. So in the RuntimeMode.EAGER case we can 102 # simply reuse them. For the RuntimeMode.JIT case, we need to generate 103 # an appropriate `jit_model(...)` invocation. 104 if runtime == RuntimeMode.EAGER: 105 stmts = (benchmark.py_fwd_stmt, benchmark.cpp_fwd_stmt) 106 107 else: 108 assert runtime == RuntimeMode.JIT 109 assert benchmark.signature_args is not None 110 stmts = GroupedBenchmark._make_model_invocation( 111 benchmark.signature_args, benchmark.signature_output, RuntimeMode.JIT 112 ) 113 114 stmt = stmts[0 if is_python else 1] 115 116 if autograd == AutogradMode.FORWARD_BACKWARD and stmt is not None: 117 assert benchmark.signature_output is not None 118 backward = ( 119 f"{benchmark.signature_output}" 120 # In C++ we have to get the Tensor out of the IValue to call `.backward()` 121 f"{'.toTensor()' if runtime == RuntimeMode.JIT and language == Language.CPP else ''}" 122 f".backward(){';' if language == Language.CPP else ''}" 123 ) 124 stmt = f"{stmt}\n{backward}" 125 return stmt 126 127 128def _get_setup( 129 benchmark: GroupedBenchmark, 130 runtime: RuntimeMode, 131 language: Language, 132 stmt: str, 133 model_path: Optional[str], 134) -> str: 135 """Specialize a GroupedBenchmark for a particular configuration. 136 137 Setup requires two extra pieces of information: 138 1) The benchmark stmt. This is needed to warm up the model and avoid 139 measuring lazy initialization. 140 2) The model path so we can load it during the benchmark. 141 142 These are only used when `runtime == RuntimeMode.JIT`. 143 """ 144 145 # By the time we get here, details about how to set up a model have already 146 # been determined by GroupedBenchmark. (Or set to None if appropriate.) We 147 # simply need to collect and package the code blocks. 148 if language == Language.PYTHON: 149 setup = benchmark.setup.py_setup 150 model_setup = benchmark.py_model_setup 151 else: 152 assert language == Language.CPP 153 setup = benchmark.setup.cpp_setup 154 model_setup = benchmark.cpp_model_setup 155 156 if runtime == RuntimeMode.EAGER: 157 return "\n".join([setup, model_setup or ""]) 158 159 assert runtime == RuntimeMode.JIT 160 assert model_path is not None 161 162 # We template `"{model_path}"`, so quotes would break model loading. The 163 # model path is generated within the benchmark, so this is just an 164 # abundance of caution rather than something that is expected in practice. 165 assert '"' not in model_path 166 167 # `stmt` may contain newlines, so we can't use f-strings. Instead we need 168 # to generate templates so that dedent works properly. 169 if language == Language.PYTHON: 170 setup_template: str = textwrap.dedent( 171 f""" 172 jit_model = torch.jit.load("{model_path}") 173 174 # Warmup `jit_model` 175 for _ in range(3): 176 {{stmt}} 177 """ 178 ) 179 180 else: 181 assert language == Language.CPP 182 setup_template = textwrap.dedent( 183 f""" 184 const std::string fpath = "{model_path}"; 185 auto jit_model = torch::jit::load(fpath); 186 187 // Warmup `jit_model` 188 for (int i = 0; i < 3; i++) {{{{ 189 {{stmt}} 190 }}}} 191 """ 192 ) 193 194 model_load = setup_template.format(stmt=textwrap.indent(stmt, " " * 4)) 195 return "\n".join([setup, model_load]) 196 197 198def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition: 199 """Convert a heterogeneous benchmark into an executable state. 200 201 This entails generation of TorchScript model artifacts, splitting 202 GroupedBenchmarks into multiple TimerArgs, and tagging the results with 203 AutoLabels. 204 """ 205 results: List[Tuple[Label, AutoLabels, TimerArgs]] = [] 206 207 for label, args in benchmarks.items(): 208 if isinstance(args, TimerArgs): 209 # User provided an explicit TimerArgs, so no processing is necessary. 210 auto_labels = AutoLabels( 211 RuntimeMode.EXPLICIT, AutogradMode.EXPLICIT, args.language 212 ) 213 results.append((label, auto_labels, args)) 214 215 else: 216 assert isinstance(args, GroupedBenchmark) 217 218 model_path: Optional[str] = None 219 if args.py_model_setup and args.torchscript: 220 model_setup = ( 221 f"{args.py_model_setup}\njit_model = torch.jit.script(model)" 222 ) 223 224 # This is just for debugging. We just need a unique name for the 225 # model, but embedding the label makes debugging easier. 226 name: str = re.sub(r"[^a-z0-9_]", "_", "_".join(label).lower()) 227 name = f"{name}_{uuid.uuid4()}" 228 229 model_path = _generate_torchscript_file(model_setup, name=name) 230 231 for (runtime, autograd, language), num_threads in it.product( 232 _ALL_MODES, args.num_threads 233 ): 234 if runtime == RuntimeMode.EXPLICIT or autograd == AutogradMode.EXPLICIT: 235 continue 236 237 if runtime == RuntimeMode.JIT and not args.torchscript: 238 continue 239 240 if autograd == AutogradMode.FORWARD_BACKWARD and not args.autograd: 241 continue 242 243 stmt = _get_stmt(args, runtime, autograd, language) 244 if stmt is None: 245 continue 246 247 setup = _get_setup(args, runtime, language, stmt, model_path) 248 249 global_setup: str = "" 250 if language == Language.CPP and runtime == RuntimeMode.JIT: 251 global_setup = textwrap.dedent( 252 """ 253 #include <string> 254 #include <vector> 255 #include <torch/script.h> 256 """ 257 ) 258 259 autolabels = AutoLabels(runtime, autograd, language) 260 timer_args = TimerArgs( 261 stmt=stmt, 262 setup=setup, 263 global_setup=global_setup, 264 num_threads=num_threads, 265 language=language, 266 ) 267 268 results.append((label, autolabels, timer_args)) 269 270 return tuple(results) 271