xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/core/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Key enums and structs used to handle data flow within the benchmark."""
2# mypy: ignore-errors
3import dataclasses
4import enum
5import itertools as it
6import re
7import textwrap
8from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
9
10from worker.main import WorkerTimerArgs
11
12
13if TYPE_CHECKING:
14    # Benchmark utils are only partially strict compliant, so MyPy won't follow
15    # imports using the public namespace. (Due to an exclusion rule in
16    # mypy-strict.ini)
17    from torch.utils.benchmark.utils.timer import Language
18else:
19    from torch.utils.benchmark import Language
20
21
22# Note:
23#   WorkerTimerArgs is defined in worker.main so that the worker does not
24#   depend on any files, including core.api. We mirror it with a public symbol
25#   `TimerArgs` for API consistency.
26TimerArgs = WorkerTimerArgs
27
28
29class RuntimeMode(enum.Enum):
30    EAGER = "Eager"
31    JIT = "TorchScript"
32    EXPLICIT = ""
33
34
35class AutogradMode(enum.Enum):
36    FORWARD = "Forward"
37    FORWARD_BACKWARD = "Forward + Backward"
38    EXPLICIT = ""
39
40
41@dataclasses.dataclass(frozen=True)
42class AutoLabels:
43    """Labels for a TimerArgs instance which are inferred during unpacking."""
44
45    runtime: RuntimeMode
46    autograd: AutogradMode
47    language: Language
48
49    @property
50    def as_dict(self) -> Dict[str, str]:
51        """Dict representation for CI reporting."""
52        return {
53            "runtime": self.runtime.value,
54            "autograd": self.autograd.value,
55            "language": "Python" if self.language == Language.PYTHON else "C++",
56        }
57
58
59@dataclasses.dataclass(frozen=True)
60class GroupedSetup:
61    py_setup: str = ""
62    cpp_setup: str = ""
63    global_setup: str = ""
64
65    def __post_init__(self) -> None:
66        for field in dataclasses.fields(self):
67            assert field.type == str
68            value: str = getattr(self, field.name)
69            object.__setattr__(self, field.name, textwrap.dedent(value))
70
71
72@dataclasses.dataclass(frozen=True)
73class GroupedBenchmark:
74    """Base class for defining groups of benchmarks.
75
76    Concrete interfaces:
77     - `core.api.GroupedStmts`     (init_from_stmts)
78     - `core.api.GroupedModules`   (init_from_model)
79     - `core.api.GroupedVariants`  (init_from_variants)
80
81    There are a variety of dimensions along which one might wish to measure
82    PyTorch performance:
83      - Python, C++
84      - Eager, TorchScript
85      - Single threaded, multi threaded
86      - Training, inference
87
88    It is useful to define them together, both for clear, concise benchmark
89    definition and more intelligent post processing and analysis.
90
91    There are also two programming idioms in PyTorch. One is to write free form
92    code (so-called "NumPy with gradients"), and the other is to organize code
93    using `torch.nn.Module`s. (This is how common neural network layers are
94    exposed through the PyTorch API.) To support easy definition two simple
95    initialization methods are provided:
96     - `init_from_stmts`
97     - `init_from_model`
98
99    Those methods will document their unique constructor arguments, however
100    most are shared and are defined here:
101        setup: Defines how to initialize a benchmark in both Python and C++.
102        signature:
103            A string of the form:
104            ```
105                f(a, b, ...) -> c
106            ```
107            For instance, if Python setup is:
108            ```
109                x = torch.ones((2,), requires_grad=True)
110                y = torch.ones((2,))
111            ```
112            and the corresponding stmt is:
113            ```
114                z = torch.dot(x, y)
115            ```
116            Then the signature is `f(x, y) -> z`. `signature` is required any
117            time we need to generate part of a snippet:
118             - When calling an opaque model provided by `init_from_models`
119             - When `torchscript=True`
120             - When `autograd=True`
121
122            If a return value is not needed (e.g. because of in place mutation)
123            then `-> None` is valid, but a non-None return must be provided if
124            `autograd=True`
125
126        torchscript:
127            If True, also JIT the stmt or model and generate benchmarks which
128            call the scripted version. Requires that `signature` is defined.
129
130        autograd:
131            If True, generate both forward and forward + backward benchmarks.
132            Requires that `signature` is defined, and return value is not None.
133
134        num_threads:
135            Maps to the Timer arg. If a tuple of ints is provided, benchmarks
136            will be generated for each value.
137
138    A third method, `init_from_variants`, is provided to define several related
139    benchmarks at once.
140    """
141
142    # These are the stmts which are actually executed by Timer. In the case of
143    # `GroupedStmts` (init_from_stmts) they are passed through from user args.
144    # In the case of `GroupedModules` (init_from_model) they are generated
145    # using `signature`. (e.g. `f(x, y) -> z` generates `z = model(x, y)`)
146    py_fwd_stmt: Optional[str]
147    cpp_fwd_stmt: Optional[str]
148
149    # Code block used to define a model. `init_from_stmts` will never populate
150    # `cpp_model_setup`, but if TorchScript is requested it will generate
151    # `py_model_setup` using `torch.jit.script`.
152    py_model_setup: Optional[str]
153    cpp_model_setup: Optional[str]
154
155    # True if this benchmark used `init_from_stmts`, otherwise False.
156    inferred_model_setup: bool
157
158    # Described above
159    setup: GroupedSetup
160    signature_args: Optional[Tuple[str, ...]]
161    signature_output: Optional[str]
162    torchscript: bool
163    autograd: bool
164    num_threads: Tuple[int, ...]
165
166    @classmethod
167    def init_from_stmts(
168        cls,
169        py_stmt: Optional[str] = None,
170        cpp_stmt: Optional[str] = None,
171        # Generic constructor arguments
172        setup: GroupedSetup = GroupedSetup(),
173        signature: Optional[str] = None,
174        torchscript: bool = False,
175        autograd: bool = False,
176        num_threads: Union[int, Tuple[int, ...]] = 1,
177    ) -> "GroupedBenchmark":
178        """Create a set of benchmarks from free-form statements.
179
180        This method of benchmark definition is analogous to Timer use, where
181        we simply execute the provided stmts.
182        """
183        if py_stmt is not None:
184            py_stmt = textwrap.dedent(py_stmt)
185
186        if cpp_stmt is not None:
187            cpp_stmt = textwrap.dedent(cpp_stmt)
188
189        signature_args, signature_output = cls._parse_signature(signature)
190        py_model_setup = (
191            cls._model_from_py_stmt(
192                py_stmt=py_stmt,
193                signature_args=signature_args,
194                signature_output=signature_output,
195            )
196            if torchscript
197            else None
198        )
199
200        return cls(
201            py_fwd_stmt=py_stmt,
202            cpp_fwd_stmt=cpp_stmt,
203            py_model_setup=py_model_setup,
204            cpp_model_setup=None,
205            inferred_model_setup=True,
206            setup=setup,
207            signature_args=signature_args,
208            signature_output=signature_output,
209            torchscript=torchscript,
210            autograd=autograd,
211            num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads,
212        )
213
214    @classmethod
215    def init_from_model(
216        cls,
217        py_model_setup: Optional[str] = None,
218        cpp_model_setup: Optional[str] = None,
219        # Generic constructor arguments
220        setup: GroupedSetup = GroupedSetup(),
221        signature: Optional[str] = None,
222        torchscript: bool = False,
223        autograd: bool = False,
224        num_threads: Union[int, Tuple[int, ...]] = 1,
225    ) -> "GroupedBenchmark":
226        """Create a set of benchmarks using torch.nn Modules.
227
228        This method of benchmark creation takes setup code, and then calls
229        a model rather than a free form block of code. As a result, there are
230        two additional requirements compared to `init_from_stmts`:
231          - `signature` must be provided.
232          - A model (named "model") must be defined, either with `model = ...`
233            or `def model(...): ...` in Python or `auto model = ...` in C++.
234        """
235        signature_args, signature_output = cls._parse_signature(signature)
236        if signature_args is None:
237            raise ValueError(
238                "signature is needed when initializing from model definitions."
239            )
240
241        return cls(
242            *cls._make_model_invocation(
243                signature_args, signature_output, RuntimeMode.EAGER
244            ),
245            py_model_setup=py_model_setup,
246            cpp_model_setup=cpp_model_setup,
247            inferred_model_setup=False,
248            setup=setup,
249            signature_args=signature_args,
250            signature_output=signature_output,
251            torchscript=torchscript,
252            autograd=autograd,
253            num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads,
254        )
255
256    @classmethod
257    def init_from_variants(
258        cls,
259        py_block: str = "",
260        cpp_block: str = "",
261        num_threads: Union[int, Tuple[int, ...]] = 1,
262    ) -> Dict[Union[Tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
263        py_cases, py_setup, py_global_setup = cls._parse_variants(
264            py_block, Language.PYTHON
265        )
266        cpp_cases, cpp_setup, cpp_global_setup = cls._parse_variants(
267            cpp_block, Language.CPP
268        )
269
270        assert not py_global_setup
271        setup = GroupedSetup(
272            py_setup=py_setup,
273            cpp_setup=cpp_setup,
274            global_setup=cpp_global_setup,
275        )
276
277        # NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused
278        #     and we use the superset `Union[Tuple[str, ...], Optional[str]` to
279        #     match the expected signature.
280        variants: Dict[Union[Tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
281
282        seen_labels: Set[str] = set()
283        for label in it.chain(py_cases.keys(), cpp_cases.keys()):
284            if label in seen_labels:
285                continue
286            seen_labels.add(label)
287
288            py_lines = py_cases.get(label, [])
289            cpp_lines = cpp_cases.get(label, [])
290
291            n_lines = max(len(py_lines), len(cpp_lines))
292            py_lines += [""] * (n_lines - len(py_lines))
293            cpp_lines += [""] * (n_lines - len(cpp_lines))
294            lines = [
295                (py_stmt, cpp_stmt)
296                for py_stmt, cpp_stmt in zip(py_lines, cpp_lines)
297                if py_stmt or cpp_stmt
298            ]
299
300            for i, (py_stmt, cpp_stmt) in enumerate(lines):
301                case = (f"Case: {i:>2}",) if len(lines) > 1 else ()
302                variants[(label,) + case] = GroupedBenchmark.init_from_stmts(
303                    py_stmt=py_stmt or None,
304                    cpp_stmt=cpp_stmt or None,
305                    setup=setup,
306                    num_threads=num_threads,
307                )
308
309        return variants
310
311    def __post_init__(self) -> None:
312        if self.autograd and self.signature_output is None:
313            raise ValueError(
314                "An output variable must be specified when `autograd=True`."
315            )
316
317        if self.py_model_setup and "model" not in self.py_model_setup:
318            raise ValueError(
319                "`py_model_setup` appears to be missing `model` definition."
320            )
321
322        if self.cpp_model_setup and "model" not in self.cpp_model_setup:
323            raise ValueError(
324                "`cpp_model_setup` appears to be missing `model` definition."
325            )
326
327    # =========================================================================
328    # == String manipulation methods ==========================================
329    # =========================================================================
330
331    @staticmethod
332    def _parse_signature(
333        signature: Optional[str],
334    ) -> Tuple[Optional[Tuple[str, ...]], Optional[str]]:
335        if signature is None:
336            return None, None
337
338        match = re.search(r"^f\((.*)\) -> (.*)$", signature)
339        if match is None:
340            raise ValueError(f"Invalid signature: `{signature}`")
341
342        args: Tuple[str, ...] = tuple(match.groups()[0].split(", "))
343        output: str = match.groups()[1].strip()
344
345        if "," in output:
346            raise ValueError(
347                f"Multiple return values are not currently allowed: `{output}`"
348            )
349
350        if output == "None":
351            return args, None
352
353        return args, output
354
355    @staticmethod
356    def _model_from_py_stmt(
357        py_stmt: Optional[str],
358        signature_args: Optional[Tuple[str, ...]],
359        signature_output: Optional[str],
360    ) -> str:
361        if py_stmt is None:
362            raise ValueError("`py_stmt` must be defined in order to derive a model.")
363
364        if signature_args is None:
365            raise ValueError("signature is needed in order to derive a model.")
366
367        return textwrap.dedent(
368            f"""\
369            def model({', '.join(signature_args)}):
370            {{stmt_str}}
371                return {signature_output}
372        """
373        ).format(stmt_str=textwrap.indent(py_stmt, " " * 4))
374
375    @staticmethod
376    def _make_model_invocation(
377        signature_args: Tuple[str, ...],
378        signature_output: Optional[str],
379        runtime: RuntimeMode,
380    ) -> Tuple[str, str]:
381        py_prefix, cpp_prefix = "", ""
382        if signature_output is not None:
383            py_prefix = f"{signature_output} = "
384            cpp_prefix = f"auto {signature_output} = "
385
386        if runtime == RuntimeMode.EAGER:
387            model_name = "model"
388            cpp_invocation = (
389                f"{cpp_prefix}{model_name}->forward({', '.join(signature_args)});"
390            )
391
392        else:
393            assert runtime == RuntimeMode.JIT
394            model_name = "jit_model"
395            cpp_invocation = textwrap.dedent(
396                f"""\
397                std::vector<torch::jit::IValue> ivalue_inputs({{
398                    {', '.join([f'torch::jit::IValue({a})' for a in signature_args])}
399                }});
400                {cpp_prefix}{model_name}.forward(ivalue_inputs);
401            """
402            )
403
404        # NB:
405        #   In python we invoke __call__, however C++ doesn't have an analogous
406        #   method so we invoke `forward` instead. This means that Python
407        #   is doing extra work (e.g. checking hooks) compared to C++; however
408        #   because this is the default user experience that's acceptable.
409        py_invocation = f"{py_prefix}{model_name}({', '.join(signature_args)})"
410
411        return py_invocation, cpp_invocation
412
413    @staticmethod
414    def _parse_variants(
415        block: str, language: Language
416    ) -> Tuple[Dict[str, List[str]], str, str]:
417        block = textwrap.dedent(block).strip()
418        comment = "#" if language == Language.PYTHON else "//"
419        label_pattern = f"{comment} @(.+)$"
420        label = ""
421
422        lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
423        for line in block.splitlines(keepends=False):
424            match = re.search(label_pattern, line.strip())
425            if match:
426                label = match.groups()[0]
427                if label.replace(" ", "_").upper() in ("SETUP", "GLOBAL_SETUP"):
428                    label = label.replace(" ", "_").upper()
429                continue
430
431            lines_by_label.setdefault(label, [])
432            if line.startswith(comment):
433                line = ""
434            lines_by_label[label].append(line)
435
436        setup = "\n".join(lines_by_label.pop("SETUP"))
437        global_setup = "\n".join(lines_by_label.pop("GLOBAL_SETUP"))
438
439        return lines_by_label, setup, global_setup
440
441
442# These are the user facing APIs.
443GroupedStmts = GroupedBenchmark.init_from_stmts
444GroupedModules = GroupedBenchmark.init_from_model
445GroupedVariants = GroupedBenchmark.init_from_variants
446