xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/core/expand.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Logic for converting human-readable benchmarks into executable form.
2
3This is mostly string manipulation, with just a bit of importlib magic.
4"""
5# mypy: ignore-errors
6import importlib.abc
7import importlib.util
8import itertools as it
9import os
10import re
11import textwrap
12import uuid
13from typing import List, Optional, Tuple, TYPE_CHECKING
14
15import torch
16
17
18if TYPE_CHECKING:
19    # See the note in api.py for why this is necessary.
20    from torch.utils.benchmark.utils.timer import Language
21else:
22    from torch.utils.benchmark import Language
23
24from core.api import AutogradMode, AutoLabels, GroupedBenchmark, RuntimeMode, TimerArgs
25from core.types import FlatDefinition, FlatIntermediateDefinition, Label
26from core.utils import get_temp_dir
27
28
29_ALL_MODES = tuple(
30    it.product(
31        RuntimeMode,
32        AutogradMode,
33        Language,
34    )
35)
36
37
38def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
39    """Returns the path a saved model if one can be constructed from `spec`.
40
41    Because TorchScript requires actual source code in order to script a
42    model, we can't simply `eval` an appropriate model string. Instead, we
43    must write the correct source to a temporary Python file and then import
44    the TorchScript model from that temporary file.
45
46    `model_src` must contain `jit_model = ...`, which `materialize` will supply.
47    """
48    # Double check.
49    assert "jit_model = " in model_src, f"Missing jit_model definition:\n{model_src}"
50
51    # `torch.utils.benchmark.Timer` will automatically import torch, so we
52    # need to match that convention.
53    model_src = f"import torch\n{model_src}"
54
55    model_root = os.path.join(get_temp_dir(), "TorchScript_models")
56    os.makedirs(model_root, exist_ok=True)
57    module_path = os.path.join(model_root, f"torchscript_{name}.py")
58    artifact_path = os.path.join(model_root, f"torchscript_{name}.pt")
59
60    if os.path.exists(module_path):
61        # The uuid in `name` should protect against this, but it doesn't hurt
62        # to confirm.
63        raise ValueError(f"File {module_path} already exists.")
64
65    with open(module_path, "w") as f:
66        f.write(model_src)
67
68    # Import magic to actually load our function.
69    module_spec = importlib.util.spec_from_file_location(
70        f"torchscript__{name}", module_path
71    )
72    assert module_spec is not None
73    module = importlib.util.module_from_spec(module_spec)
74    loader = module_spec.loader
75    assert loader is not None
76
77    loader.exec_module(module)
78
79    # And again, the type checker has no way of knowing that this line is valid.
80    jit_model = module.jit_model  # type: ignore[attr-defined]
81    assert isinstance(
82        jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)
83    ), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
84    jit_model.save(artifact_path)  # type: ignore[call-arg]
85
86    # Cleanup now that we have the actual serialized model.
87    os.remove(module_path)
88    return artifact_path
89
90
91def _get_stmt(
92    benchmark: GroupedBenchmark,
93    runtime: RuntimeMode,
94    autograd: AutogradMode,
95    language: Language,
96) -> Optional[str]:
97    """Specialize a GroupedBenchmark for a particular configuration."""
98    is_python = language == Language.PYTHON
99
100    # During GroupedBenchmark construction, py_fwd_stmt and cpp_fwd_stmt are
101    # set to the eager invocation. So in the RuntimeMode.EAGER case we can
102    # simply reuse them. For the RuntimeMode.JIT case, we need to generate
103    # an appropriate `jit_model(...)` invocation.
104    if runtime == RuntimeMode.EAGER:
105        stmts = (benchmark.py_fwd_stmt, benchmark.cpp_fwd_stmt)
106
107    else:
108        assert runtime == RuntimeMode.JIT
109        assert benchmark.signature_args is not None
110        stmts = GroupedBenchmark._make_model_invocation(
111            benchmark.signature_args, benchmark.signature_output, RuntimeMode.JIT
112        )
113
114    stmt = stmts[0 if is_python else 1]
115
116    if autograd == AutogradMode.FORWARD_BACKWARD and stmt is not None:
117        assert benchmark.signature_output is not None
118        backward = (
119            f"{benchmark.signature_output}"
120            # In C++ we have to get the Tensor out of the IValue to call `.backward()`
121            f"{'.toTensor()' if runtime == RuntimeMode.JIT and language == Language.CPP else ''}"
122            f".backward(){';' if language == Language.CPP else ''}"
123        )
124        stmt = f"{stmt}\n{backward}"
125    return stmt
126
127
128def _get_setup(
129    benchmark: GroupedBenchmark,
130    runtime: RuntimeMode,
131    language: Language,
132    stmt: str,
133    model_path: Optional[str],
134) -> str:
135    """Specialize a GroupedBenchmark for a particular configuration.
136
137    Setup requires two extra pieces of information:
138      1) The benchmark stmt. This is needed to warm up the model and avoid
139         measuring lazy initialization.
140      2) The model path so we can load it during the benchmark.
141
142    These are only used when `runtime == RuntimeMode.JIT`.
143    """
144
145    # By the time we get here, details about how to set up a model have already
146    # been determined by GroupedBenchmark. (Or set to None if appropriate.) We
147    # simply need to collect and package the code blocks.
148    if language == Language.PYTHON:
149        setup = benchmark.setup.py_setup
150        model_setup = benchmark.py_model_setup
151    else:
152        assert language == Language.CPP
153        setup = benchmark.setup.cpp_setup
154        model_setup = benchmark.cpp_model_setup
155
156    if runtime == RuntimeMode.EAGER:
157        return "\n".join([setup, model_setup or ""])
158
159    assert runtime == RuntimeMode.JIT
160    assert model_path is not None
161
162    # We template `"{model_path}"`, so quotes would break model loading. The
163    # model path is generated within the benchmark, so this is just an
164    # abundance of caution rather than something that is expected in practice.
165    assert '"' not in model_path
166
167    # `stmt` may contain newlines, so we can't use f-strings. Instead we need
168    # to generate templates so that dedent works properly.
169    if language == Language.PYTHON:
170        setup_template: str = textwrap.dedent(
171            f"""
172            jit_model = torch.jit.load("{model_path}")
173
174            # Warmup `jit_model`
175            for _ in range(3):
176            {{stmt}}
177        """
178        )
179
180    else:
181        assert language == Language.CPP
182        setup_template = textwrap.dedent(
183            f"""
184            const std::string fpath = "{model_path}";
185            auto jit_model = torch::jit::load(fpath);
186
187            // Warmup `jit_model`
188            for (int i = 0; i < 3; i++) {{{{
189            {{stmt}}
190            }}}}
191        """
192        )
193
194    model_load = setup_template.format(stmt=textwrap.indent(stmt, " " * 4))
195    return "\n".join([setup, model_load])
196
197
198def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition:
199    """Convert a heterogeneous benchmark into an executable state.
200
201    This entails generation of TorchScript model artifacts, splitting
202    GroupedBenchmarks into multiple TimerArgs, and tagging the results with
203    AutoLabels.
204    """
205    results: List[Tuple[Label, AutoLabels, TimerArgs]] = []
206
207    for label, args in benchmarks.items():
208        if isinstance(args, TimerArgs):
209            # User provided an explicit TimerArgs, so no processing is necessary.
210            auto_labels = AutoLabels(
211                RuntimeMode.EXPLICIT, AutogradMode.EXPLICIT, args.language
212            )
213            results.append((label, auto_labels, args))
214
215        else:
216            assert isinstance(args, GroupedBenchmark)
217
218            model_path: Optional[str] = None
219            if args.py_model_setup and args.torchscript:
220                model_setup = (
221                    f"{args.py_model_setup}\njit_model = torch.jit.script(model)"
222                )
223
224                # This is just for debugging. We just need a unique name for the
225                # model, but embedding the label makes debugging easier.
226                name: str = re.sub(r"[^a-z0-9_]", "_", "_".join(label).lower())
227                name = f"{name}_{uuid.uuid4()}"
228
229                model_path = _generate_torchscript_file(model_setup, name=name)
230
231            for (runtime, autograd, language), num_threads in it.product(
232                _ALL_MODES, args.num_threads
233            ):
234                if runtime == RuntimeMode.EXPLICIT or autograd == AutogradMode.EXPLICIT:
235                    continue
236
237                if runtime == RuntimeMode.JIT and not args.torchscript:
238                    continue
239
240                if autograd == AutogradMode.FORWARD_BACKWARD and not args.autograd:
241                    continue
242
243                stmt = _get_stmt(args, runtime, autograd, language)
244                if stmt is None:
245                    continue
246
247                setup = _get_setup(args, runtime, language, stmt, model_path)
248
249                global_setup: str = ""
250                if language == Language.CPP and runtime == RuntimeMode.JIT:
251                    global_setup = textwrap.dedent(
252                        """
253                        #include <string>
254                        #include <vector>
255                        #include <torch/script.h>
256                    """
257                    )
258
259                autolabels = AutoLabels(runtime, autograd, language)
260                timer_args = TimerArgs(
261                    stmt=stmt,
262                    setup=setup,
263                    global_setup=global_setup,
264                    num_threads=num_threads,
265                    language=language,
266                )
267
268                results.append((label, autolabels, timer_args))
269
270    return tuple(results)
271