xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/torchbench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import gc
4import importlib
5import logging
6import os
7import re
8import sys
9import warnings
10from collections import namedtuple
11from os.path import abspath, exists
12
13import torch
14
15
16try:
17    from .common import BenchmarkRunner, load_yaml_file, main
18except ImportError:
19    from common import BenchmarkRunner, load_yaml_file, main
20
21from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
22from torch._dynamo.utils import clone_inputs
23
24
25# We are primarily interested in tf32 datatype
26torch.backends.cuda.matmul.allow_tf32 = True
27
28# Enable FX graph caching
29if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
30    torch._inductor.config.fx_graph_cache = True
31
32
33def _reassign_parameters(model):
34    # torch_geometric models register parameter as tensors due to
35    # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
36    # Since it is unusual thing to do, we just reassign them to parameters
37    def state_dict_hook(module, destination, prefix, local_metadata):
38        for name, param in module.named_parameters():
39            if isinstance(destination[name], torch.Tensor) and not isinstance(
40                destination[name], torch.nn.Parameter
41            ):
42                destination[name] = torch.nn.Parameter(destination[name])
43
44    model._register_state_dict_hook(state_dict_hook)
45
46
47def setup_torchbench_cwd():
48    original_dir = abspath(os.getcwd())
49
50    os.environ["KALDI_ROOT"] = "/tmp"  # avoids some spam
51    for torchbench_dir in (
52        "./torchbenchmark",
53        "../torchbenchmark",
54        "../torchbench",
55        "../benchmark",
56        "../../torchbenchmark",
57        "../../torchbench",
58        "../../benchmark",
59        "../../../torchbenchmark",
60        "../../../torchbench",
61        "../../../benchmark",
62    ):
63        if exists(torchbench_dir):
64            break
65
66    if exists(torchbench_dir):
67        torchbench_dir = abspath(torchbench_dir)
68        os.chdir(torchbench_dir)
69        sys.path.append(torchbench_dir)
70
71    return original_dir
72
73
74def process_hf_reformer_output(out):
75    assert isinstance(out, list)
76    # second output is unstable
77    return [elem for i, elem in enumerate(out) if i != 1]
78
79
80def process_hf_whisper_output(out):
81    out_ret = []
82    for i, elem in enumerate(out):
83        if i == 0:
84            assert isinstance(elem, dict)
85            out_ret.append({k: v for k, v in elem.items() if k != "logits"})
86        elif i != 1:
87            out_ret.append(elem)
88
89    return out_ret
90
91
92process_train_model_output = {
93    "hf_Reformer": process_hf_reformer_output,
94    "hf_Whisper": process_hf_whisper_output,
95}
96
97
98class TorchBenchmarkRunner(BenchmarkRunner):
99    def __init__(self):
100        super().__init__()
101        self.suite_name = "torchbench"
102        self.optimizer = None
103
104    @property
105    def _config(self):
106        return load_yaml_file("torchbench.yaml")
107
108    @property
109    def _skip(self):
110        return self._config["skip"]
111
112    @property
113    def _batch_size(self):
114        return self._config["batch_size"]
115
116    @property
117    def _tolerance(self):
118        return self._config["tolerance"]
119
120    @property
121    def _require_larger_multiplier_for_smaller_tensor(self):
122        return self._config["require_larger_multiplier_for_smaller_tensor"]
123
124    @property
125    def _accuracy(self):
126        return self._config["accuracy"]
127
128    @property
129    def skip_models(self):
130        return self._skip["all"]
131
132    @property
133    def skip_models_for_cpu(self):
134        return self._skip["device"]["cpu"]
135
136    @property
137    def skip_models_for_cuda(self):
138        return self._skip["device"]["cuda"]
139
140    @property
141    def skip_models_for_freezing_cuda(self):
142        return self._skip["freezing"]["cuda"]
143
144    @property
145    def skip_models_for_freezing_cpu(self):
146        return self._skip["freezing"]["cpu"]
147
148    @property
149    def slow_models(self):
150        return self._config["slow"]
151
152    @property
153    def very_slow_models(self):
154        return self._config["very_slow"]
155
156    @property
157    def non_deterministic_models(self):
158        return self._config["non_deterministic"]
159
160    @property
161    def get_output_amp_train_process_func(self):
162        return process_train_model_output
163
164    @property
165    def skip_not_suitable_for_training_models(self):
166        return self._skip["test"]["training"]
167
168    @property
169    def failing_fx2trt_models(self):
170        return self._config["trt_not_yet_working"]
171
172    @property
173    def force_amp_for_fp16_bf16_models(self):
174        return self._config["dtype"]["force_amp_for_fp16_bf16_models"]
175
176    @property
177    def force_fp16_for_bf16_models(self):
178        return self._config["dtype"]["force_fp16_for_bf16_models"]
179
180    @property
181    def skip_accuracy_checks_large_models_dashboard(self):
182        if self.args.dashboard or self.args.accuracy:
183            return self._accuracy["skip"]["large_models"]
184        return set()
185
186    @property
187    def skip_accuracy_check_as_eager_non_deterministic(self):
188        if self.args.accuracy and self.args.training:
189            return self._accuracy["skip"]["eager_not_deterministic"]
190        return set()
191
192    @property
193    def skip_multiprocess_models(self):
194        return self._skip["multiprocess"]
195
196    @property
197    def skip_models_due_to_control_flow(self):
198        return self._skip["control_flow"]
199
200    @property
201    def guard_on_nn_module_models(self):
202        return {
203            "vision_maskrcnn",
204        }
205
206    @property
207    def inline_inbuilt_nn_modules_models(self):
208        return {
209            "basic_gnn_edgecnn",
210            "drq",
211            "hf_Reformer",
212            "DALLE2_pytorch",
213            "hf_BigBird",
214            "detectron2_maskrcnn_r_50_fpn",
215            "detectron2_maskrcnn_r_101_fpn",
216            "vision_maskrcnn",
217            "doctr_reco_predictor",
218            "hf_T5_generate",
219        }
220
221    def load_model(
222        self,
223        device,
224        model_name,
225        batch_size=None,
226        part=None,
227        extra_args=None,
228    ):
229        if self.args.enable_activation_checkpointing:
230            raise NotImplementedError(
231                "Activation checkpointing not implemented for Torchbench models"
232            )
233        is_training = self.args.training
234        use_eval_mode = self.args.use_eval_mode
235        dynamic_shapes = self.args.dynamic_shapes
236        candidates = [
237            f"torchbenchmark.models.{model_name}",
238            f"torchbenchmark.canary_models.{model_name}",
239            f"torchbenchmark.models.fb.{model_name}",
240        ]
241        for c in candidates:
242            try:
243                module = importlib.import_module(c)
244                break
245            except ModuleNotFoundError as e:
246                if e.name != c:
247                    raise
248        else:
249            raise ImportError(f"could not import any of {candidates}")
250        benchmark_cls = getattr(module, "Model", None)
251        if benchmark_cls is None:
252            raise NotImplementedError(f"{model_name}.Model is None")
253
254        if not hasattr(benchmark_cls, "name"):
255            benchmark_cls.name = model_name
256
257        cant_change_batch_size = (
258            not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True)
259            or model_name in self._config["dont_change_batch_size"]
260        )
261        if cant_change_batch_size:
262            batch_size = None
263        if (
264            batch_size is None
265            and is_training
266            and model_name in self._batch_size["training"]
267        ):
268            batch_size = self._batch_size["training"][model_name]
269        elif (
270            batch_size is None
271            and not is_training
272            and model_name in self._batch_size["inference"]
273        ):
274            batch_size = self._batch_size["inference"][model_name]
275
276        # Control the memory footprint for few models
277        if self.args.accuracy and model_name in self._accuracy["max_batch_size"]:
278            batch_size = min(batch_size, self._accuracy["max_batch_size"][model_name])
279
280        # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
281        torch.backends.__allow_nonbracketed_mutation_flag = True
282        if extra_args is None:
283            extra_args = []
284        if part:
285            extra_args += ["--part", part]
286
287        # sam_fast only runs with amp
288        if model_name == "sam_fast":
289            self.args.amp = True
290            self.setup_amp()
291
292        if model_name == "vision_maskrcnn" and is_training:
293            # Output of vision_maskrcnn model is a list of bounding boxes,
294            # sorted on the basis of their scores. This makes accuracy
295            # comparison hard with torch.compile. torch.compile can cause minor
296            # divergences in the output because of how fusion works for amp in
297            # TorchInductor compared to eager.  Therefore, instead of looking at
298            # all the bounding boxes, we compare only top 4.
299            model_kwargs = {"box_detections_per_img": 4}
300            benchmark = benchmark_cls(
301                test="train",
302                device=device,
303                batch_size=batch_size,
304                extra_args=extra_args,
305                model_kwargs=model_kwargs,
306            )
307            use_eval_mode = True
308        elif is_training:
309            benchmark = benchmark_cls(
310                test="train",
311                device=device,
312                batch_size=batch_size,
313                extra_args=extra_args,
314            )
315        else:
316            benchmark = benchmark_cls(
317                test="eval",
318                device=device,
319                batch_size=batch_size,
320                extra_args=extra_args,
321            )
322        model, example_inputs = benchmark.get_module()
323        if model_name in [
324            "basic_gnn_edgecnn",
325            "basic_gnn_gcn",
326            "basic_gnn_sage",
327            "basic_gnn_gin",
328        ]:
329            _reassign_parameters(model)
330
331        # Models that must be in train mode while training
332        if is_training and (
333            not use_eval_mode or model_name in self._config["only_training"]
334        ):
335            model.train()
336        else:
337            model.eval()
338        gc.collect()
339        batch_size = benchmark.batch_size
340        if model_name == "torchrec_dlrm":
341            batch_namedtuple = namedtuple(
342                "Batch", "dense_features sparse_features labels"
343            )
344            example_inputs = tuple(
345                batch_namedtuple(
346                    dense_features=batch.dense_features,
347                    sparse_features=batch.sparse_features,
348                    labels=batch.labels,
349                )
350                for batch in example_inputs
351            )
352        # Torchbench has quite different setup for yolov3, so directly passing
353        # the right example_inputs
354        if model_name == "yolov3":
355            example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
356        # See https://github.com/pytorch/benchmark/issues/1561
357        if model_name == "maml_omniglot":
358            batch_size = 5
359            assert example_inputs[0].shape[0] == batch_size
360        if model_name == "vision_maskrcnn":
361            batch_size = 1
362        # global current_name, current_device
363        # current_device = device
364        # current_name = benchmark.name
365
366        if self.args.trace_on_xla:
367            # work around for: https://github.com/pytorch/xla/issues/4174
368            import torch_xla  # noqa: F401
369        self.validate_model(model, example_inputs)
370        return device, benchmark.name, model, example_inputs, batch_size
371
372    def iter_model_names(self, args):
373        from torchbenchmark import _list_canary_model_paths, _list_model_paths
374
375        models = _list_model_paths()
376        models += [
377            f
378            for f in _list_canary_model_paths()
379            if os.path.basename(f) in self._config["canary_models"]
380        ]
381        models.sort()
382
383        start, end = self.get_benchmark_indices(len(models))
384        for index, model_path in enumerate(models):
385            if index < start or index >= end:
386                continue
387
388            model_name = os.path.basename(model_path)
389            if (
390                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
391                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
392                or model_name in args.exclude_exact
393                or model_name in self.skip_models
394            ):
395                continue
396
397            yield model_name
398
399    def pick_grad(self, name, is_training):
400        if is_training or name in ("maml",):
401            return torch.enable_grad()
402        else:
403            return torch.no_grad()
404
405    def use_larger_multiplier_for_smaller_tensor(self, name):
406        return name in self._require_larger_multiplier_for_smaller_tensor
407
408    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
409        tolerance = 1e-4
410        cosine = self.args.cosine
411        # Increase the tolerance for torch allclose
412        if self.args.float16 or self.args.amp:
413            if name in self._tolerance["higher_fp16"]:
414                return 1e-2, cosine
415            elif name in self._tolerance["even_higher"]:
416                return 8 * 1e-2, cosine
417            return 1e-3, cosine
418
419        if self.args.bfloat16:
420            if name in self._tolerance["higher_bf16"]:
421                return 1e-2, cosine
422
423        if is_training and (current_device == "cuda" or current_device == "xpu"):
424            tolerance = 1e-3
425            if name in self._tolerance["cosine"]:
426                cosine = True
427            elif name in self._tolerance["higher"]:
428                tolerance = 1e-3
429            elif name in self._tolerance["even_higher"]:
430                tolerance = 8 * 1e-2
431        return tolerance, cosine
432
433    def compute_loss(self, pred):
434        return reduce_to_scalar_loss(pred)
435
436    def forward_pass(self, mod, inputs, collect_outputs=True):
437        with self.autocast(**self.autocast_arg):
438            if isinstance(inputs, dict):
439                return mod(**inputs)
440            else:
441                return mod(*inputs)
442
443    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
444        cloned_inputs = clone_inputs(inputs)
445        self.optimizer_zero_grad(mod)
446        with self.autocast(**self.autocast_arg):
447            if isinstance(cloned_inputs, dict):
448                pred = mod(**cloned_inputs)
449            else:
450                pred = mod(*cloned_inputs)
451            loss = self.compute_loss(pred)
452        self.grad_scaler.scale(loss).backward()
453        self.optimizer_step()
454        if collect_outputs:
455            return collect_results(mod, pred, loss, cloned_inputs)
456        return None
457
458
459def torchbench_main():
460    original_dir = setup_torchbench_cwd()
461    logging.basicConfig(level=logging.WARNING)
462    warnings.filterwarnings("ignore")
463    main(TorchBenchmarkRunner(), original_dir)
464
465
466if __name__ == "__main__":
467    torchbench_main()
468