1"""Key enums and structs used to handle data flow within the benchmark.""" 2# mypy: ignore-errors 3import dataclasses 4import enum 5import itertools as it 6import re 7import textwrap 8from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union 9 10from worker.main import WorkerTimerArgs 11 12 13if TYPE_CHECKING: 14 # Benchmark utils are only partially strict compliant, so MyPy won't follow 15 # imports using the public namespace. (Due to an exclusion rule in 16 # mypy-strict.ini) 17 from torch.utils.benchmark.utils.timer import Language 18else: 19 from torch.utils.benchmark import Language 20 21 22# Note: 23# WorkerTimerArgs is defined in worker.main so that the worker does not 24# depend on any files, including core.api. We mirror it with a public symbol 25# `TimerArgs` for API consistency. 26TimerArgs = WorkerTimerArgs 27 28 29class RuntimeMode(enum.Enum): 30 EAGER = "Eager" 31 JIT = "TorchScript" 32 EXPLICIT = "" 33 34 35class AutogradMode(enum.Enum): 36 FORWARD = "Forward" 37 FORWARD_BACKWARD = "Forward + Backward" 38 EXPLICIT = "" 39 40 41@dataclasses.dataclass(frozen=True) 42class AutoLabels: 43 """Labels for a TimerArgs instance which are inferred during unpacking.""" 44 45 runtime: RuntimeMode 46 autograd: AutogradMode 47 language: Language 48 49 @property 50 def as_dict(self) -> Dict[str, str]: 51 """Dict representation for CI reporting.""" 52 return { 53 "runtime": self.runtime.value, 54 "autograd": self.autograd.value, 55 "language": "Python" if self.language == Language.PYTHON else "C++", 56 } 57 58 59@dataclasses.dataclass(frozen=True) 60class GroupedSetup: 61 py_setup: str = "" 62 cpp_setup: str = "" 63 global_setup: str = "" 64 65 def __post_init__(self) -> None: 66 for field in dataclasses.fields(self): 67 assert field.type == str 68 value: str = getattr(self, field.name) 69 object.__setattr__(self, field.name, textwrap.dedent(value)) 70 71 72@dataclasses.dataclass(frozen=True) 73class GroupedBenchmark: 74 """Base class for defining groups of benchmarks. 75 76 Concrete interfaces: 77 - `core.api.GroupedStmts` (init_from_stmts) 78 - `core.api.GroupedModules` (init_from_model) 79 - `core.api.GroupedVariants` (init_from_variants) 80 81 There are a variety of dimensions along which one might wish to measure 82 PyTorch performance: 83 - Python, C++ 84 - Eager, TorchScript 85 - Single threaded, multi threaded 86 - Training, inference 87 88 It is useful to define them together, both for clear, concise benchmark 89 definition and more intelligent post processing and analysis. 90 91 There are also two programming idioms in PyTorch. One is to write free form 92 code (so-called "NumPy with gradients"), and the other is to organize code 93 using `torch.nn.Module`s. (This is how common neural network layers are 94 exposed through the PyTorch API.) To support easy definition two simple 95 initialization methods are provided: 96 - `init_from_stmts` 97 - `init_from_model` 98 99 Those methods will document their unique constructor arguments, however 100 most are shared and are defined here: 101 setup: Defines how to initialize a benchmark in both Python and C++. 102 signature: 103 A string of the form: 104 ``` 105 f(a, b, ...) -> c 106 ``` 107 For instance, if Python setup is: 108 ``` 109 x = torch.ones((2,), requires_grad=True) 110 y = torch.ones((2,)) 111 ``` 112 and the corresponding stmt is: 113 ``` 114 z = torch.dot(x, y) 115 ``` 116 Then the signature is `f(x, y) -> z`. `signature` is required any 117 time we need to generate part of a snippet: 118 - When calling an opaque model provided by `init_from_models` 119 - When `torchscript=True` 120 - When `autograd=True` 121 122 If a return value is not needed (e.g. because of in place mutation) 123 then `-> None` is valid, but a non-None return must be provided if 124 `autograd=True` 125 126 torchscript: 127 If True, also JIT the stmt or model and generate benchmarks which 128 call the scripted version. Requires that `signature` is defined. 129 130 autograd: 131 If True, generate both forward and forward + backward benchmarks. 132 Requires that `signature` is defined, and return value is not None. 133 134 num_threads: 135 Maps to the Timer arg. If a tuple of ints is provided, benchmarks 136 will be generated for each value. 137 138 A third method, `init_from_variants`, is provided to define several related 139 benchmarks at once. 140 """ 141 142 # These are the stmts which are actually executed by Timer. In the case of 143 # `GroupedStmts` (init_from_stmts) they are passed through from user args. 144 # In the case of `GroupedModules` (init_from_model) they are generated 145 # using `signature`. (e.g. `f(x, y) -> z` generates `z = model(x, y)`) 146 py_fwd_stmt: Optional[str] 147 cpp_fwd_stmt: Optional[str] 148 149 # Code block used to define a model. `init_from_stmts` will never populate 150 # `cpp_model_setup`, but if TorchScript is requested it will generate 151 # `py_model_setup` using `torch.jit.script`. 152 py_model_setup: Optional[str] 153 cpp_model_setup: Optional[str] 154 155 # True if this benchmark used `init_from_stmts`, otherwise False. 156 inferred_model_setup: bool 157 158 # Described above 159 setup: GroupedSetup 160 signature_args: Optional[Tuple[str, ...]] 161 signature_output: Optional[str] 162 torchscript: bool 163 autograd: bool 164 num_threads: Tuple[int, ...] 165 166 @classmethod 167 def init_from_stmts( 168 cls, 169 py_stmt: Optional[str] = None, 170 cpp_stmt: Optional[str] = None, 171 # Generic constructor arguments 172 setup: GroupedSetup = GroupedSetup(), 173 signature: Optional[str] = None, 174 torchscript: bool = False, 175 autograd: bool = False, 176 num_threads: Union[int, Tuple[int, ...]] = 1, 177 ) -> "GroupedBenchmark": 178 """Create a set of benchmarks from free-form statements. 179 180 This method of benchmark definition is analogous to Timer use, where 181 we simply execute the provided stmts. 182 """ 183 if py_stmt is not None: 184 py_stmt = textwrap.dedent(py_stmt) 185 186 if cpp_stmt is not None: 187 cpp_stmt = textwrap.dedent(cpp_stmt) 188 189 signature_args, signature_output = cls._parse_signature(signature) 190 py_model_setup = ( 191 cls._model_from_py_stmt( 192 py_stmt=py_stmt, 193 signature_args=signature_args, 194 signature_output=signature_output, 195 ) 196 if torchscript 197 else None 198 ) 199 200 return cls( 201 py_fwd_stmt=py_stmt, 202 cpp_fwd_stmt=cpp_stmt, 203 py_model_setup=py_model_setup, 204 cpp_model_setup=None, 205 inferred_model_setup=True, 206 setup=setup, 207 signature_args=signature_args, 208 signature_output=signature_output, 209 torchscript=torchscript, 210 autograd=autograd, 211 num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads, 212 ) 213 214 @classmethod 215 def init_from_model( 216 cls, 217 py_model_setup: Optional[str] = None, 218 cpp_model_setup: Optional[str] = None, 219 # Generic constructor arguments 220 setup: GroupedSetup = GroupedSetup(), 221 signature: Optional[str] = None, 222 torchscript: bool = False, 223 autograd: bool = False, 224 num_threads: Union[int, Tuple[int, ...]] = 1, 225 ) -> "GroupedBenchmark": 226 """Create a set of benchmarks using torch.nn Modules. 227 228 This method of benchmark creation takes setup code, and then calls 229 a model rather than a free form block of code. As a result, there are 230 two additional requirements compared to `init_from_stmts`: 231 - `signature` must be provided. 232 - A model (named "model") must be defined, either with `model = ...` 233 or `def model(...): ...` in Python or `auto model = ...` in C++. 234 """ 235 signature_args, signature_output = cls._parse_signature(signature) 236 if signature_args is None: 237 raise ValueError( 238 "signature is needed when initializing from model definitions." 239 ) 240 241 return cls( 242 *cls._make_model_invocation( 243 signature_args, signature_output, RuntimeMode.EAGER 244 ), 245 py_model_setup=py_model_setup, 246 cpp_model_setup=cpp_model_setup, 247 inferred_model_setup=False, 248 setup=setup, 249 signature_args=signature_args, 250 signature_output=signature_output, 251 torchscript=torchscript, 252 autograd=autograd, 253 num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads, 254 ) 255 256 @classmethod 257 def init_from_variants( 258 cls, 259 py_block: str = "", 260 cpp_block: str = "", 261 num_threads: Union[int, Tuple[int, ...]] = 1, 262 ) -> Dict[Union[Tuple[str, ...], Optional[str]], "GroupedBenchmark"]: 263 py_cases, py_setup, py_global_setup = cls._parse_variants( 264 py_block, Language.PYTHON 265 ) 266 cpp_cases, cpp_setup, cpp_global_setup = cls._parse_variants( 267 cpp_block, Language.CPP 268 ) 269 270 assert not py_global_setup 271 setup = GroupedSetup( 272 py_setup=py_setup, 273 cpp_setup=cpp_setup, 274 global_setup=cpp_global_setup, 275 ) 276 277 # NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused 278 # and we use the superset `Union[Tuple[str, ...], Optional[str]` to 279 # match the expected signature. 280 variants: Dict[Union[Tuple[str, ...], Optional[str]], GroupedBenchmark] = {} 281 282 seen_labels: Set[str] = set() 283 for label in it.chain(py_cases.keys(), cpp_cases.keys()): 284 if label in seen_labels: 285 continue 286 seen_labels.add(label) 287 288 py_lines = py_cases.get(label, []) 289 cpp_lines = cpp_cases.get(label, []) 290 291 n_lines = max(len(py_lines), len(cpp_lines)) 292 py_lines += [""] * (n_lines - len(py_lines)) 293 cpp_lines += [""] * (n_lines - len(cpp_lines)) 294 lines = [ 295 (py_stmt, cpp_stmt) 296 for py_stmt, cpp_stmt in zip(py_lines, cpp_lines) 297 if py_stmt or cpp_stmt 298 ] 299 300 for i, (py_stmt, cpp_stmt) in enumerate(lines): 301 case = (f"Case: {i:>2}",) if len(lines) > 1 else () 302 variants[(label,) + case] = GroupedBenchmark.init_from_stmts( 303 py_stmt=py_stmt or None, 304 cpp_stmt=cpp_stmt or None, 305 setup=setup, 306 num_threads=num_threads, 307 ) 308 309 return variants 310 311 def __post_init__(self) -> None: 312 if self.autograd and self.signature_output is None: 313 raise ValueError( 314 "An output variable must be specified when `autograd=True`." 315 ) 316 317 if self.py_model_setup and "model" not in self.py_model_setup: 318 raise ValueError( 319 "`py_model_setup` appears to be missing `model` definition." 320 ) 321 322 if self.cpp_model_setup and "model" not in self.cpp_model_setup: 323 raise ValueError( 324 "`cpp_model_setup` appears to be missing `model` definition." 325 ) 326 327 # ========================================================================= 328 # == String manipulation methods ========================================== 329 # ========================================================================= 330 331 @staticmethod 332 def _parse_signature( 333 signature: Optional[str], 334 ) -> Tuple[Optional[Tuple[str, ...]], Optional[str]]: 335 if signature is None: 336 return None, None 337 338 match = re.search(r"^f\((.*)\) -> (.*)$", signature) 339 if match is None: 340 raise ValueError(f"Invalid signature: `{signature}`") 341 342 args: Tuple[str, ...] = tuple(match.groups()[0].split(", ")) 343 output: str = match.groups()[1].strip() 344 345 if "," in output: 346 raise ValueError( 347 f"Multiple return values are not currently allowed: `{output}`" 348 ) 349 350 if output == "None": 351 return args, None 352 353 return args, output 354 355 @staticmethod 356 def _model_from_py_stmt( 357 py_stmt: Optional[str], 358 signature_args: Optional[Tuple[str, ...]], 359 signature_output: Optional[str], 360 ) -> str: 361 if py_stmt is None: 362 raise ValueError("`py_stmt` must be defined in order to derive a model.") 363 364 if signature_args is None: 365 raise ValueError("signature is needed in order to derive a model.") 366 367 return textwrap.dedent( 368 f"""\ 369 def model({', '.join(signature_args)}): 370 {{stmt_str}} 371 return {signature_output} 372 """ 373 ).format(stmt_str=textwrap.indent(py_stmt, " " * 4)) 374 375 @staticmethod 376 def _make_model_invocation( 377 signature_args: Tuple[str, ...], 378 signature_output: Optional[str], 379 runtime: RuntimeMode, 380 ) -> Tuple[str, str]: 381 py_prefix, cpp_prefix = "", "" 382 if signature_output is not None: 383 py_prefix = f"{signature_output} = " 384 cpp_prefix = f"auto {signature_output} = " 385 386 if runtime == RuntimeMode.EAGER: 387 model_name = "model" 388 cpp_invocation = ( 389 f"{cpp_prefix}{model_name}->forward({', '.join(signature_args)});" 390 ) 391 392 else: 393 assert runtime == RuntimeMode.JIT 394 model_name = "jit_model" 395 cpp_invocation = textwrap.dedent( 396 f"""\ 397 std::vector<torch::jit::IValue> ivalue_inputs({{ 398 {', '.join([f'torch::jit::IValue({a})' for a in signature_args])} 399 }}); 400 {cpp_prefix}{model_name}.forward(ivalue_inputs); 401 """ 402 ) 403 404 # NB: 405 # In python we invoke __call__, however C++ doesn't have an analogous 406 # method so we invoke `forward` instead. This means that Python 407 # is doing extra work (e.g. checking hooks) compared to C++; however 408 # because this is the default user experience that's acceptable. 409 py_invocation = f"{py_prefix}{model_name}({', '.join(signature_args)})" 410 411 return py_invocation, cpp_invocation 412 413 @staticmethod 414 def _parse_variants( 415 block: str, language: Language 416 ) -> Tuple[Dict[str, List[str]], str, str]: 417 block = textwrap.dedent(block).strip() 418 comment = "#" if language == Language.PYTHON else "//" 419 label_pattern = f"{comment} @(.+)$" 420 label = "" 421 422 lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []} 423 for line in block.splitlines(keepends=False): 424 match = re.search(label_pattern, line.strip()) 425 if match: 426 label = match.groups()[0] 427 if label.replace(" ", "_").upper() in ("SETUP", "GLOBAL_SETUP"): 428 label = label.replace(" ", "_").upper() 429 continue 430 431 lines_by_label.setdefault(label, []) 432 if line.startswith(comment): 433 line = "" 434 lines_by_label[label].append(line) 435 436 setup = "\n".join(lines_by_label.pop("SETUP")) 437 global_setup = "\n".join(lines_by_label.pop("GLOBAL_SETUP")) 438 439 return lines_by_label, setup, global_setup 440 441 442# These are the user facing APIs. 443GroupedStmts = GroupedBenchmark.init_from_stmts 444GroupedModules = GroupedBenchmark.init_from_model 445GroupedVariants = GroupedBenchmark.init_from_variants 446